Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: (csl) Use u32 as the counter in async loops #3353

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@
// CHECK-NEXT: "csl.export"(%36) <{"var_name" = "arg1", "type" = !csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>}> : (!csl.ptr<f32, #csl<ptr_kind many>, #csl<ptr_const var>>) -> ()
// CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> ()
// CHECK-NEXT: %37 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32>
// CHECK-NEXT: %38 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var<i16>
// CHECK-NEXT: %38 = "csl.variable"() <{"default" = 0 : ui32}> : () -> !csl.var<ui32>
// CHECK-NEXT: %39 = "csl.variable"() : () -> !csl.var<memref<512xf32>>
// CHECK-NEXT: %40 = "csl.variable"() : () -> !csl.var<memref<512xf32>>
// CHECK-NEXT: csl.func @gauss_seidel_func() {
Expand All @@ -112,9 +112,9 @@
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: csl.task @for_cond0() attributes {"kind" = #csl<task_kind local>, "id" = 1 : i5}{
// CHECK-NEXT: %44 = arith.constant 1000 : i16
// CHECK-NEXT: %45 = "csl.load_var"(%38) : (!csl.var<i16>) -> i16
// CHECK-NEXT: %46 = arith.cmpi slt, %45, %44 : i16
// CHECK-NEXT: %44 = arith.constant 1000 : ui32
// CHECK-NEXT: %45 = "csl.load_var"(%38) : (!csl.var<ui32>) -> ui32
// CHECK-NEXT: %46 = arith.cmpi slt, %45, %44 : ui32
// CHECK-NEXT: scf.if %46 {
// CHECK-NEXT: "csl.call"() <{"callee" = @for_body0}> : () -> ()
// CHECK-NEXT: } else {
Expand All @@ -123,7 +123,7 @@
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: csl.func @for_body0() {
// CHECK-NEXT: %arg2 = "csl.load_var"(%38) : (!csl.var<i16>) -> i16
// CHECK-NEXT: %arg2 = "csl.load_var"(%38) : (!csl.var<ui32>) -> ui32
// CHECK-NEXT: %arg3 = "csl.load_var"(%39) : (!csl.var<memref<512xf32>>) -> memref<512xf32>
// CHECK-NEXT: %arg4 = "csl.load_var"(%40) : (!csl.var<memref<512xf32>>) -> memref<512xf32>
// CHECK-NEXT: csl_stencil.apply(%arg3 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg4 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array<i32: 1, 1, 0, 0, 1>, "swaps" = [#csl_stencil.exchange<to [1, 0]>, #csl_stencil.exchange<to [-1, 0]>, #csl_stencil.exchange<to [0, 1]>, #csl_stencil.exchange<to [0, -1]>], "topo" = #dmp.topo<1022x510>}> ({
Expand Down Expand Up @@ -151,10 +151,10 @@
// CHECK-NEXT: csl.return
// CHECK-NEXT: }
// CHECK-NEXT: csl.func @for_inc0() {
// CHECK-NEXT: %47 = arith.constant 1 : i16
// CHECK-NEXT: %48 = "csl.load_var"(%38) : (!csl.var<i16>) -> i16
// CHECK-NEXT: %49 = arith.addi %48, %47 : i16
// CHECK-NEXT: "csl.store_var"(%38, %49) : (!csl.var<i16>, i16) -> ()
// CHECK-NEXT: %47 = arith.constant 1 : ui32
// CHECK-NEXT: %48 = "csl.load_var"(%38) : (!csl.var<ui32>) -> ui32
// CHECK-NEXT: %49 = arith.addi %48, %47 : ui32
// CHECK-NEXT: "csl.store_var"(%38, %49) : (!csl.var<ui32>, ui32) -> ()
// CHECK-NEXT: %50 = "csl.load_var"(%39) : (!csl.var<memref<512xf32>>) -> memref<512xf32>
// CHECK-NEXT: %51 = "csl.load_var"(%40) : (!csl.var<memref<512xf32>>) -> memref<512xf32>
// CHECK-NEXT: "csl.store_var"(%39, %51) : (!csl.var<memref<512xf32>>, memref<512xf32>) -> ()
Expand Down
10 changes: 7 additions & 3 deletions xdsl/transforms/csl_stencil_handle_async_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
FunctionType,
IndexType,
IntegerAttr,
IntegerType,
MemRefType,
ModuleOp,
Signedness,
SymbolRefAttr,
)
from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper
Expand All @@ -30,6 +32,8 @@
from xdsl.rewriter import InsertPoint
from xdsl.utils.hints import isa

u32 = IntegerType(32, Signedness.UNSIGNED)


@dataclass()
class HandleCslStencilApplyAsyncCF(RewritePattern):
Expand Down Expand Up @@ -162,7 +166,7 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter, /):

# create csl.vars for loop var and iter_args outside the parent func
rewriter.insert_op(
iv := csl.VariableOp.from_value(IntegerAttr(op.lb.op.value.value, 16)),
iv := csl.VariableOp.from_value(IntegerAttr(op.lb.op.value.value, u32)),
InsertPoint.before(parent_func),
)
iter_vars = [csl.VariableOp.from_type(arg_t) for arg_t in op.iter_args.types]
Expand All @@ -177,7 +181,7 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter, /):

# for-loop condition func
with ImplicitBuilder(cond_func.body.block):
ub = arith.Constant.from_int_and_width(op.ub.op.value.value, 16)
ub = arith.Constant.from_int_and_width(op.ub.op.value.value, u32)
iv_load = csl.LoadVarOp(iv)
cond = arith.Cmpi(iv_load, ub, "slt")
branch = scf.If(cond, [], Region(Block()), Region(Block()))
Expand All @@ -191,7 +195,7 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter, /):

# for-loop inc func
with ImplicitBuilder(inc_func.body.block):
step = arith.Constant.from_int_and_width(op.step.op.value.value, 16)
step = arith.Constant.from_int_and_width(op.step.op.value.value, u32)
iv_load = csl.LoadVarOp(iv)
stepped = arith.Addi(iv_load, step)
csl.StoreVarOp(iv, stepped)
Expand Down
Loading