-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce unstructured-to-memref
pass
#216
Conversation
@nhat-nguyen Hi, I tried your PR, and noticed that %ptrs = "tts.make_unstructured_tptr"(%base_ptr, %offsets) : (!tt.ptr<f16>, tensor<1024xi32>) -> tensor<1024x!tt.ptr<f16>>
tt.store %ptrs, %vals, %mask : tensor<1024x!tt.ptr<f16>> is converted to: %base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
affine.for %i = 0 to 1024 {
%cond = tensor.extract %mask[%i] : tensor<1024xi1>
scf.if %cond {
%offset_i32 = tensor.extract %offsets[%i] : tensor<1024xi32>
%val = tensor.extract %vals[%i] : tensor<1024xf16>
%offset = arith.index_cast %offset_i32 : i32 to index
memref.store %val, %base_mem[%offset] : memref<?xf16>
}
} However, as far as I know So I suggest converting %base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
%cst = arith.constant 0.000000e+00 : f16
%dummy = tensor.empty() : tensor<1024xf16> // dummy output tensor.
%11 = linalg.generic {
indexing_maps = [#map, #map, #map, #map],
iterator_types = ["parallel"]
}
ins(%offsets, %vals, %mask : tensor<1024xi32>, tensor<1024xf16>, tensor<1024xi1>)
outs(%dummy: tensor<1024xf16>) {
^bb0(%offset_i32 : i32, %val : f16, %cond : i1, %out_dummy : f16):
%dummy_val = scf.if %cond -> f16 {
%offset = arith.index_cast %offset_i32 : i32 to index
memref.store %val, %base_mem[%offset] : memref<?xf16>
scf.yield %cst : f16
} else {
scf.yield %cst : f16
}
linalg.yield %dummy_val : f16
} -> tensor<1024xf16> later we can lower the %base_mem = memref.cast %base_ptr : memref<*xf16> to memref<?xf16>
scf.parallel (%i) = (%c0) to (%c1024) step (%c1) {
%offset_i32 = memref.load %offsets_buf[%i] : memref<1024xi32>
%val = memref.load %vals_buf[%i] : memref<1024xf16>
%cond = memref.load %mask_buf[%i] : memref<1024xi1>
scf.if %4 {
%offset = arith.index_cast %offset_i32 : i32 to index
memref.store %val, %base_mem[%offset] : memref<?xf16>
}
scf.reduce
} |
@Nullkooland The issue with using linalg on tensor for these write operations is the whole op can be removed through canonicalization if we haven't converted to memref yet. I think that is a pretty big drawback -- worse than not being able to leverage the conversion to parallel loop. With the following IR: #map = affine_map<(d0) -> (d0)>
module {
tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
%0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
%1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
%cst = arith.constant 9.900000e+01 : f32
%dummy_const = arith.constant 1 : i1
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c2_i32 = arith.constant 2 : i32
%cst_0 = arith.constant dense<4> : tensor<4xi32>
%cst_1 = arith.constant dense<64> : tensor<4xi32>
%cst_2 = arith.constant dense<3> : tensor<4xi32>
%2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>) : i32 {
%4 = arith.divsi %arg3, %cst_2 : tensor<4xi32>
%5 = tt.splat %arg2 : i32 -> tensor<4xi32>
%6 = arith.addi %4, %5 : tensor<4xi32>
%7 = arith.cmpi slt, %6, %cst_1 : tensor<4xi32>
%cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
%8 = bufferization.to_tensor %cast restrict : memref<?xf32>
%9 = tensor.empty() : tensor<4xf32>
%10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
^bb0(%in: i32, %in_4: i1, %out: f32):
%15 = scf.if %in_4 -> (f32) {
%16 = arith.index_cast %in : i32 to index
%extracted = tensor.extract %8[%16] : tensor<?xf32>
scf.yield %extracted : f32
} else {
scf.yield %cst : f32
}
linalg.yield %15 : f32
} -> tensor<4xf32>
%cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
%11 = tensor.empty() : tensor<4xi1>
%alloc = memref.alloc() : memref<4xi1>
linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%10, %6, %7 : tensor<4xf32>, tensor<4xi32>, tensor<4xi1>) outs(%alloc : memref<4xi1>) {
^bb0(%in: f32, %in_4: i32, %in_5: i1, %out: i1):
%15 = arith.index_cast %in_4 : i32 to index
%yield = scf.if %in_5 -> i1 {
memref.store %in, %cast_3[%15] : memref<?xf32>
scf.yield %dummy_const : i1
} else {
scf.yield %dummy_const : i1
}
linalg.yield %yield : i1
}
%13 = arith.addi %6, %cst_0 : tensor<4xi32>
%14 = arith.addi %arg4, %cst_0 : tensor<4xi32>
scf.yield %13, %14 : tensor<4xi32>, tensor<4xi32>
}
tt.return
}
}
|
@nhat-nguyen This might be an upstream MLIR bug, due to You could apply this upstream fix patch to the |
@nhat-nguyen I trited your example IR (with minor modification that outs triton-shared-opt --canonicalize --cse masked_gather_scatter.mlir the output IR is: #map = affine_map<(d0) -> (d0)>
module {
tt.func public @masked_gather_scatter(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} {
%cst = arith.constant dense<3> : tensor<4xi32>
%cst_0 = arith.constant dense<64> : tensor<4xi32>
%cst_1 = arith.constant dense<4> : tensor<4xi32>
%c2_i32 = arith.constant 2 : i32
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%true = arith.constant true
%cst_2 = arith.constant 9.900000e+01 : f32
%0 = builtin.unrealized_conversion_cast %arg1 : !tt.ptr<f32> to memref<*xf32>
%1 = builtin.unrealized_conversion_cast %arg0 : !tt.ptr<f32> to memref<*xf32>
%2 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%3:2 = scf.for %arg2 = %c0_i32 to %c2_i32 step %c1_i32 iter_args(%arg3 = %2, %arg4 = %2) -> (tensor<4xi32>, tensor<4xi32>) : i32 {
%4 = arith.divsi %arg3, %cst : tensor<4xi32>
%5 = tt.splat %arg2 : i32 -> tensor<4xi32>
%6 = arith.addi %4, %5 : tensor<4xi32>
%7 = arith.cmpi slt, %6, %cst_0 : tensor<4xi32>
%cast = memref.cast %1 : memref<*xf32> to memref<?xf32>
%8 = bufferization.to_tensor %cast restrict : memref<?xf32>
%9 = tensor.empty() : tensor<4xf32>
%10 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%6, %7 : tensor<4xi32>, tensor<4xi1>) outs(%9 : tensor<4xf32>) {
^bb0(%in: i32, %in_4: i1, %out: f32):
%15 = scf.if %in_4 -> (f32) {
%16 = arith.index_cast %in : i32 to index
%extracted = tensor.extract %8[%16] : tensor<?xf32>
scf.yield %extracted : f32
} else {
scf.yield %cst_2 : f32
}
linalg.yield %15 : f32
} -> tensor<4xf32>
%cast_3 = memref.cast %0 : memref<*xf32> to memref<?xf32>
%11 = tensor.empty() : tensor<4xi1>
%12 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%10, %6, %7 : tensor<4xf32>, tensor<4xi32>, tensor<4xi1>) outs(%11 : tensor<4xi1>) {
^bb0(%in: f32, %in_4: i32, %in_5: i1, %out: i1):
%15 = arith.index_cast %in_4 : i32 to index
scf.if %in_5 {
memref.store %in, %cast_3[%15] : memref<?xf32>
}
linalg.yield %true : i1
} -> tensor<4xi1>
%13 = arith.addi %6, %cst_1 : tensor<4xi32>
%14 = arith.addi %arg4, %cst_1 : tensor<4xi32>
scf.yield %13, %14 : tensor<4xi32>, tensor<4xi32>
}
tt.return
}
} The |
That is perfect. Sorry about the incorrect IR (I was playing around with different ways to get this to work and forgot to revert). Looks like we need to update triton which in turn will update llvm to get the fix. |
This PR introduces the `triton-to-unstructured` pass which is the first step towards allowing triton-shared to compile pointer sequences that cannot be analyzed by `triton-to-structured` (gather / scatter). This pass attempts to lower all loads and stores of unstructured pointers to tts.gather or tts.scatter that take a single base, a tensor of offsets, an optional tensor of mask values, and a default value in case of load. In addition, all pointer-producing ops will be eliminated and replaced by offset-producing ops. tts.gather and tts.scatter will use the pointer directly from the kernel arguments as opposed to pointer produced by ops such as tt.addptr and tt.splat. Example: ```mlir module { tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} { %cst = arith.constant dense<5> : tensor<64xi32> %cst_0 = arith.constant dense<10> : tensor<64xi32> %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> %1 = arith.divsi %0, %cst_0 : tensor<64xi32> %2 = arith.addi %1, %cst : tensor<64xi32> %3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %4 = tt.addptr %3, %2 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> %5 = tt.load %4 : tensor<64x!tt.ptr<f32>> %6 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<64x!tt.ptr<f32>> %7 = tt.addptr %6, %0 : tensor<64x!tt.ptr<f32>>, tensor<64xi32> tt.store %7, %5 : tensor<64x!tt.ptr<f32>> tt.return } } ``` becomes ```mlir module { tt.func public @gather_simple_no_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>) attributes {noinline = false} { %cst = arith.constant dense<5> : tensor<64xi32> %cst_0 = arith.constant dense<10> : tensor<64xi32> %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32> %1 = arith.divsi %0, %cst_0 : tensor<64xi32> %2 = arith.addi %1, %cst : tensor<64xi32> %3 = tts.gather %arg0[%2] : (<f32>, tensor<64xi32>) -> tensor<64xf32> tts.scatter %3 into %arg1[%0] : tensor<64xf32> into (<f32>, tensor<64xi32>) tt.return } } ``` Current assumptions and limitations: - For simplicity, the pass assumes that gather / scatter operations load / store from / to a single base with a tensor of random offsets. As a result, the following triton program would not work: ```python @triton.jit def gather_simple(in0, in1, out0): offs = tl.arange(0, 8) in0_ptrs = in0 + offs in1_ptrs = in1 + offs ptrs = tl.cat(in0_ptrs, in1_ptrs, can_reorder=True) c = tl.load(ptrs) out_offs = tl.arange(0, 16) tl.store(out0 + out_offs, c) ``` In the above program, `ptrs` contains 2 bases: `in0` and `in1` after the `cat` operation. For more details on the algorithm, see the `TritonToUnstructuredPass.cpp` file. # Future work Future work may include scaling the algorithm to support multiple bases -- one possible solution is to let tts.gather and tts.scatter take in an additional tensor of base pointers corresponding to the tensor of offsets. But because we do not want pointer-producing ops to be present after this pass, we can use a tensor of index where each element indicates the index of the pointer argument to be used. The drawback is a gather or scatter operation now needs one extract lookup to get the base which will affect performance. --- # Intended lowering pipeline - triton-to-structured (no changes): - analyzes structured addptr sequences - introduces `tts.make_tptr %ptr_arg with offsets and strides` - introduces `tts.load` and `tts.store` - leaves unstructured addptr sequences and their corresponding `tt.load` and `tt.store` intact - triton-to-unstructured (#210): - introduces `tts.gather` and `tts.scatter` - removes all pointer-producing ops such as `tt.addptr` and `tt.splat` and replaces them with offset-producing ops - structured-to-memref (#217): - currently converts everything to memref including scalar addptr and kernel arguments - will change to just convert ops in the `tts` dialect to `memref` with the exception of `tts.gather` and `tts.scatter` - unstructured-to-memref (#216): - converts the remaining unstructured `tts.gather`, `tts.scatter` into memref - triton-ptr-to-memref (#211): - converts kernel arguments with pointer type to memref
@Nullkooland The update to latest triton is still ongoing. We might end up starting out with the affine version and update to linalg like you suggest in order to not delay landing these new features. |
@kile01 This is the continuation for lowering gather / scatter to linalg. This might not be super relevant to you but thought you might want to take a look too. |
This PR introduces the `triton-ptr-to-memref` pass responsible for converting function signature that uses triton ptr to use memref instead. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. Much of this code is copied from the current `StructuredToMemref` pass which will be cleaned up in a later PR. --- # Intended lowering pipeline - triton-to-structured (no changes): - analyzes structured addptr sequences - introduces `tts.make_tptr %ptr_arg with offsets and strides` - introduces `tts.load` and `tts.store` - leaves unstructured addptr sequences and their corresponding `tt.load` and `tt.store` intact - triton-to-unstructured (#210): - introduces `tts.gather` and `tts.scatter` - removes all pointer-producing ops such as `tt.addptr` and `tt.splat` and replaces them with offset-producing ops - structured-to-memref (#217): - currently converts everything to memref including scalar addptr and kernel arguments - will change to just convert ops in the `tts` dialect to `memref` with the exception of `tts.gather` and `tts.scatter` - unstructured-to-memref (#216): - converts the remaining unstructured `tts.gather`, `tts.scatter` into memref - triton-ptr-to-memref (#211): - converts kernel arguments with pointer type to memref
…217) This PR simplifies the `structured-to-memref` pass responsible for converting structured triton load / store ops to memref load / store ops. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. Previously, this pass is also responsible for converting scalar pointer load and store into memref; that transformation has now been moved to `unstructured-to-memref`. In addition, the PR also updates the `triton-to-linalg-experimental` pass to fully utilize all the new passes. Once merged, triton-shared now fully supports gather / scatter. An example test (`test_gather_scatter.py`) is also added to demonstrate this new capability. --- # Intended lowering pipeline - triton-to-structured (no changes): - analyzes structured addptr sequences - introduces `tts.make_tptr %ptr_arg with offsets and strides` - introduces `tts.load` and `tts.store` - leaves unstructured addptr sequences and their corresponding `tt.load` and `tt.store` intact - triton-to-unstructured (#210): - introduces `tts.gather` and `tts.scatter` - removes all pointer-producing ops such as `tt.addptr` and `tt.splat` and replaces them with offset-producing ops - structured-to-memref (#217): - currently converts everything to memref including scalar addptr and kernel arguments - will change to just convert ops in the `tts` dialect to `memref` with the exception of `tts.gather` and `tts.scatter` - unstructured-to-memref (#216): - converts the remaining unstructured `tts.gather`, `tts.scatter` into memref - triton-ptr-to-memref (#211): - converts kernel arguments with pointer type to memref
This PR introduces the
unstructured-to-memref
pass responsible for converting unstructured triton load / store ops to memref load / store ops. This is part of the work to allow triton-shared to lower gather / scatter pointer sequences. The pass is intended to be used after running--fold-unstructured-ptr
.Triton load op (gather) is lowered to a
linalg.generic
whose body contains a load from the offset indicated by the offset provided bytts.make_unstructured_tptr
. For load op with mask, an inner-mostscf.if
is used to return a default value (or theother
intt.load
if provided) if the corresponding mask value is false.Example of a load:
Triton store op (scatter) is lowered to an
affine.for
loop nest that stores the value to the appropriate offset provided bytts.make_unstructured_tptr
. Store op with mask is also supported.Example of a store:
Intended lowering pipeline
tts.make_tptr %ptr_arg with offsets and strides
tts.load
andtts.store
tt.load
andtt.store
intacttriton-to-unstructured
pass #210):tts.gather
andtts.scatter
tt.addptr
andtt.splat
and replaces them with offset-producing opsstructured-to-memref
pass to support the new pass pipeline #217):tts
dialect tomemref
with the exception oftts.gather
andtts.scatter
unstructured-to-memref
pass #216):tts.gather
,tts.scatter
into memreftriton-ptr-to-memref
pass #211):