From 94b58a0549af3b47f27aceac3c3f63bd9fcb1272 Mon Sep 17 00:00:00 2001 From: dk949 <56653556+dk949@users.noreply.github.com> Date: Tue, 29 Oct 2024 15:19:25 +0000 Subject: [PATCH] fix: (csl) Use `u32` as the counter in async loops Previously, `i16` was used, this significantly limited the maximum loop iterations. --- .../csl-stencil-handle-async-flow.mlir | 18 +++++++++--------- .../csl_stencil_handle_async_flow.py | 10 +++++++--- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir b/tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir index 65c2620302..2aeb8a52cf 100644 --- a/tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir +++ b/tests/filecheck/transforms/csl-stencil-handle-async-flow.mlir @@ -99,7 +99,7 @@ // CHECK-NEXT: "csl.export"(%36) <{"var_name" = "arg1", "type" = !csl.ptr, #csl>}> : (!csl.ptr, #csl>) -> () // CHECK-NEXT: "csl.export"() <{"var_name" = @gauss_seidel_func, "type" = () -> ()}> : () -> () // CHECK-NEXT: %37 = memref.alloc() {"alignment" = 64 : i64} : memref<510xf32> -// CHECK-NEXT: %38 = "csl.variable"() <{"default" = 0 : i16}> : () -> !csl.var +// CHECK-NEXT: %38 = "csl.variable"() <{"default" = 0 : ui32}> : () -> !csl.var // CHECK-NEXT: %39 = "csl.variable"() : () -> !csl.var> // CHECK-NEXT: %40 = "csl.variable"() : () -> !csl.var> // CHECK-NEXT: csl.func @gauss_seidel_func() { @@ -112,9 +112,9 @@ // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.task @for_cond0() attributes {"kind" = #csl, "id" = 1 : i5}{ -// CHECK-NEXT: %44 = arith.constant 1000 : i16 -// CHECK-NEXT: %45 = "csl.load_var"(%38) : (!csl.var) -> i16 -// CHECK-NEXT: %46 = arith.cmpi slt, %45, %44 : i16 +// CHECK-NEXT: %44 = arith.constant 1000 : ui32 +// CHECK-NEXT: %45 = "csl.load_var"(%38) : (!csl.var) -> ui32 +// CHECK-NEXT: %46 = arith.cmpi slt, %45, %44 : ui32 // CHECK-NEXT: scf.if %46 { // CHECK-NEXT: "csl.call"() <{"callee" = @for_body0}> : () -> () // CHECK-NEXT: } else { @@ -123,7 +123,7 @@ // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_body0() { -// CHECK-NEXT: %arg2 = "csl.load_var"(%38) : (!csl.var) -> i16 +// CHECK-NEXT: %arg2 = "csl.load_var"(%38) : (!csl.var) -> ui32 // CHECK-NEXT: %arg3 = "csl.load_var"(%39) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %arg4 = "csl.load_var"(%40) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: csl_stencil.apply(%arg3 : memref<512xf32>, %37 : memref<510xf32>) outs (%arg4 : memref<512xf32>) <{"bounds" = #stencil.bounds<[0, 0], [1, 1]>, "num_chunks" = 2 : i64, "operandSegmentSizes" = array, "swaps" = [#csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange, #csl_stencil.exchange], "topo" = #dmp.topo<1022x510>}> ({ @@ -151,10 +151,10 @@ // CHECK-NEXT: csl.return // CHECK-NEXT: } // CHECK-NEXT: csl.func @for_inc0() { -// CHECK-NEXT: %47 = arith.constant 1 : i16 -// CHECK-NEXT: %48 = "csl.load_var"(%38) : (!csl.var) -> i16 -// CHECK-NEXT: %49 = arith.addi %48, %47 : i16 -// CHECK-NEXT: "csl.store_var"(%38, %49) : (!csl.var, i16) -> () +// CHECK-NEXT: %47 = arith.constant 1 : ui32 +// CHECK-NEXT: %48 = "csl.load_var"(%38) : (!csl.var) -> ui32 +// CHECK-NEXT: %49 = arith.addi %48, %47 : ui32 +// CHECK-NEXT: "csl.store_var"(%38, %49) : (!csl.var, ui32) -> () // CHECK-NEXT: %50 = "csl.load_var"(%39) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: %51 = "csl.load_var"(%40) : (!csl.var>) -> memref<512xf32> // CHECK-NEXT: "csl.store_var"(%39, %51) : (!csl.var>, memref<512xf32>) -> () diff --git a/xdsl/transforms/csl_stencil_handle_async_flow.py b/xdsl/transforms/csl_stencil_handle_async_flow.py index 1f77062313..940136be9e 100644 --- a/xdsl/transforms/csl_stencil_handle_async_flow.py +++ b/xdsl/transforms/csl_stencil_handle_async_flow.py @@ -7,8 +7,10 @@ FunctionType, IndexType, IntegerAttr, + IntegerType, MemRefType, ModuleOp, + Signedness, SymbolRefAttr, ) from xdsl.dialects.csl import csl, csl_stencil, csl_wrapper @@ -30,6 +32,8 @@ from xdsl.rewriter import InsertPoint from xdsl.utils.hints import isa +u32 = IntegerType(32, Signedness.UNSIGNED) + @dataclass() class HandleCslStencilApplyAsyncCF(RewritePattern): @@ -162,7 +166,7 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter, /): # create csl.vars for loop var and iter_args outside the parent func rewriter.insert_op( - iv := csl.VariableOp.from_value(IntegerAttr(op.lb.op.value.value, 16)), + iv := csl.VariableOp.from_value(IntegerAttr(op.lb.op.value.value, u32)), InsertPoint.before(parent_func), ) iter_vars = [csl.VariableOp.from_type(arg_t) for arg_t in op.iter_args.types] @@ -177,7 +181,7 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter, /): # for-loop condition func with ImplicitBuilder(cond_func.body.block): - ub = arith.Constant.from_int_and_width(op.ub.op.value.value, 16) + ub = arith.Constant.from_int_and_width(op.ub.op.value.value, u32) iv_load = csl.LoadVarOp(iv) cond = arith.Cmpi(iv_load, ub, "slt") branch = scf.If(cond, [], Region(Block()), Region(Block())) @@ -191,7 +195,7 @@ def match_and_rewrite(self, op: scf.For, rewriter: PatternRewriter, /): # for-loop inc func with ImplicitBuilder(inc_func.body.block): - step = arith.Constant.from_int_and_width(op.step.op.value.value, 16) + step = arith.Constant.from_int_and_width(op.step.op.value.value, u32) iv_load = csl.LoadVarOp(iv) stepped = arith.Addi(iv_load, step) csl.StoreVarOp(iv, stepped)