Skip to content

Commit

Permalink
dialects: (csl) Switch dsds to use affine maps (#3657)
Browse files Browse the repository at this point in the history
Co-authored-by: n-io <[email protected]>
  • Loading branch information
n-io and n-io authored Dec 19, 2024
1 parent 0e1cec0 commit 48642de
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 58 deletions.
11 changes: 8 additions & 3 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ csl.func @builtins() {
%u32_pointer = "csl.addressof"(%u32_value) : (ui32) -> !csl.ptr<ui32, #csl<ptr_kind single>, #csl<ptr_const var>>

%A = memref.get_global @A : memref<24xf32>
%dsd_2d = "csl.get_mem_dsd"(%A, %i32_value, %i32_value) <{"strides" = [3, 4], "offsets" = [1, 2]}> : (memref<24xf32>, si32, si32) -> !csl<dsd mem4d_dsd>
%dsd_2d = "csl.get_mem_dsd"(%A, %i32_value, %i32_value) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<24xf32>, si32, si32) -> !csl<dsd mem4d_dsd>
%dest_dsd = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl<dsd mem1d_dsd>
%src_dsd1 = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl<dsd mem1d_dsd>
%src_dsd2 = "csl.get_mem_dsd"(%A, %i32_value) : (memref<24xf32>, si32) -> !csl<dsd mem1d_dsd>
Expand All @@ -426,7 +426,9 @@ csl.func @builtins() {
%fabin_dsd = "csl.get_fab_dsd"(%i32_value) <{"fabric_color" = 2 : ui5 , "queue_id" = 0 : i3}> : (si32) -> !csl<dsd fabin_dsd>
%fabout_dsd = "csl.get_fab_dsd"(%i32_value) <{"fabric_color" = 3 : ui5 , "queue_id" = 1 : i3, "control"= true, "wavelet_index_offset" = false}>: (si32) -> !csl<dsd fabout_dsd>

%zero_stride_dsd = "csl.get_mem_dsd"(%A, %i16_value, %i16_value, %i16_value) <{"strides" = [0 : si16, 0 : si16, 1 : si16]}> : (memref<24xf32>, si16, si16, si16) -> !csl<dsd mem4d_dsd>
%zero_stride_dsd = "csl.get_mem_dsd"(%A, %i16_value, %i16_value, %i16_value) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<24xf32>, si16, si16, si16) -> !csl<dsd mem4d_dsd>
%B = memref.get_global @B : memref<3x64xf32>
%oned_access_into_twod = "csl.get_mem_dsd"(%B, %i16_value) <{"tensor_access" = affine_map<(d0) -> (1, d0)>}> : (memref<3x64xf32>, si16) -> !csl<dsd mem1d_dsd>

"csl.add16"(%dest_dsd, %src_dsd1, %src_dsd2) : (!csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>, !csl<dsd mem1d_dsd>) -> ()
"csl.addc16"(%dest_dsd, %i16_value, %src_dsd1) : (!csl<dsd mem1d_dsd>, si16, !csl<dsd mem1d_dsd>) -> ()
Expand Down Expand Up @@ -795,7 +797,7 @@ csl.func @builtins() {
// CHECK-NEXT: var u16_pointer : *u16 = &u16_value;
// CHECK-NEXT: var u32_pointer : *u32 = &u32_value;
// CHECK-NEXT: const dsd_2d : mem4d_dsd = @get_dsd( mem4d_dsd, .{
// CHECK-NEXT: .tensor_access = | d0, d1 | { i32_value, i32_value } -> A[ 3 * d0 + 1, 4 * d1 + 2 ]
// CHECK-NEXT: .tensor_access = | d0, d1 | { i32_value, i32_value } -> A[ ((d0 * 3) + 1), ((d1 * 4) + 2) ]
// CHECK-NEXT: });
// CHECK-NEXT: const dest_dsd : mem1d_dsd = @get_dsd( mem1d_dsd, .{
// CHECK-NEXT: .tensor_access = | d0 | { i32_value } -> A[ d0 ]
Expand Down Expand Up @@ -825,6 +827,9 @@ csl.func @builtins() {
// CHECK-NEXT: const zero_stride_dsd : mem4d_dsd = @get_dsd( mem4d_dsd, .{
// CHECK-NEXT: .tensor_access = | d0, d1, d2 | { i16_value, i16_value, i16_value } -> A[ d2 ]
// CHECK-NEXT: });
// CHECK-NEXT: const oned_access_into_twod : mem1d_dsd = @get_dsd( mem1d_dsd, .{
// CHECK-NEXT: .tensor_access = | d0 | { i16_value } -> B[ 1, d0 ]
// CHECK-NEXT: });
// CHECK-NEXT: @add16(dest_dsd, src_dsd1, src_dsd2);
// CHECK-NEXT: @addc16(dest_dsd, i16_value, src_dsd1);
// CHECK-NEXT: @and16(dest_dsd, u16_value, src_dsd1);
Expand Down
14 changes: 7 additions & 7 deletions tests/filecheck/dialects/csl/csl-canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,20 @@ builtin.module {
%1 = "csl.zeros"() : () -> memref<512xf32>
%2 = "csl.get_mem_dsd"(%1, %0) : (memref<512xf32>, i16) -> !csl<dsd mem1d_dsd>

%3 = arith.constant 1 : si16
%4 = "csl.increment_dsd_offset"(%2, %3) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>
%int8 = arith.constant 3 : si8
%3 = "csl.set_dsd_stride"(%2, %int8) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>

%5 = arith.constant 510 : ui16
%6 = "csl.set_dsd_length"(%4, %5) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>
%4 = arith.constant 1 : si16
%5 = "csl.increment_dsd_offset"(%3, %4) <{"elem_type" = f32}> : (!csl<dsd mem1d_dsd>, si16) -> !csl<dsd mem1d_dsd>

%int8 = arith.constant 1 : si8
%7 = "csl.set_dsd_stride"(%6, %int8) : (!csl<dsd mem1d_dsd>, si8) -> !csl<dsd mem1d_dsd>
%6 = arith.constant 510 : ui16
%7 = "csl.set_dsd_length"(%5, %6) : (!csl<dsd mem1d_dsd>, ui16) -> !csl<dsd mem1d_dsd>

"test.op"(%7) : (!csl<dsd mem1d_dsd>) -> ()

// CHECK-NEXT: %0 = "csl.zeros"() : () -> memref<512xf32>
// CHECK-NEXT: %1 = arith.constant 510 : ui16
// CHECK-NEXT: %2 = "csl.get_mem_dsd"(%0, %1) <{"offsets" = [1 : si16], "strides" = [1 : si8]}> : (memref<512xf32>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %2 = "csl.get_mem_dsd"(%0, %1) <{"tensor_access" = affine_map<(d0) -> (((d0 * 3) + 1))>}> : (memref<512xf32>, ui16) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: "test.op"(%2) : (!csl<dsd mem1d_dsd>) -> ()


Expand Down
6 changes: 3 additions & 3 deletions tests/filecheck/dialects/csl/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ csl.func @initialize() {
%dir = "csl.get_dir"() <{"dir" = #csl<dir_kind north>}> : () -> !csl.direction

%dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
%dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3, 4], "offsets" = [1, 2]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl<dsd mem4d_dsd>
%dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl<dsd mem1d_dsd>, !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>) -> !csl<dsd mem1d_dsd>
Expand Down Expand Up @@ -392,7 +392,7 @@ csl.func @builtins() {
// CHECK-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-NEXT: %dir = "csl.get_dir"() <{"dir" = #csl<dir_kind north>}> : () -> !csl.direction
// CHECK-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
// CHECK-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl<dsd mem1d_dsd>, !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>) -> !csl<dsd mem1d_dsd>
Expand Down Expand Up @@ -639,7 +639,7 @@ csl.func @builtins() {
// CHECK-GENERIC-NEXT: %function_ptr = "csl.addressof_fn"() <{"fn_name" = @initialize}> : () -> !csl.ptr<() -> (), #csl<ptr_kind single>, #csl<ptr_const const>>
// CHECK-GENERIC-NEXT: %dir = "csl.get_dir"() <{"dir" = #csl<dir_kind north>}> : () -> !csl.direction
// CHECK-GENERIC-NEXT: %dsd_1d = "csl.get_mem_dsd"(%arr, %scalar) : (memref<10xf32>, i32) -> !csl<dsd mem1d_dsd>
// CHECK-GENERIC-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"strides" = [3 : i64, 4 : i64], "offsets" = [1 : i64, 2 : i64]}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_2d = "csl.get_mem_dsd"(%arr, %scalar, %scalar) <{"tensor_access" = affine_map<(d0, d1) -> (((d0 * 3) + 1), ((d1 * 4) + 2))>}> : (memref<10xf32>, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_3d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_4d = "csl.get_mem_dsd"(%arr, %scalar, %scalar, %scalar, %scalar) : (memref<10xf32>, i32, i32, i32, i32) -> !csl<dsd mem4d_dsd>
// CHECK-GENERIC-NEXT: %dsd_1d1 = "csl.set_dsd_base_addr"(%dsd_1d, %many_arr_ptr) : (!csl<dsd mem1d_dsd>, !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const const>>) -> !csl<dsd mem1d_dsd>
Expand Down
4 changes: 2 additions & 2 deletions tests/filecheck/transforms/lower-csl-stencil.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ builtin.module {
// CHECK-NEXT: %offset_1 = arith.index_cast %offset : i16 to index
// CHECK-NEXT: %42 = memref.subview %accumulator[%offset_1] [255] [1] : memref<510xf32> to memref<255xf32, strided<[1], offset: ?>>
// CHECK-NEXT: %43 = arith.constant 4 : i16
// CHECK-NEXT: %44 = "csl.get_mem_dsd"(%accumulator, %43, %29, %31) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %44 = "csl.get_mem_dsd"(%accumulator, %43, %29, %31) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<510xf32>, i16, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %45 = arith.index_cast %offset_1 : index to si16
// CHECK-NEXT: %46 = "csl.increment_dsd_offset"(%44, %45) <{"elem_type" = f32}> : (!csl<dsd mem4d_dsd>, si16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %47 = "csl.member_call"(%34) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl<dsd mem4d_dsd>
Expand Down Expand Up @@ -308,7 +308,7 @@ builtin.module {
// CHECK-NEXT: %offset_3 = arith.index_cast %offset_2 : i16 to index
// CHECK-NEXT: %88 = memref.subview %accumulator_1[%offset_3] [510] [1] : memref<510xf32> to memref<510xf32, strided<[1], offset: ?>>
// CHECK-NEXT: %89 = arith.constant 4 : i16
// CHECK-NEXT: %90 = "csl.get_mem_dsd"(%accumulator_1, %89, %arg3_1, %arg5_1) <{"strides" = [0 : i16, 0 : i16, 1 : i16]}> : (memref<510xf32>, i16, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %90 = "csl.get_mem_dsd"(%accumulator_1, %89, %arg3_1, %arg5_1) <{"tensor_access" = affine_map<(d0, d1, d2) -> (d2)>}> : (memref<510xf32>, i16, i16, i16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %91 = arith.index_cast %offset_3 : index to si16
// CHECK-NEXT: %92 = "csl.increment_dsd_offset"(%90, %91) <{"elem_type" = f32}> : (!csl<dsd mem4d_dsd>, si16) -> !csl<dsd mem4d_dsd>
// CHECK-NEXT: %93 = "csl.member_call"(%69) <{"field" = "getRecvBufDsd"}> : (!csl.imported_module) -> !csl<dsd mem4d_dsd>
Expand Down
31 changes: 9 additions & 22 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
i1,
)
from xdsl.ir import Attribute, Block, Operation, OpResult, Region, SSAValue
from xdsl.ir.affine import AffineMap
from xdsl.irdl import Operand
from xdsl.traits import is_side_effect_free
from xdsl.utils.comparisons import to_unsigned
Expand Down Expand Up @@ -755,36 +756,22 @@ def print_block(self, body: Block):
inner.print(f"@rpc(@get_data_task_id({id}));")
case csl.GetMemDsdOp(
base_addr=base_addr,
offsets=offsets,
strides=strides,
tensor_access=tensor_access,
sizes=sizes,
result=result,
):
sizes_str = ", ".join(
self._get_variable_name_for(size) for size in sizes
)
t_accesses = (
tensor_access.data
if tensor_access
else AffineMap.identity(len(sizes))
)

ind_vars = ["d" + str(i) for i in range(len(sizes))]
ind_vars_str = ", ".join(ind_vars)
accesses = [
(
f"{str(s)} * "
if strides and (s := strides.data[i].value.data) != 1
else ""
)
+ ind_vars[i]
+ (f" + {str(offsets.data[i].value.data)}" if offsets else "")
for i in range(len(ind_vars))
]
if strides and 0 in (
strides_data := [s.value.data for s in strides.data]
):
non_zero_stride_idx = [
idx for idx, sd in enumerate(strides_data) if sd != 0
]
# if all except one strides are 0, print only the non-0 part (default to printing all dims)
if len(non_zero_stride_idx) == 1:
accesses = [accesses[non_zero_stride_idx[0]]]
accesses_str = ", ".join(accesses)
accesses_str = ", ".join(str(expr) for expr in t_accesses.results)
self.print(
f"{self._var_use(result)} = @get_dsd( {self.mlir_type_to_csl_type(result.type)}, .{{"
)
Expand Down
19 changes: 9 additions & 10 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from xdsl.dialects import builtin
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloatAttr,
AnyFloatAttrConstr,
AnyIntegerAttr,
Expand Down Expand Up @@ -1143,8 +1144,7 @@ class GetMemDsdOp(_GetDsdOp):

name = "csl.get_mem_dsd"
base_addr = operand_def(base(MemRefType[Attribute]) | base(TensorType[Attribute]))
offsets = opt_prop_def(ArrayAttr[AnyIntegerAttr])
strides = opt_prop_def(ArrayAttr[AnyIntegerAttr])
tensor_access = opt_prop_def(AffineMapAttr)

traits = traits_def(
Pure(),
Expand All @@ -1166,14 +1166,13 @@ def verify_(self) -> None:
raise VerifyException(
"DSD of type mem4d_dsd must have between 1 and 4 dimensions"
)
if self.offsets is not None and len(self.offsets) != len(self.sizes):
raise VerifyException(
"Dimensions of offsets must match dimensions of sizes"
)
if self.strides is not None and len(self.strides) != len(self.sizes):
raise VerifyException(
"Dimensions of strides must match dimensions of sizes"
)
if self.tensor_access:
if len(self.sizes) != self.tensor_access.data.num_dims:
raise VerifyException(
"Dsd must have sizes specified for each dimension of the affine map"
)
if self.tensor_access.data.num_symbols != 0:
raise VerifyException("Symbols on affine map not supported")


@irdl_op_definition
Expand Down
36 changes: 27 additions & 9 deletions xdsl/transforms/canonicalization_patterns/csl.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from xdsl.dialects import arith
from xdsl.dialects.builtin import AnyIntegerAttrConstr, ArrayAttr
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyIntegerAttr,
)
from xdsl.dialects.csl import csl
from xdsl.ir import OpResult
from xdsl.ir.affine import AffineMap
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.utils.isattr import isattr
from xdsl.utils.hints import isa


class GetDsdAndOffsetFolding(RewritePattern):
Expand All @@ -23,20 +27,28 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> N
):
return
# only works on 1d
if op.offsets and len(op.offsets) > 1:
if len(op.sizes) > 1:
return

# check if we can promote arith.const to property
if (
isinstance(offset_op.offset, OpResult)
and isinstance(cnst := offset_op.offset.op, arith.ConstantOp)
and isattr(cnst.value, AnyIntegerAttrConstr)
and isa(attr_val := cnst.value, AnyIntegerAttr)
):
tensor_access = AffineMap.from_callable(
lambda x: (x + attr_val.value.data,)
)
if op.tensor_access:
tensor_access = tensor_access.compose(op.tensor_access.data)
rewriter.replace_matched_op(
new_op := csl.GetMemDsdOp.build(
operands=[op.base_addr, op.sizes],
result_types=op.result_types,
properties={**op.properties, "offsets": ArrayAttr([cnst.value])},
properties={
**op.properties,
"tensor_access": AffineMapAttr(tensor_access),
},
)
)
rewriter.replace_op(offset_op, [], new_results=[new_op.result])
Expand Down Expand Up @@ -81,21 +93,27 @@ def match_and_rewrite(self, op: csl.GetMemDsdOp, rewriter: PatternRewriter) -> N
stride_op := next(iter(op.result.uses)).operation, csl.SetDsdStrideOp
):
return
# only works on 1d
if op.offsets and len(op.offsets) > 1:
# only works on 1d and default (unspecified) tensor_access
if len(op.sizes) > 1 or op.tensor_access:
return

# check if we can promote arith.const to property
if (
isinstance(stride_op.stride, OpResult)
and isinstance(cnst := stride_op.stride.op, arith.ConstantOp)
and isattr(cnst.value, AnyIntegerAttrConstr)
and isa(attr_val := cnst.value, AnyIntegerAttr)
):
tensor_access = AffineMap.from_callable(
lambda x: (x * attr_val.value.data,)
)
rewriter.replace_matched_op(
new_op := csl.GetMemDsdOp.build(
operands=[op.base_addr, op.sizes],
result_types=op.result_types,
properties={**op.properties, "strides": ArrayAttr([cnst.value])},
properties={
**op.properties,
"tensor_access": AffineMapAttr(tensor_access),
},
)
)
rewriter.replace_op(stride_op, [], new_results=[new_op.result])
Expand Down
9 changes: 7 additions & 2 deletions xdsl/transforms/lower_csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from xdsl.context import MLContext
from xdsl.dialects import arith, func, memref, stencil
from xdsl.dialects.builtin import (
AffineMapAttr,
AnyFloatAttr,
AnyMemRefType,
ArrayAttr,
DenseIntOrFPElementsAttr,
Float16Type,
Float32Type,
Expand All @@ -28,6 +28,7 @@
Region,
SSAValue,
)
from xdsl.ir.affine import AffineMap
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand Down Expand Up @@ -458,7 +459,11 @@ def match_and_rewrite(self, op: csl_stencil.ApplyOp, rewriter: PatternRewriter,
acc_dsd = csl.GetMemDsdOp.build(
operands=[alloc, [direction_count, pattern, chunk_size]],
result_types=[dsd_t],
properties={"strides": ArrayAttr([IntegerAttr(i, 16) for i in [0, 0, 1]])},
properties={
"tensor_access": AffineMapAttr(
AffineMap.from_callable(lambda x, y, z: (z,))
)
},
)
new_acc = acc_dsd

Expand Down

0 comments on commit 48642de

Please sign in to comment.