Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce unstructured-to-memref pass #216

Merged
merged 9 commits into from
Jan 15, 2025
Merged

Introduce unstructured-to-memref pass #216

merged 9 commits into from
Jan 15, 2025

Conversation

nhat-nguyen
Copy link
Collaborator

@nhat-nguyen nhat-nguyen commented Jan 7, 2025

This PR introduces the unstructured-to-memref pass responsible for converting unstructured triton load / store ops to memref load / store ops. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. The pass is intended to be used after running --fold-unstructured-ptr.

Triton load op (gather) is lowered to a linalg.generic whose body contains a load from the offset indicated by the offset provided by tts.make_unstructured_tptr. For load op with mask, an inner-most scf.if is used to return a default value (or the other in tt.load if provided) if the corresponding mask value is false.

Example of a load:

  func.func @gather_simple_mask_with_other(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
      %cst = arith.constant -1.000000e+00 : f32
      %cast = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
      %load_tensor = bufferization.to_tensor %cast restrict : memref<?xf32>
      %out = tensor.empty() : tensor<64xf32>
      %gather = linalg.generic {
        iterator_types = ["parallel"]
      } ins(%offset_tensor, %mask_tensor : tensor<64xi32>, tensor<64xi1>)
        outs(%out : tensor<64xf32>) {
      ^bb0(%offset: i32, %mask: i1, %out: f32):
        %yield = scf.if %mask -> (f32) {
          %index = arith.index_cast %offset : i32 to index
          %extracted = tensor.extract %load_tensor[%index] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst : f32
        }
        linalg.yield %yield : f32
      } -> tensor<64xf32>

Triton store op (scatter) is lowered to an affine.for loop nest that stores the value to the appropriate offset provided by tts.make_unstructured_tptr. Store op with mask is also supported.

Example of a store:

  func.func @masked_gather_scatter(%arg0: memref<*xf32>, %arg1: memref<*xf32>) {
    %store_memref = memref.cast %arg1 : memref<*xf32> to memref<?xf32>
    affine.for %i = 0 to 4 {
      %mask_val = tensor.extract %mask[%i] : tensor<4xi1>
      scf.if %mask_val {
        %offset_val = tensor.extract %offset_tensor[%i] : tensor<4xi32>
        %store_value = tensor.extract %tensor[%i] : tensor<4xf32>
        %offset_index = arith.index_cast %offset_val : i32 to index
        memref.store %store_value, %store_memref[%offset_index] : memref<?xf32>
      }
    }

Intended lowering pipeline

  • triton-to-structured (no changes):
    • analyzes structured addptr sequences
      • introduces tts.make_tptr %ptr_arg with offsets and strides
      • introduces tts.load and tts.store
    • leaves unstructured addptr sequences and their corresponding tt.load and tt.store intact
  • triton-to-unstructured (Introduce triton-to-unstructured pass #210):
    • introduces tts.gather and tts.scatter
    • removes all pointer-producing ops such as tt.addptr and tt.splat and replaces them with offset-producing ops
  • structured-to-memref (Update structured-to-memref pass to support the new pass pipeline #217):
    • currently converts everything to memref including scalar addptr and kernel arguments
    • will change to just convert ops in the tts dialect to memref with the exception of tts.gather and tts.scatter
  • unstructured-to-memref (Introduce unstructured-to-memref pass #216):
    • converts the remaining unstructured tts.gather, tts.scatter into memref
  • triton-ptr-to-memref (Introduce triton-ptr-to-memref pass #211):
    • converts kernel arguments with pointer type to memref

@Nullkooland
Copy link
Contributor

@nhat-nguyen Hi, I tried your PR, and noticed that tt.store on unstructured ptrs tensor is lowered to an affine.for, for instance:

%ptrs = "tts.make_unstructured_tptr"(%base_ptr, %offsets) : (!tt.ptr<f16>, tensor<1024xi32>) -> tensor<1024x!tt.ptr<f16>>
tt.store %ptrs, %vals, %mask : tensor<1024x!tt.ptr<f16>>

is converted to:

%base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
affine.for %i = 0 to 1024 {
  %cond = tensor.extract %mask[%i] : tensor<1024xi1>
  scf.if %cond {
    %offset_i32 = tensor.extract %offsets[%i] : tensor<1024xi32>
    %val = tensor.extract %vals[%i] : tensor<1024xf16>
    %offset = arith.index_cast %offset_i32 : i32 to index
    memref.store %val, %base_mem[%offset] : memref<?xf16>
  }
}

However, as far as I know affine dialect is not actively used today and there is no pass to parallelize it, see MLIR discourse

So I suggest converting tt.store to linalg.generic, same as the conversion for for tt.load you implemented.
I understand that the linalg.generic op must have an output operand to imply iteration range, so you could add a dummy output tensor, like:

%base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
%cst = arith.constant 0.000000e+00 : f16
%dummy = tensor.empty() : tensor<1024xf16> // dummy output tensor.
%11 = linalg.generic {
  indexing_maps = [#map, #map, #map, #map],
  iterator_types = ["parallel"]
}
ins(%offsets, %vals, %mask : tensor<1024xi32>, tensor<1024xf16>, tensor<1024xi1>)
outs(%dummy: tensor<1024xf16>) {
^bb0(%offset_i32 : i32, %val : f16, %cond : i1, %out_dummy : f16):
  %dummy_val = scf.if %cond -> f16 {
    %offset = arith.index_cast %offset_i32 : i32 to index
    memref.store %val, %base_mem[%offset] : memref<?xf16>
    scf.yield %cst : f16
  } else {
    scf.yield %cst : f16
  }
  linalg.yield %dummy_val : f16
} -> tensor<1024xf16>

later we can lower the linalg.generic to scf.parallel loop with one-shot-bufferize and convert-linalg-to-parallel-loops, the dummy output is eliminated by dce:

%base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
scf.parallel (%i) = (%c0) to (%c1024) step (%c1) {
  %offset_i32 = memref.load %offsets_buf[%i] : memref<1024xi32>
  %val = memref.load %vals_buf[%i] : memref<1024xf16>
  %cond = memref.load %mask_buf[%i] : memref<1024xi1>
  scf.if %4 {
    %offset = arith.index_cast %offset_i32 : i32 to index
    memref.store %val, %base_mem[%offset] : memref<?xf16>
  }
  scf.reduce 
}

@nhat-nguyen
Copy link
Collaborator Author

@Nullkooland The issue with using linalg on tensor for these write operations is the whole op can be removed through canonicalization if we haven't converted to memref yet. I think that is a pretty big drawback -- worse than not being able to leverage the conversion to parallel loop.

With the following IR:

#map = affine_map<(d0) -> (d0)>
module {
  tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
    %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
    %cst = arith.constant 9.900000e+01 : f32
    %dummy_const = arith.constant 1 : i1
    %c0_i32 = arith.constant 0 : i32
    %c1_i32 = arith.constant 1 : i32
    %c2_i32 = arith.constant 2 : i32
    %cst_0 = arith.constant dense<4> : tensor<4xi32>
    %cst_1 = arith.constant dense<64> : tensor<4xi32>
    %cst_2 = arith.constant dense<3> : tensor<4xi32>
    %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
    %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>)  : i32 {
      %4 = arith.divsi %arg3, %cst_2 : tensor<4xi32>
      %5 = tt.splat %arg2 : i32 -> tensor<4xi32>
      %6 = arith.addi %4, %5 : tensor<4xi32>
      %7 = arith.cmpi slt, %6, %cst_1 : tensor<4xi32>
      %cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
      %8 = bufferization.to_tensor %cast restrict : memref<?xf32>
      %9 = tensor.empty() : tensor<4xf32>
      %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
      ^bb0(%in: i32, %in_4: i1, %out: f32):
        %15 = scf.if %in_4 -> (f32) {
          %16 = arith.index_cast %in : i32 to index
          %extracted = tensor.extract %8[%16] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst : f32
        }
        linalg.yield %15 : f32
      } -> tensor<4xf32>
      %cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
      %11 = tensor.empty() : tensor<4xi1>
        %alloc = memref.alloc() : memref<4xi1>
      linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%10, %6, %7 : tensor<4xf32>, tensor<4xi32>, tensor<4xi1>) outs(%alloc : memref<4xi1>) {
      ^bb0(%in: f32, %in_4: i32, %in_5: i1, %out: i1):
        %15 = arith.index_cast %in_4 : i32 to index
        %yield = scf.if %in_5 -> i1 {
          memref.store %in, %cast_3[%15] : memref<?xf32>
          scf.yield %dummy_const : i1
        } else {
        scf.yield %dummy_const : i1
        }
        linalg.yield %yield : i1
      }
      %13 = arith.addi %6, %cst_0 : tensor<4xi32>
      %14 = arith.addi %arg4, %cst_0 : tensor<4xi32>
      scf.yield %13, %14 : tensor<4xi32>, tensor<4xi32>
    }
    tt.return
  }
}

--canonicalize will end up removing the whole body. Let me know if you know of any ways to prevent this. Otherwise, I think just leaving the scatter into the affine loop is the best we can do. The loop nest itself is pretty simple, so we can pattern match and parallelize it later if necessary.

@Nullkooland
Copy link
Contributor

@nhat-nguyen This might be an upstream MLIR bug, due to linalg op not implementing RecursiveMemoryEffects that take into account ops with memory side effect in its body, see llvm/llvm-project#114045.

You could apply this upstream fix patch to the llvm-project dependency and try again to see if this linalg.generic with memref.store in its body gets removed or not. If this works, I guess triton-shared needs to update the dependent triton version with newer dependent llvm-project version that includes this fix.

@Nullkooland
Copy link
Contributor

@nhat-nguyen I trited your example IR (with minor modification that outs %alloc = memref.alloc() : memref<4xi1> is changed to %alloc = tensor.empty() : tensor<4xi1> since we cannot mix tensor and memref I/O operands in linalg ops) using a triton-shared-opt with a llvm-project built from source with that llvm/llvm-project#114045 upstream fix.

triton-shared-opt --canonicalize --cse masked_gather_scatter.mlir

the output IR is:

#map = affine_map<(d0) -> (d0)>
module {
  tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<3> : tensor<4xi32>
    %cst_0 = arith.constant dense<64> : tensor<4xi32>
    %cst_1 = arith.constant dense<4> : tensor<4xi32>
    %c2_i32 = arith.constant 2 : i32
    %c1_i32 = arith.constant 1 : i32
    %c0_i32 = arith.constant 0 : i32
    %true = arith.constant true
    %cst_2 = arith.constant 9.900000e+01 : f32
    %0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
    %1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
    %2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
    %3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>)  : i32 {
      %4 = arith.divsi %arg3, %cst : tensor<4xi32>
      %5 = tt.splat %arg2 : i32 -> tensor<4xi32>
      %6 = arith.addi %4, %5 : tensor<4xi32>
      %7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32>
      %cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
      %8 = bufferization.to_tensor %cast restrict : memref<?xf32>
      %9 = tensor.empty() : tensor<4xf32>
      %10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
      ^bb0(%in: i32, %in_4: i1, %out: f32):
        %15 = scf.if %in_4 -> (f32) {
          %16 = arith.index_cast %in : i32 to index
          %extracted = tensor.extract %8[%16] : tensor<?xf32>
          scf.yield %extracted : f32
        } else {
          scf.yield %cst_2 : f32
        }
        linalg.yield %15 : f32
      } -> tensor<4xf32>
      %cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
      %11 = tensor.empty() : tensor<4xi1>
      %12 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%10, %6, %7 : tensor<4xf32>, tensor<4xi32>, tensor<4xi1>) outs(%11 : tensor<4xi1>) {
      ^bb0(%in: f32, %in_4: i32, %in_5: i1, %out: i1):
        %15 = arith.index_cast %in_4 : i32 to index
        scf.if %in_5 {
          memref.store %in, %cast_3[%15] : memref<?xf32>
        }
        linalg.yield %true : i1
      } -> tensor<4xi1>
      %13 = arith.addi %6, %cst_1 : tensor<4xi32>
      %14 = arith.addi %arg4, %cst_1 : tensor<4xi32>
      scf.yield %13, %14 : tensor<4xi32>, tensor<4xi32>
    }
    tt.return
  }
}

