From 6d001d1ebbf76d0bf8bd510f6c98d3233320769f Mon Sep 17 00:00:00 2001 From: n-io Date: Fri, 25 Oct 2024 19:32:13 +0200 Subject: [PATCH 1/3] transformations: (linalg-to-csl) Lower generic to fmac(h|s) --- tests/filecheck/transforms/linalg-to-csl.mlir | 13 +++ xdsl/transforms/linalg_to_csl.py | 102 ++++++++++++++---- 2 files changed, 93 insertions(+), 22 deletions(-) diff --git a/tests/filecheck/transforms/linalg-to-csl.mlir b/tests/filecheck/transforms/linalg-to-csl.mlir index 4c76663461..de4a5f0106 100644 --- a/tests/filecheck/transforms/linalg-to-csl.mlir +++ b/tests/filecheck/transforms/linalg-to-csl.mlir @@ -1,5 +1,7 @@ // RUN: xdsl-opt %s -p linalg-to-csl | filecheck %s +#map = affine_map<(d0) -> (d0)> + builtin.module { %0, %1, %2, %3, %4 = "test.op"() : () -> (memref<16xf32>, memref<16xf32>, memref<16xf32>, memref<16xf32>, memref<16xf32>) linalg.add ins(%1, %2 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>) @@ -14,6 +16,14 @@ builtin.module { %10 = arith.constant dense<1.123400e-01> : memref<16xf32> linalg.add ins(%0, %10 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>) linalg.mul ins(%10, %0 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>) + + %c = arith.constant dense<2.99792458e+08> : memref<16xf32> + linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%0, %c, %2 : memref<16xf32>, memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>) { + ^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32): + %11 = arith.mulf %in, %in_0 : f32 + %12 = arith.addf %11, %in_1 : f32 + linalg.yield %12 : f32 + } } //CHECK-NEXT: builtin.module { @@ -30,4 +40,7 @@ builtin.module { //CHECK-NEXT: "csl.fadds"(%0, %0, %11) : (memref<16xf32>, memref<16xf32>, f32) -> () //CHECK-NEXT: %12 = arith.constant 1.123400e-01 : f32 //CHECK-NEXT: "csl.fmuls"(%0, %12, %0) : (memref<16xf32>, f32, memref<16xf32>) -> () +//CHECK-NEXT: %c = arith.constant dense<2.99792458e+08> : memref<16xf32> +//CHECK-NEXT: %13 = arith.constant 2.997925e+08 : f32 +//CHECK-NEXT: "csl.fmacs"(%0, %0, %2, %13) : (memref<16xf32>, memref<16xf32>, memref<16xf32>, f32) -> () //CHECK-NEXT: } diff --git a/xdsl/transforms/linalg_to_csl.py b/xdsl/transforms/linalg_to_csl.py index 3bcadaad67..065866c887 100644 --- a/xdsl/transforms/linalg_to_csl.py +++ b/xdsl/transforms/linalg_to_csl.py @@ -12,7 +12,7 @@ ModuleOp, ) from xdsl.dialects.csl import csl -from xdsl.ir import OpResult, SSAValue +from xdsl.ir import Attribute, OpResult, SSAValue from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -25,6 +25,31 @@ from xdsl.utils.hints import isa +def match_op_for_precision( + prec: Attribute, f16: type[csl.BuiltinDsdOp], f32: type[csl.BuiltinDsdOp] +) -> type[csl.BuiltinDsdOp]: + """Returns the op type matching a given precision.""" + # todo support mixed-precision + match prec: + case Float16Type(): + return f16 + case Float32Type(): + return f32 + case _: + raise ValueError(f"Unsupported element type {prec}") + + +def get_scalar_const(op: SSAValue) -> AnyFloatAttr | AnyIntegerAttr | None: + """Returns the value of a scalar arith.constant, or None if not a constant or not scalar).""" + if ( + isinstance(op, OpResult) + and isinstance(op.op, arith.Constant) + and isa(val := op.op.value, DenseIntOrFPElementsAttr) + and val.data.data.count(val.data.data[0]) == len(val.data.data) + ): + return val.data.data[0] + + class ConvertBinaryLinalgOp(RewritePattern): """ Base class for converting binary linalg operations. @@ -37,30 +62,22 @@ def transform_op( f16: type[csl.BuiltinDsdOp], f32: type[csl.BuiltinDsdOp], ): - if not isa(op.outputs.types[0], AnyMemRefType): + if not isa(target_t := op.outputs.types[0], AnyMemRefType): return - match op.outputs.types[0].get_element_type(): - case Float16Type(): - builtin = f16 - case Float32Type(): - builtin = f32 - case _: - raise ValueError( - f"Unsupported element type {op.outputs.types[0].get_element_type()}" - ) + builtin = match_op_for_precision(target_t.get_element_type(), f16, f32) lhs = op.inputs[0] rhs = op.inputs[1] # binary functions translated here support mixing scalar and collection operands # may need revisiting if more functions are translated - if scalar_const := self._get_scalar_const(lhs): + if scalar_const := get_scalar_const(lhs): rewriter.insert_op( const_op := arith.Constant(scalar_const), InsertPoint.before(op) ) lhs = const_op.result - elif scalar_const := self._get_scalar_const(rhs): + elif scalar_const := get_scalar_const(rhs): rewriter.insert_op( const_op := arith.Constant(scalar_const), InsertPoint.before(op) ) @@ -68,16 +85,56 @@ def transform_op( rewriter.replace_matched_op(builtin(operands=[[op.outputs[0], lhs, rhs]])) + +class ConvertLinalgGenericFMAPass(RewritePattern): + """Lowers `linalg.generic` fused multiply-adds to csl builtin ops.""" + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter, /): + if not self.is_fma(op) or not isa(op.outputs.types[0], AnyMemRefType): + return + + # one of the factors must be a scalar const, which the csl function signatures require + if scalar_const := get_scalar_const(op.inputs[0]): + rewriter.insert_op( + const_op := arith.Constant(scalar_const), InsertPoint.before(op) + ) + non_scalar = op.inputs[1] + elif scalar_const := get_scalar_const(op.inputs[1]): + rewriter.insert_op( + const_op := arith.Constant(scalar_const), InsertPoint.before(op) + ) + non_scalar = op.inputs[0] + else: + # if neither factor is a scalar, return + return + + # fetch the csl op to build depending on the precision + csl_op = match_op_for_precision( + op.outputs.types[0].get_element_type(), f16=csl.FmachOp, f32=csl.FmacsOp + ) + + rewriter.replace_matched_op( + csl_op(operands=[[op.outputs[0], non_scalar, op.inputs[2], const_op]]) + ) + @staticmethod - def _get_scalar_const(op: SSAValue) -> AnyFloatAttr | AnyIntegerAttr | None: - """Returns the value of a scalar arith.constant, or None if not a constant or not scalar).""" - if ( - isinstance(op, OpResult) - and isinstance(op.op, arith.Constant) - and isa(val := op.op.value, DenseIntOrFPElementsAttr) - and val.data.data.count(val.data.data[0]) == len(val.data.data) - ): - return val.data.data[0] + def is_fma(op: linalg.Generic) -> bool: + """Returns if a given `generic` op is a fused multiply-add""" + return ( + len(op.inputs) == 3 + and len(op.outputs) == 1 + and len((block := op.body.block).args) == 4 + and len(block.ops) == 3 + and isinstance(mul := block.first_op, arith.Mulf) + and mul.lhs == block.args[0] + and mul.rhs == block.args[1] + and isinstance(add := mul.next_op, arith.Addf) + and add.lhs == mul.result + and add.rhs == block.args[2] + and isinstance(yld := add.next_op, linalg.YieldOp) + and yld.operands[0] == add.result + ) class ConvertLinalgAddPass(ConvertBinaryLinalgOp): @@ -112,6 +169,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None: module_pass = PatternRewriteWalker( GreedyRewritePatternApplier( [ + ConvertLinalgGenericFMAPass(), ConvertLinalgAddPass(), ConvertLinalgSubPass(), ConvertLinalgMulPass(), From cbed081f00a75beb41d7fe55f5e22b4ed008432b Mon Sep 17 00:00:00 2001 From: n-io Date: Fri, 25 Oct 2024 19:53:59 +0200 Subject: [PATCH 2/3] fix filecheck --- tests/filecheck/transforms/linalg-to-csl.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/filecheck/transforms/linalg-to-csl.mlir b/tests/filecheck/transforms/linalg-to-csl.mlir index de4a5f0106..c33eaa7881 100644 --- a/tests/filecheck/transforms/linalg-to-csl.mlir +++ b/tests/filecheck/transforms/linalg-to-csl.mlir @@ -40,7 +40,7 @@ builtin.module { //CHECK-NEXT: "csl.fadds"(%0, %0, %11) : (memref<16xf32>, memref<16xf32>, f32) -> () //CHECK-NEXT: %12 = arith.constant 1.123400e-01 : f32 //CHECK-NEXT: "csl.fmuls"(%0, %12, %0) : (memref<16xf32>, f32, memref<16xf32>) -> () -//CHECK-NEXT: %c = arith.constant dense<2.99792458e+08> : memref<16xf32> +//CHECK-NEXT: %c = arith.constant dense<2.997925e+08> : memref<16xf32> //CHECK-NEXT: %13 = arith.constant 2.997925e+08 : f32 //CHECK-NEXT: "csl.fmacs"(%0, %0, %2, %13) : (memref<16xf32>, memref<16xf32>, memref<16xf32>, f32) -> () //CHECK-NEXT: } From d14fc2295c2a98bb908a97b224e2247c0c95a7a4 Mon Sep 17 00:00:00 2001 From: n-io Date: Mon, 28 Oct 2024 18:32:24 +0100 Subject: [PATCH 3/3] fix operand order --- tests/filecheck/transforms/linalg-to-csl.mlir | 2 +- xdsl/transforms/linalg_to_csl.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/filecheck/transforms/linalg-to-csl.mlir b/tests/filecheck/transforms/linalg-to-csl.mlir index c33eaa7881..6a1f414421 100644 --- a/tests/filecheck/transforms/linalg-to-csl.mlir +++ b/tests/filecheck/transforms/linalg-to-csl.mlir @@ -42,5 +42,5 @@ builtin.module { //CHECK-NEXT: "csl.fmuls"(%0, %12, %0) : (memref<16xf32>, f32, memref<16xf32>) -> () //CHECK-NEXT: %c = arith.constant dense<2.997925e+08> : memref<16xf32> //CHECK-NEXT: %13 = arith.constant 2.997925e+08 : f32 -//CHECK-NEXT: "csl.fmacs"(%0, %0, %2, %13) : (memref<16xf32>, memref<16xf32>, memref<16xf32>, f32) -> () +//CHECK-NEXT: "csl.fmacs"(%0, %2, %0, %13) : (memref<16xf32>, memref<16xf32>, memref<16xf32>, f32) -> () //CHECK-NEXT: } diff --git a/xdsl/transforms/linalg_to_csl.py b/xdsl/transforms/linalg_to_csl.py index 065866c887..bbfbbee6fc 100644 --- a/xdsl/transforms/linalg_to_csl.py +++ b/xdsl/transforms/linalg_to_csl.py @@ -97,14 +97,14 @@ def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter, /): # one of the factors must be a scalar const, which the csl function signatures require if scalar_const := get_scalar_const(op.inputs[0]): rewriter.insert_op( - const_op := arith.Constant(scalar_const), InsertPoint.before(op) + a := arith.Constant(scalar_const), InsertPoint.before(op) ) - non_scalar = op.inputs[1] + x = op.inputs[1] elif scalar_const := get_scalar_const(op.inputs[1]): rewriter.insert_op( - const_op := arith.Constant(scalar_const), InsertPoint.before(op) + a := arith.Constant(scalar_const), InsertPoint.before(op) ) - non_scalar = op.inputs[0] + x = op.inputs[0] else: # if neither factor is a scalar, return return @@ -114,9 +114,11 @@ def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter, /): op.outputs.types[0].get_element_type(), f16=csl.FmachOp, f32=csl.FmacsOp ) - rewriter.replace_matched_op( - csl_op(operands=[[op.outputs[0], non_scalar, op.inputs[2], const_op]]) - ) + r = op.outputs[0] + y = op.inputs[2] + + # builds `r = a * x + y` + rewriter.replace_matched_op(csl_op(operands=[[r, y, x, a]])) @staticmethod def is_fma(op: linalg.Generic) -> bool: