Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: (linalg-to-csl) Lower generic to fmac(h|s) #3345

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/filecheck/transforms/linalg-to-csl.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: xdsl-opt %s -p linalg-to-csl | filecheck %s

#map = affine_map<(d0) -> (d0)>

builtin.module {
%0, %1, %2, %3, %4 = "test.op"() : () -> (memref<16xf32>, memref<16xf32>, memref<16xf32>, memref<16xf32>, memref<16xf32>)
linalg.add ins(%1, %2 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>)
Expand All @@ -14,6 +16,14 @@ builtin.module {
%10 = arith.constant dense<1.123400e-01> : memref<16xf32>
linalg.add ins(%0, %10 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>)
linalg.mul ins(%10, %0 : memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>)

%c = arith.constant dense<2.99792458e+08> : memref<16xf32>
linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = ["parallel"]} ins(%0, %c, %2 : memref<16xf32>, memref<16xf32>, memref<16xf32>) outs(%0 : memref<16xf32>) {
^bb0(%in: f32, %in_0: f32, %in_1: f32, %out: f32):
%11 = arith.mulf %in, %in_0 : f32
%12 = arith.addf %11, %in_1 : f32
linalg.yield %12 : f32
}
}

//CHECK-NEXT: builtin.module {
Expand All @@ -30,4 +40,7 @@ builtin.module {
//CHECK-NEXT: "csl.fadds"(%0, %0, %11) : (memref<16xf32>, memref<16xf32>, f32) -> ()
//CHECK-NEXT: %12 = arith.constant 1.123400e-01 : f32
//CHECK-NEXT: "csl.fmuls"(%0, %12, %0) : (memref<16xf32>, f32, memref<16xf32>) -> ()
//CHECK-NEXT: %c = arith.constant dense<2.997925e+08> : memref<16xf32>
//CHECK-NEXT: %13 = arith.constant 2.997925e+08 : f32
//CHECK-NEXT: "csl.fmacs"(%0, %0, %2, %13) : (memref<16xf32>, memref<16xf32>, memref<16xf32>, f32) -> ()
//CHECK-NEXT: }
102 changes: 80 additions & 22 deletions xdsl/transforms/linalg_to_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ModuleOp,
)
from xdsl.dialects.csl import csl
from xdsl.ir import OpResult, SSAValue
from xdsl.ir import Attribute, OpResult, SSAValue
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand All @@ -25,6 +25,31 @@
from xdsl.utils.hints import isa


def match_op_for_precision(
prec: Attribute, f16: type[csl.BuiltinDsdOp], f32: type[csl.BuiltinDsdOp]
) -> type[csl.BuiltinDsdOp]:
"""Returns the op type matching a given precision."""
# todo support mixed-precision
match prec:
case Float16Type():
return f16
case Float32Type():
return f32
case _:
raise ValueError(f"Unsupported element type {prec}")


def get_scalar_const(op: SSAValue) -> AnyFloatAttr | AnyIntegerAttr | None:
"""Returns the value of a scalar arith.constant, or None if not a constant or not scalar)."""
if (
isinstance(op, OpResult)
and isinstance(op.op, arith.Constant)
and isa(val := op.op.value, DenseIntOrFPElementsAttr)
and val.data.data.count(val.data.data[0]) == len(val.data.data)
):
return val.data.data[0]


class ConvertBinaryLinalgOp(RewritePattern):
"""
Base class for converting binary linalg operations.
Expand All @@ -37,47 +62,79 @@ def transform_op(
f16: type[csl.BuiltinDsdOp],
f32: type[csl.BuiltinDsdOp],
):
if not isa(op.outputs.types[0], AnyMemRefType):
if not isa(target_t := op.outputs.types[0], AnyMemRefType):
return

match op.outputs.types[0].get_element_type():
case Float16Type():
builtin = f16
case Float32Type():
builtin = f32
case _:
raise ValueError(
f"Unsupported element type {op.outputs.types[0].get_element_type()}"
)
builtin = match_op_for_precision(target_t.get_element_type(), f16, f32)

lhs = op.inputs[0]
rhs = op.inputs[1]

# binary functions translated here support mixing scalar and collection operands
# may need revisiting if more functions are translated
if scalar_const := self._get_scalar_const(lhs):
if scalar_const := get_scalar_const(lhs):
rewriter.insert_op(
const_op := arith.Constant(scalar_const), InsertPoint.before(op)
)
lhs = const_op.result
elif scalar_const := self._get_scalar_const(rhs):
elif scalar_const := get_scalar_const(rhs):
rewriter.insert_op(
const_op := arith.Constant(scalar_const), InsertPoint.before(op)
)
rhs = const_op.result

rewriter.replace_matched_op(builtin(operands=[[op.outputs[0], lhs, rhs]]))


class ConvertLinalgGenericFMAPass(RewritePattern):
"""Lowers `linalg.generic` fused multiply-adds to csl builtin ops."""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: linalg.Generic, rewriter: PatternRewriter, /):
if not self.is_fma(op) or not isa(op.outputs.types[0], AnyMemRefType):
return

# one of the factors must be a scalar const, which the csl function signatures require
if scalar_const := get_scalar_const(op.inputs[0]):
rewriter.insert_op(
const_op := arith.Constant(scalar_const), InsertPoint.before(op)
)
non_scalar = op.inputs[1]
elif scalar_const := get_scalar_const(op.inputs[1]):
rewriter.insert_op(
const_op := arith.Constant(scalar_const), InsertPoint.before(op)
)
non_scalar = op.inputs[0]
else:
# if neither factor is a scalar, return
return

# fetch the csl op to build depending on the precision
csl_op = match_op_for_precision(
op.outputs.types[0].get_element_type(), f16=csl.FmachOp, f32=csl.FmacsOp
n-io marked this conversation as resolved.
Show resolved Hide resolved
)

rewriter.replace_matched_op(
csl_op(operands=[[op.outputs[0], non_scalar, op.inputs[2], const_op]])
n-io marked this conversation as resolved.
Show resolved Hide resolved
)

@staticmethod
def _get_scalar_const(op: SSAValue) -> AnyFloatAttr | AnyIntegerAttr | None:
"""Returns the value of a scalar arith.constant, or None if not a constant or not scalar)."""
if (
isinstance(op, OpResult)
and isinstance(op.op, arith.Constant)
and isa(val := op.op.value, DenseIntOrFPElementsAttr)
and val.data.data.count(val.data.data[0]) == len(val.data.data)
):
return val.data.data[0]
def is_fma(op: linalg.Generic) -> bool:
"""Returns if a given `generic` op is a fused multiply-add"""
return (
len(op.inputs) == 3
and len(op.outputs) == 1
and len((block := op.body.block).args) == 4
and len(block.ops) == 3
and isinstance(mul := block.first_op, arith.Mulf)
and mul.lhs == block.args[0]
and mul.rhs == block.args[1]
and isinstance(add := mul.next_op, arith.Addf)
and add.lhs == mul.result
and add.rhs == block.args[2]
and isinstance(yld := add.next_op, linalg.YieldOp)
and yld.operands[0] == add.result
)


class ConvertLinalgAddPass(ConvertBinaryLinalgOp):
Expand Down Expand Up @@ -112,6 +169,7 @@ def apply(self, ctx: MLContext, op: ModuleOp) -> None:
module_pass = PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ConvertLinalgGenericFMAPass(),
ConvertLinalgAddPass(),
ConvertLinalgSubPass(),
ConvertLinalgMulPass(),
Expand Down
Loading