[MLIR][MaterializeVectors] Add a MaterializeVector pass via unrolling.
This CL adds an MLIR-MLIR pass which materializes super-vectors to hardware-dependent sized vectors. While the physical vector size is target-dependent, the pass is written in a target-independent way: the target vector size is specified as a parameter to the pass. This pass is thus a partial lowering that opens the "greybox" that is the super-vector abstraction. This first CL adds a first materilization pass iterates over vector_transfer_write operations and: 1. computes the program slice including the current vector_transfer_write; 2. computes the multi-dimensional ratio of super-vector shape to hardware vector shape; 3. for each possible multi-dimensional value within the bounds of ratio, a new slice is instantiated (i.e. cloned and rewritten) so that all operations in this instance operate on the hardware vector type. As a simple example, given: ```mlir mlfunc @vector_add_2d(%M : index, %N : index) -> memref<?x?xf32> { %A = alloc (%M, %N) : memref<?x?xf32> %B = alloc (%M, %N) : memref<?x?xf32> %C = alloc (%M, %N) : memref<?x?xf32> for %i0 = 0 to %M { for %i1 = 0 to %N { %a1 = load %A[%i0, %i1] : memref<?x?xf32> %b1 = load %B[%i0, %i1] : memref<?x?xf32> %s1 = addf %a1, %b1 : f32 store %s1, %C[%i0, %i1] : memref<?x?xf32> } } return %C : memref<?x?xf32> } ``` and the following options: ``` -vectorize -virtual-vector-size 32 --test-fastest-varying=0 -materialize-vectors -vector-size=8 ``` materialization emits: ```mlir #map0 = (d0, d1) -> (d0, d1) #map1 = (d0, d1) -> (d0, d1 + 8) #map2 = (d0, d1) -> (d0, d1 + 16) #map3 = (d0, d1) -> (d0, d1 + 24) mlfunc @vector_add_2d(%arg0 : index, %arg1 : index) -> memref<?x?xf32> { %0 = alloc(%arg0, %arg1) : memref<?x?xf32> %1 = alloc(%arg0, %arg1) : memref<?x?xf32> %2 = alloc(%arg0, %arg1) : memref<?x?xf32> for %i0 = 0 to %arg0 { for %i1 = 0 to %arg1 step 32 { %3 = affine_apply #map0(%i0, %i1) %4 = "vector_transfer_read"(%0, %3tensorflow/mlir#0, %3tensorflow/mlir#1) : (memref<?x?xf32>, index, index) -> vector<8xf32> %5 = affine_apply #map1(%i0, %i1) %6 = "vector_transfer_read"(%0, %5tensorflow/mlir#0, %5tensorflow/mlir#1) : (memref<?x?xf32>, index, index) -> vector<8xf32> %7 = affine_apply #map2(%i0, %i1) %8 = "vector_transfer_read"(%0, %7tensorflow/mlir#0, %7tensorflow/mlir#1) : (memref<?x?xf32>, index, index) -> vector<8xf32> %9 = affine_apply #map3(%i0, %i1) %10 = "vector_transfer_read"(%0, %9tensorflow/mlir#0, %9tensorflow/mlir#1) : (memref<?x?xf32>, index, index) -> vector<8xf32> %11 = affine_apply #map0(%i0, %i1) %12 = "vector_transfer_read"(%1, %11tensorflow/mlir#0, %11tensorflow/mlir#1) : (memref<?x?xf32>, index, index) -> vector<8xf32> %13 = affine_apply #map1(%i0, %i1) %14 = "vector_transfer_read"(%1, %13tensorflow/mlir#0, %13tensorflow/mlir#1) : (memref<?x?xf32>, index, index) -> vector<8xf32> %15 = affine_apply #map2(%i0, %i1) %16 = "vector_transfer_read"(%1, %15tensorflow/mlir#0, %15tensorflow/mlir#1) : (memref<?x?xf32>, index, index) -> vector<8xf32> %17 = affine_apply #map3(%i0, %i1) %18 = "vector_transfer_read"(%1, %17tensorflow/mlir#0, %17tensorflow/mlir#1) : (memref<?x?xf32>, index, index) -> vector<8xf32> %19 = addf %4, %12 : vector<8xf32> %20 = addf %6, %14 : vector<8xf32> %21 = addf %8, %16 : vector<8xf32> %22 = addf %10, %18 : vector<8xf32> %23 = affine_apply #map0(%i0, %i1) "vector_transfer_write"(%19, %2, %23tensorflow/mlir#0, %23tensorflow/mlir#1) : (vector<8xf32>, memref<?x?xf32>, index, index) -> () %24 = affine_apply #map1(%i0, %i1) "vector_transfer_write"(%20, %2, %24tensorflow/mlir#0, %24tensorflow/mlir#1) : (vector<8xf32>, memref<?x?xf32>, index, index) -> () %25 = affine_apply #map2(%i0, %i1) "vector_transfer_write"(%21, %2, %25tensorflow/mlir#0, %25tensorflow/mlir#1) : (vector<8xf32>, memref<?x?xf32>, index, index) -> () %26 = affine_apply #map3(%i0, %i1) "vector_transfer_write"(%22, %2, %26tensorflow/mlir#0, %26tensorflow/mlir#1) : (vector<8xf32>, memref<?x?xf32>, index, index) -> () } } return %2 : memref<?x?xf32> } ``` PiperOrigin-RevId: 222455351
Loading
Please sign in to comment