Uncovering Nested Data Parallelism and Data Reuse in DNN Computation with FractalTensor

SOSP |

To speed up computation, deep neural networks (DNNs) usually rely on highly optimized tensor operators. Despite the effectiveness, tensor operators are often defined empirically with ad hoc semantics. This hinders the analysis and optimization across operator boundaries. FractalTensor is a programming framework that addresses this challenge. At the core, FractalTensor is a nested list-based abstract data type (ADT), where each element is a tensor with static shape or another FractalTensor (i.e., nested). DNNs are then defined by high-order compute operators like map/reduce/scan and data access operators like window/stride on FractalTensor. This new way of DNN definition explicitly exposes nested data parallelism and fine-grained data access patterns, opening new opportunities for whole program analysis and optimization. To exploit these opportunities, from the FractalTensor-based code the compiler extracts a nested multi-dimensional dataflow graph called Extended Task Dependence Graph (ETDG), which provides a holistic view of data dependency across different granularity. The ETDG is then transformed into an efficient implementation through graph coarsening, data reordering, and access materialization. Evaluation on six representative DNNs like RNN and FlashAttention on NVIDIA A100 shows that FractalTensor achieves speedup by up to 5.44x and 1.97x on average through a unified solution for diverse optimizations.