The linalg.generic with memref.store is not removed by canonicalization.

@nhat-nguyen
Copy link
Collaborator Author

That is perfect. Sorry about the incorrect IR (I was playing around with different ways to get this to work and forgot to revert). Looks like we need to update triton which in turn will update llvm to get the fix.

nhat-nguyen added a commit that referenced this pull request Jan 14, 2025
This PR introduces the `triton-to-unstructured` pass which is the first
step towards allowing triton-shared to compile pointer sequences that
cannot be analyzed by `triton-to-structured` (gather / scatter).

This pass attempts to lower all loads and stores of unstructured
pointers to
tts.gather or tts.scatter that take a single base, a tensor of offsets,
an
optional tensor of mask values, and a default value in case of load.

In addition, all pointer-producing ops will be eliminated and replaced
by
offset-producing ops. tts.gather and tts.scatter will use the pointer
directly from the kernel arguments as opposed to pointer produced by ops
such
as tt.addptr and tt.splat.

Example:

```mlir
module {
  tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<5> : tensor<64xi32>
    %cst_0 = arith.constant dense<10> : tensor<64xi32>
    %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %1 = arith.divsi %0, %cst_0 : tensor<64xi32>
    %2 = arith.addi %1, %cst : tensor<64xi32>
    %3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %4 = tt.addptr %3, %2 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
    %5 = tt.load %4 : tensor<64x!tt.ptr<f32>>
    %6 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>>
    %7 = tt.addptr %6, %0 : tensor<64x!tt.ptr<f32>>, tensor<64xi32>
    tt.store %7, %5 : tensor<64x!tt.ptr<f32>>
    tt.return
  }
}
```

