diff --git a/tests/filecheck/backend/csl/print_csl.mlir b/tests/filecheck/backend/csl/print_csl.mlir index 9ed76201b6..83e42966b7 100644 --- a/tests/filecheck/backend/csl/print_csl.mlir +++ b/tests/filecheck/backend/csl/print_csl.mlir @@ -413,7 +413,7 @@ csl.func @builtins() { %u32_pointer = "csl.addressof"(%u32_value) : (ui32) -> !csl.ptr, #csl> %A = memref.get_global @A : memref<24xf32> - %dsd_2d = "csl.get_mem_dsd"(%A, %i32_value, %i32_value) <{"strides" = [3, 4], "offsets" = [1, 2]}> : (memref<24xf32>, si32, si32) -> !csl + %dsd_2d = "csl.get_mem_dsd"(%A, %i32_value, %i32_value) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<24xf32>, si32, si32) -> !csl %dest_dsd = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl %src_dsd1 = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl %src_dsd2 = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl @@ -426,7 +426,9 @@ csl.func @builtins() { %fabin_dsd = "csl.get_fab_dsd"(%i32_value) <{"fabric_color" = 2 : ui5 , "queue_id" = 0 : i3}> : (si32) -> !csl %fabout_dsd = "csl.get_fab_dsd"(%i32_value) <{"fabric_color" = 3 : ui5 , "queue_id" = 1 : i3, "control"= true, "wavelet_index_offset" = false}>: (si32) -> !csl - %zero_stride_dsd = "csl.get_mem_dsd"(%A, %i16_value, %i16_value, %i16_value) <{"strides" = [0 : si16, 0 : si16, 1 : si16]}> : (memref<24xf32>, si16, si16, si16) -> !csl + %zero_stride_dsd = "csl.get_mem_dsd"(%A, %i16_value, %i16_value, %i16_value) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<24xf32>, si16, si16, si16) -> !csl + %B = memref.get_global @B : memref<3x64xf32> + %oned_access_into_twod = "csl.get_mem_dsd"(%B, %i16_value) <{"tensor_access" = affine_map<(d0) -> (1, d0)>}> : (memref<3x64xf32>, si16) -> !csl "csl.add16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl, !csl, !csl) -> () "csl.addc16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl, si16, !csl) -> () @@ -795,7 +797,7 @@ csl.func @builtins() { // CHECK-NEXT: var u16_pointer : *u16 = &u16_value; // CHECK-NEXT: var u32_pointer : *u32 = &u32_value; // CHECK-NEXT: const dsd_2d : mem4d_dsd = @get_dsd( mem4d_dsd, .{ -// CHECK-NEXT: .tensor_access = | d0, d1 | { i32_value, i32_value } -> A[ 3 * d0 + 1, 4 * d1 + 2 ] +// CHECK-NEXT: .tensor_access = | d0, d1 | { i32_value, i32_value } -> A[ ((d0 * 3) + 1), ((d1 * 4) + 2) ] // CHECK-NEXT: }); // CHECK-NEXT: const dest_dsd : mem1d_dsd = @get_dsd( mem1d_dsd, .{ // CHECK-NEXT: .tensor_access = | d0 | { i32_value } -> A[ d0 ] @@ -825,6 +827,9 @@ csl.func @builtins() { // CHECK-NEXT: const zero_stride_dsd : mem4d_dsd = @get_dsd( mem4d_dsd, .{ // CHECK-NEXT: .tensor_access = | d0, d1, d2 | { i16_value, i16_value, i16_value } -> A[ d2 ] // CHECK-NEXT: }); +// CHECK-NEXT: const oned_access_into_twod : mem1d_dsd = @get_dsd( mem1d_dsd, .{ +// CHECK-NEXT: .tensor_access = | d0 | { i16_value } -> B[ 1, d0 ] +// CHECK-NEXT: }); // CHECK-NEXT: @add16(dest_dsd, src_dsd1, src_dsd2); // CHECK-NEXT: @addc16(dest_dsd, i16_value, src_dsd1); // CHECK-NEXT: @and16(dest_dsd, u16_value, src_dsd1); diff --git a/tests/filecheck/dialects/csl/csl-canonicalize.mlir b/tests/filecheck/dialects/csl/csl-canonicalize.mlir index 713595bf39..7188986cb7 100644 --- a/tests/filecheck/dialects/csl/csl-canonicalize.mlir +++ b/tests/filecheck/dialects/csl/csl-canonicalize.mlir @@ -8,20 +8,20 @@ builtin.module { %1 = "csl.zeros"() : () -> memref<512xf32> %2 = "csl.get_mem_dsd"(%1, %0) : (memref<512xf32>, i16) -> !csl -%3 = arith.constant 1 : si16 -%4 = "csl.increment_dsd_offset"(%2, %3) <{"elem_type" = f32}> : (!csl, si16) -> !csl +%int8 = arith.constant 3 : si8 +%3 = "csl.set_dsd_stride"(%2, %int8) : (!csl, si8) -> !csl -%5 = arith.constant 510 : ui16 -%6 = "csl.set_dsd_length"(%4, %5) : (!csl, ui16) -> !csl +%4 = arith.constant 1 : si16 +%5 = "csl.increment_dsd_offset"(%3, %4) <{"elem_type" = f32}> : (!csl, si16) -> !csl -%int8 = arith.constant 1 : si8 -%7 = "csl.set_dsd_stride"(%6, %int8) : (!csl, si8) -> !csl +%6 = arith.constant 510 : ui16 +%7 = "csl.set_dsd_length"(%5, %6) : (!csl, ui16) -> !csl "test.op"(%7) : (!csl) -> () // CHECK-NEXT: %0 = "csl.zeros"() : () -> memref<512xf32> // CHECK-NEXT: %1 = arith.constant 510 : ui16 -// CHECK-NEXT: %2 = "csl.get_mem_dsd"(%0, %1) <{"offsets" = [1 : si16], "strides" = [1 : si8]}> : (memref<512xf32>, ui16) -> !csl +// CHECK-NEXT: %2 = "csl.get_mem_dsd"(%0, %1) <{"tensor_access" = affine_map<(d0) -> (((d0 * 3) + 1))>}> : (memref<512xf32>, ui16) -> !csl // CHECK-NEXT: "test.op"(%2) : (!csl) -> () diff --git a/tests/filecheck/dialects/csl/ops.mlir b/tests/filecheck/dialects/csl/ops.mlir index ac125f832e..94ca709154 100644 --- a/tests/filecheck/dialects/csl/ops.mlir +++ b/tests/filecheck/dialects/csl/ops.mlir @@ -89,7 +89,7 @@ csl.func @initialize() { %dir = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl - %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3, 4], "offsets" = [1, 2]}> : (memref<10xf32>, i32, i32) -> !csl + %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl %dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl %dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl, !csl.ptr, #csl>) -> !csl @@ -392,7 +392,7 @@ csl.func @builtins() { // CHECK-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-NEXT: %dir = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction // CHECK-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl -// CHECK-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl +// CHECK-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl // CHECK-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl // CHECK-NEXT: %dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl // CHECK-NEXT: %dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl, !csl.ptr, #csl>) -> !csl @@ -639,7 +639,7 @@ csl.func @builtins() { // CHECK-GENERIC-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl, #csl> // CHECK-GENERIC-NEXT: %dir = "csl.get_dir"() <{"dir" = #csl}> : () -> !csl.direction // CHECK-GENERIC-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl -// CHECK-GENERIC-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl +// CHECK-GENERIC-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl // CHECK-GENERIC-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl // CHECK-GENERIC-NEXT: %dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl // CHECK-GENERIC-NEXT: %dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl, !csl.ptr, #csl>) -> !csl diff --git a/tests/filecheck/transforms/lower-csl-stencil.mlir b/tests/filecheck/transforms/lower-csl-stencil.mlir index 14a03ac340..99c8451403 100644 --- a/tests/filecheck/transforms/lower-csl-stencil.mlir +++ b/tests/filecheck/transforms/lower-csl-stencil.mlir @@ -108,7 +108,7 @@ builtin.module { // CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index // CHECK-NEXT: %42 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>> // CHECK-NEXT: %43 = arith.constant 4 : i16 -// CHECK-NEXT: %44 = "csl.get_mem_dsd"(%accumulator, %43, %29, %31) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl +// CHECK-NEXT: %44 = "csl.get_mem_dsd"(%accumulator, %43, %29, %31) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<510xf32>, i16, i16, i16) -> !csl // CHECK-NEXT: %45 = arith.index_cast %offset_1 : index to si16 // CHECK-NEXT: %46 = "csl.increment_dsd_offset"(%44, %45) <{"elem_type" = f32}> : (!csl, si16) -> !csl // CHECK-NEXT: %47 = "csl.member_call"(%34) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl @@ -308,7 +308,7 @@ builtin.module { // CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index // CHECK-NEXT: %88 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>> // CHECK-NEXT: %89 = arith.constant 4 : i16 -// CHECK-NEXT: %90 = "csl.get_mem_dsd"(%accumulator_1, %89, %arg3_1, %arg5_1) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl +// CHECK-NEXT: %90 = "csl.get_mem_dsd"(%accumulator_1, %89, %arg3_1, %arg5_1) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<510xf32>, i16, i16, i16) -> !csl // CHECK-NEXT: %91 = arith.index_cast %offset_3 : index to si16 // CHECK-NEXT: %92 = "csl.increment_dsd_offset"(%90, %91) <{"elem_type" = f32}> : (!csl, si16) -> !csl // CHECK-NEXT: %93 = "csl.member_call"(%69) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl diff --git a/xdsl/backend/csl/print_csl.py b/xdsl/backend/csl/print_csl.py index fa57238fe6..ae06f770b5 100644 --- a/xdsl/backend/csl/print_csl.py +++ b/xdsl/backend/csl/print_csl.py @@ -29,6 +29,7 @@ i1, ) from xdsl.ir import Attribute, Block, Operation, OpResult, Region, SSAValue +from xdsl.ir.affine import AffineMap from xdsl.irdl import Operand from xdsl.traits import is_side_effect_free from xdsl.utils.comparisons import to_unsigned @@ -755,36 +756,22 @@ def print_block(self, body: Block): inner.print(f"@rpc(@get_data_task_id({id}));") case csl.GetMemDsdOp( base_addr=base_addr, - offsets=offsets, - strides=strides, + tensor_access=tensor_access, sizes=sizes, result=result, ): sizes_str = ", ".join( self._get_variable_name_for(size) for size in sizes ) + t_accesses = ( + tensor_access.data + if tensor_access + else AffineMap.identity(len(sizes)) + ) + ind_vars = ["d" + str(i) for i in range(len(sizes))] ind_vars_str = ", ".join(ind_vars) - accesses = [ - ( - f"{str(s)} * " - if strides and (s := strides.data[i].value.data) != 1 - else "" - ) - + ind_vars[i] - + (f" + {str(offsets.data[i].value.data)}" if offsets else "") - for i in range(len(ind_vars)) - ] - if strides and 0 in ( - strides_data := [s.value.data for s in strides.data] - ): - non_zero_stride_idx = [ - idx for idx, sd in enumerate(strides_data) if sd != 0 - ] - # if all except one strides are 0, print only the non-0 part (default to printing all dims) - if len(non_zero_stride_idx) == 1: - accesses = [accesses[non_zero_stride_idx[0]]] - accesses_str = ", ".join(accesses) + accesses_str = ", ".join(str(expr) for expr in t_accesses.results) self.print( f"{self._var_use(result)} = @get_dsd( {self.mlir_type_to_csl_type(result.type)}, .{{" ) diff --git a/xdsl/dialects/csl/csl.py b/xdsl/dialects/csl/csl.py index e8f6da22e0..fe9192d47a 100644 --- a/xdsl/dialects/csl/csl.py +++ b/xdsl/dialects/csl/csl.py @@ -16,6 +16,7 @@ from xdsl.dialects import builtin from xdsl.dialects.builtin import ( + AffineMapAttr, AnyFloatAttr, AnyFloatAttrConstr, AnyIntegerAttr, @@ -1143,8 +1144,7 @@ class GetMemDsdOp(_GetDsdOp): name = "csl.get_mem_dsd" base_addr = operand_def(base(MemRefType[Attribute]) | base(TensorType[Attribute])) - offsets = opt_prop_def(ArrayAttr[AnyIntegerAttr]) - strides = opt_prop_def(ArrayAttr[AnyIntegerAttr]) + tensor_access = opt_prop_def(AffineMapAttr) traits = traits_def( Pure(), @@ -1166,14 +1166,13 @@ def verify_(self) -> None: raise VerifyException( "DSD of type mem4d_dsd must have between 1 and 4 dimensions" ) - if self.offsets is not None and len(self.offsets) != len(self.sizes): - raise VerifyException( - "Dimensions of offsets must match dimensions of sizes" - ) - if self.strides is not None and len(self.strides) != len(self.sizes): - raise VerifyException( - "Dimensions of strides must match dimensions of sizes" - ) + if self.tensor_access: + if len(self.sizes) != self.tensor_access.data.num_dims: + raise VerifyException( + "Dsd must have sizes specified for each dimension of the affine map" + ) + if self.tensor_access.data.num_symbols != 0: + raise VerifyException("Symbols on affine map not supported") @irdl_op_definition diff --git a/xdsl/transforms/canonicalization_patterns/csl.py b/xdsl/transforms/canonicalization_patterns/csl.py index 74ab5b6e07..e8f1771577 100644 --- a/xdsl/transforms/canonicalization_patterns/csl.py +++ b/xdsl/transforms/canonicalization_patterns/csl.py @@ -1,13 +1,17 @@ from xdsl.dialects import arith -from xdsl.dialects.builtin import AnyIntegerAttrConstr, ArrayAttr +from xdsl.dialects.builtin import ( + AffineMapAttr, + AnyIntegerAttr, +) from xdsl.dialects.csl import csl from xdsl.ir import OpResult +from xdsl.ir.affine import AffineMap from xdsl.pattern_rewriter import ( PatternRewriter, RewritePattern, op_type_rewrite_pattern, ) -from xdsl.utils.isattr import isattr +from xdsl.utils.hints import isa class GetDsdAndOffsetFolding(RewritePattern): @@ -23,20 +27,28 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> N ): return # only works on 1d - if op.offsets and len(op.offsets) > 1: + if len(op.sizes) > 1: return # check if we can promote arith.const to property if ( isinstance(offset_op.offset, OpResult) and isinstance(cnst := offset_op.offset.op, arith.ConstantOp) - and isattr(cnst.value, AnyIntegerAttrConstr) + and isa(attr_val := cnst.value, AnyIntegerAttr) ): + tensor_access = AffineMap.from_callable( + lambda x: (x + attr_val.value.data,) + ) + if op.tensor_access: + tensor_access = tensor_access.compose(op.tensor_access.data) rewriter.replace_matched_op( new_op := csl.GetMemDsdOp.build( operands=[op.base_addr, op.sizes], result_types=op.result_types, - properties={**op.properties, "offsets": ArrayAttr([cnst.value])}, + properties={ + **op.properties, + "tensor_access": AffineMapAttr(tensor_access), + }, ) ) rewriter.replace_op(offset_op, [], new_results=[new_op.result]) @@ -81,21 +93,27 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> N stride_op := next(iter(op.result.uses)).operation, csl.SetDsdStrideOp ): return - # only works on 1d - if op.offsets and len(op.offsets) > 1: + # only works on 1d and default (unspecified) tensor_access + if len(op.sizes) > 1 or op.tensor_access: return # check if we can promote arith.const to property if ( isinstance(stride_op.stride, OpResult) and isinstance(cnst := stride_op.stride.op, arith.ConstantOp) - and isattr(cnst.value, AnyIntegerAttrConstr) + and isa(attr_val := cnst.value, AnyIntegerAttr) ): + tensor_access = AffineMap.from_callable( + lambda x: (x * attr_val.value.data,) + ) rewriter.replace_matched_op( new_op := csl.GetMemDsdOp.build( operands=[op.base_addr, op.sizes], result_types=op.result_types, - properties={**op.properties, "strides": ArrayAttr([cnst.value])}, + properties={ + **op.properties, + "tensor_access": AffineMapAttr(tensor_access), + }, ) ) rewriter.replace_op(stride_op, [], new_results=[new_op.result]) diff --git a/xdsl/transforms/lower_csl_stencil.py b/xdsl/transforms/lower_csl_stencil.py index df427022f4..3aa43841e6 100644 --- a/xdsl/transforms/lower_csl_stencil.py +++ b/xdsl/transforms/lower_csl_stencil.py @@ -4,9 +4,9 @@ from xdsl.context import MLContext from xdsl.dialects import arith, func, memref, stencil from xdsl.dialects.builtin import ( + AffineMapAttr, AnyFloatAttr, AnyMemRefType, - ArrayAttr, DenseIntOrFPElementsAttr, Float16Type, Float32Type, @@ -28,6 +28,7 @@ Region, SSAValue, ) +from xdsl.ir.affine import AffineMap from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -458,7 +459,11 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter, acc_dsd = csl.GetMemDsdOp.build( operands=[alloc, [direction_count, pattern, chunk_size]], result_types=[dsd_t], - properties={"strides": ArrayAttr([IntegerAttr(i, 16) for i in [0, 0, 1]])}, + properties={ + "tensor_access": AffineMapAttr( + AffineMap.from_callable(lambda x, y, z: (z,)) + ) + }, ) new_acc = acc_dsd