Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable running softmax with TPPs #584

Open
chelini opened this issue Jun 12, 2023 · 3 comments
Open

Enable running softmax with TPPs #584

chelini opened this issue Jun 12, 2023 · 3 comments
Labels
low-priority Things that go in the back burner

Comments

@chelini
Copy link
Contributor

chelini commented Jun 12, 2023

To enable running softmax with TPPs we need more operations:

  1. max/sum reduce op (%2 and %8)
  2. sub operation (%5) We needs to support broadcast semantics in sub or implement an explicit broadcast op see %4.
  3. exp (%6)
  4. div (%10)

The IR below shows a Softmax example in Linalg, extracted from a self-attention layer. The lowering is: TF dialect -> StableHLO -> Linalg IR. To lower from TF dialect to StableHLO we use tf-opt while from StableHLO to linalg we use the IREE compiler and print after iree-stablehlo-to-iree-input.

The dimension of arg0 are: [B, heads, T, T] where B is the batched dimension, heads is the number of heads, while T is the sequence length.

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, 0)>
module {
  func.func @softmax(%arg0: tensor<64x8x5x5xf32>) -> tensor<64x8x5x5xf32> {
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant 0xFF800000 : f32
    %0 = tensor.empty() : tensor<64x8x5xf32>
    %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<64x8x5xf32>) -> tensor<64x8x5xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%arg0 : tensor<64x8x5x5xf32>) outs(%1 : tensor<64x8x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      %11 = arith.maxf %out, %in : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5xf32>
    %expanded = tensor.expand_shape %2 [[0], [1], [2, 3]] : tensor<64x8x5xf32> into tensor<64x8x5x1xf32>
    %3 = tensor.empty() : tensor<64x8x5x5xf32>
    %4 = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded : tensor<64x8x5x1xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x5x5xf32>
    %5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg0, %4 : tensor<64x8x5x5xf32>, tensor<64x8x5x5xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %11 = arith.subf %in, %in_2 : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5x5xf32>
    %6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%5 : tensor<64x8x5x5xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      %11 = math.exp %in : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5x5xf32>
    %7 = linalg.fill ins(%cst : f32) outs(%0 : tensor<64x8x5xf32>) -> tensor<64x8x5xf32>
    %8 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%6 : tensor<64x8x5x5xf32>) outs(%7 : tensor<64x8x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      %11 = arith.addf %out, %in : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5xf32>
    %expanded_1 = tensor.expand_shape %8 [[0], [1], [2, 3]] : tensor<64x8x5xf32> into tensor<64x8x5x1xf32>
    %9 = linalg.generic {indexing_maps = [#map2, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_1 : tensor<64x8x5x1xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %out: f32):
      linalg.yield %in : f32
    } -> tensor<64x8x5x5xf32>
    %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%6, %9 : tensor<64x8x5x5xf32>, tensor<64x8x5x5xf32>) outs(%3 : tensor<64x8x5x5xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %11 = arith.divf %in, %in_2 : f32
      linalg.yield %11 : f32
    } -> tensor<64x8x5x5xf32>
    return %10 : tensor<64x8x5x5xf32>
  }
}

Related to #414.

TBD:

  • It looks legit to fuse along the batch and head dimensions; this would require extending the fusion pass to detect a softmax and fuse along these dimensions. Today we only anchor fusion around matmul-like operations (i.e., blocked convolution/ matmul or linalg.matmul and named convolutions). The other option is to use linalg fusion on tensor but limit the pass only to the softmax operations to avoid having to split the body of the generic later on before mapping to tpps.
@chelini chelini mentioned this issue Jun 12, 2023
@rengolin
Copy link
Contributor

This IR looks similar to what we generate from mlir-gen, but the shapes are weird (5x5), where is that from?

The MHA softmax is a little different, we need both styles covered.

The other option is to use linalg fusion on tensor but limit the pass only to the softmax operations to avoid having to split the body of the generic later on before mapping to tpps.

softmax in libxsmm is lowered as an equation, and just calling the kernels one after another is very close to optimal. I would not create complicated machinery that is specific to certain complex patterns unless the benefit was very large and there was no other way.

Softmax will eventually be lowered as an equation, which is the right way long term, so we can live with most of the performance now and the rest later.

@chelini
Copy link
Contributor Author

chelini commented Jun 12, 2023

Yes, calling the kernel one after the other would be the plan. Still, we must either fuse along 64 and 8 to extract 2d tensors or materialize the two outermost dimensions for each linalg ops and replace the body with a tpp operation. Do you have an example of the IR generated by mlir-gen? 5 is an arbitrary number for the sequence length, it does not matter in this context.

@rengolin
Copy link
Contributor

rengolin commented Jun 12, 2023

Do you have an example of the IR generated by mlir-gen

Yup. just run mlir-gen and you'll see.

Also, just to be clear, this is really low priority. Finding the right shapes for MHA and finally getting TPP on tensors in the main pipeline are still the most important tasks right now.

@rengolin rengolin added the low-priority Things that go in the back burner label Jun 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
low-priority Things that go in the back burner
Projects
None yet
Development

No branches or pull requests

2 participants