becomes

```mlir
module {
  tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
    %cst = arith.constant dense<5> : tensor<64xi32>
    %cst_0 = arith.constant dense<10> : tensor<64xi32>
    %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32>
    %1 = arith.divsi %0, %cst_0 : tensor<64xi32>
    %2 = arith.addi %1, %cst : tensor<64xi32>
    %3 = tts.gather %arg0[%2] : (<f32>, tensor<64xi32>) -> tensor<64xf32>
    tts.scatter %3 into %arg1[%0] : tensor<64xf32> into (<f32>, tensor<64xi32>)
    tt.return
  }
}
```

Current assumptions and limitations:

- For simplicity, the pass assumes that gather / scatter operations load
/
store from / to a single base with a tensor of random offsets. As a
result, the following triton program would not work:

```python
@triton.jit
def gather_simple(in0, in1, out0):
    offs = tl.arange(0, 8)
    in0_ptrs = in0 + offs
    in1_ptrs = in1 + offs
    ptrs = tl.cat(in0_ptrs, in1_ptrs, can_reorder=True)
    c = tl.load(ptrs)
    out_offs = tl.arange(0, 16)
    tl.store(out0 + out_offs, c)
```

In the above program, `ptrs` contains 2 bases: `in0` and `in1` after the
`cat` operation.

For more details on the algorithm, see the
`TritonToUnstructuredPass.cpp` file.

# Future work

Future work may include scaling the algorithm to support multiple bases
-- one
possible solution is to let tts.gather and tts.scatter take in an
additional
tensor of base pointers corresponding to the tensor of offsets. But
because
we do not want pointer-producing ops to be present after this pass, we
can
use a tensor of index where each element indicates the index of the
pointer
argument to be used. The drawback is a gather or scatter operation now
needs
one extract lookup to get the base which will affect performance.

---

# Intended lowering pipeline
- triton-to-structured (no changes):
    - analyzes structured addptr sequences
        - introduces `tts.make_tptr %ptr_arg with offsets and strides`
        - introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
    - introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
    - converts kernel arguments with pointer type to memref
@nhat-nguyen nhat-nguyen marked this pull request as ready for review January 14, 2025 19:04
@nhat-nguyen
Copy link
Collaborator Author

@Nullkooland The update to latest triton is still ongoing. We might end up starting out with the affine version and update to linalg like you suggest in order to not delay landing these new features.

@nhat-nguyen
Copy link
Collaborator Author

@kile01 This is the continuation for lowering gather / scatter to linalg. This might not be super relevant to you but thought you might want to take a look too.

nhat-nguyen added a commit that referenced this pull request Jan 14, 2025
This PR introduces the `triton-ptr-to-memref` pass responsible for
converting function signature that uses triton ptr to use memref
instead. This is part of the work to allow triton-shared to lower gather
/ scatter pointer sequences.

Much of this code is copied from the current `StructuredToMemref` pass
which will be cleaned up in a later PR.

---

# Intended lowering pipeline
- triton-to-structured (no changes):
    - analyzes structured addptr sequences
        - introduces `tts.make_tptr %ptr_arg with offsets and strides`
        - introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
    - introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
    - converts kernel arguments with pointer type to memref
@beicy beicy self-requested a review January 15, 2025 15:30
@nhat-nguyen nhat-nguyen merged commit e9262ad into main Jan 15, 2025
3 checks passed
@nhat-nguyen nhat-nguyen deleted the nhat/unstructured branch January 15, 2025 17:50
nhat-nguyen added a commit that referenced this pull request Jan 16, 2025
…217)

This PR simplifies the `structured-to-memref` pass responsible for
converting structured triton load / store ops to memref load / store
ops. This is part of the work to allow triton-shared to lower gather /
scatter pointer sequences. Previously, this pass is also responsible for
converting scalar pointer load and store into memref; that
transformation has now been moved to `unstructured-to-memref`.

In addition, the PR also updates the `triton-to-linalg-experimental`
pass to fully utilize all the new passes. Once merged, triton-shared now
fully supports gather / scatter. An example test
(`test_gather_scatter.py`) is also added to demonstrate this new
capability.

---

# Intended lowering pipeline
- triton-to-structured (no changes):
    - analyzes structured addptr sequences
        - introduces `tts.make_tptr %ptr_arg with offsets and strides`
        - introduces `tts.load` and `tts.store`
- leaves unstructured addptr sequences and their corresponding `tt.load`
and `tt.store` intact
- triton-to-unstructured (#210):
    - introduces `tts.gather` and `tts.scatter`
- removes all pointer-producing ops such as `tt.addptr` and `tt.splat`
and replaces them with offset-producing ops
- structured-to-memref (#217):
- currently converts everything to memref including scalar addptr and
kernel arguments
- will change to just convert ops in the `tts` dialect to `memref` with
the exception of `tts.gather` and `tts.scatter`
- unstructured-to-memref (#216):
- converts the remaining unstructured `tts.gather`, `tts.scatter` into
memref
- triton-ptr-to-memref (#211):
    - converts kernel arguments with pointer type to memref
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants