diff --git a/inconspiquous/dialects/prob.py b/inconspiquous/dialects/prob.py index f144775..4ffe6ff 100644 --- a/inconspiquous/dialects/prob.py +++ b/inconspiquous/dialects/prob.py @@ -121,7 +121,7 @@ class FinSuppOp(IRDLOperation): def __init__( self, probabilities: Sequence[float] | DenseArrayBase, - default_value: SSAValue, + default_value: SSAValue | Operation, *ins: SSAValue | Operation, attr_dict: dict[str, Attribute] | None = None, ): diff --git a/inconspiquous/transforms/__init__.py b/inconspiquous/transforms/__init__.py index c8de6bb..c9d0139 100644 --- a/inconspiquous/transforms/__init__.py +++ b/inconspiquous/transforms/__init__.py @@ -41,6 +41,11 @@ def get_lower_dyn_gate_to_scf(): return lower_dyn_gate_to_scf.LowerDynGateToScf + def get_lower_to_fin_supp(): + from inconspiquous.transforms import lower_to_fin_supp + + return lower_to_fin_supp.LowerToFinSupp + def get_lower_xs_to_select(): from inconspiquous.transforms.xs import lower @@ -74,6 +79,7 @@ def get_xs_select(): "cse": get_cse, "dce": get_dce, "lower-dyn-gate-to-scf": get_lower_dyn_gate_to_scf, + "lower-to-fin-supp": get_lower_to_fin_supp, "lower-xs-to-select": get_lower_xs_to_select, "merge-xs": get_merge_xs, "mlir-opt": get_mlir_opt, diff --git a/inconspiquous/transforms/lower_to_fin_supp.py b/inconspiquous/transforms/lower_to_fin_supp.py new file mode 100644 index 0000000..ee53114 --- /dev/null +++ b/inconspiquous/transforms/lower_to_fin_supp.py @@ -0,0 +1,65 @@ +from xdsl.context import MLContext +from xdsl.dialects import builtin +from xdsl.dialects.builtin import BoolAttr +from xdsl.ir import Operation, dataclass +from xdsl.parser import IntegerType +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.dialects import arith +from xdsl.passes import ModulePass +from inconspiquous.dialects.prob import BernoulliOp, FinSuppOp, UniformOp + + +class LowerBernoulli(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: BernoulliOp, rewriter: PatternRewriter): + zero = arith.Constant(BoolAttr.from_bool(False)) + one = arith.Constant(BoolAttr.from_bool(True)) + rewriter.replace_matched_op( + (zero, one, FinSuppOp((op.prob.value.data,), zero, one)) + ) + + +@dataclass(frozen=True) +class LowerUniform(RewritePattern): + max_size: int + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: UniformOp, rewriter: PatternRewriter): + ty = op.out.type + if not isinstance(ty, IntegerType): + return + + if ty.bitwidth > self.max_size: + return + + zero = arith.Constant.from_int_and_width(0, ty.bitwidth) + ops: list[Operation] = [] + for i in range(1, 2**ty.bitwidth): + ops.append(arith.Constant.from_int_and_width(i, ty.bitwidth)) + + fin_supp = FinSuppOp( + tuple(1.0 / (2**ty.bitwidth) for _ in range(1, 2**ty.bitwidth)), zero, *ops + ) + + ops.append(zero) + ops.append(fin_supp) + + rewriter.replace_matched_op(ops) + + +@dataclass(frozen=True) +class LowerToFinSupp(ModulePass): + max_size: int + + name = "lower-to-fin-supp" + + def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: + PatternRewriteWalker( + GreedyRewritePatternApplier([LowerBernoulli(), LowerUniform(self.max_size)]) + ).rewrite_op(op) diff --git a/tests/filecheck/transforms/lower_to_fin_supp.mlir b/tests/filecheck/transforms/lower_to_fin_supp.mlir new file mode 100644 index 0000000..1d66626 --- /dev/null +++ b/tests/filecheck/transforms/lower_to_fin_supp.mlir @@ -0,0 +1,21 @@ +// RUN: quopt -p lower-to-fin-supp{max_size=2} %s | filecheck %s + +// CHECK: %0 = arith.constant false +// CHECK-NEXT: %1 = arith.constant true +// CHECK-NEXT: %2 = prob.fin_supp [ 0.1 of %1, else %0 ] : i1 +%0 = prob.bernoulli 0.1 : f64 + +// CHECK-NEXT: %3 = arith.constant true +// CHECK-NEXT: %4 = arith.constant false +// CHECK-NEXT: %5 = prob.fin_supp [ 0.5 of %3, else %4 ] : i1 +%1 = prob.uniform : i1 + +// CHECK-NEXT: %6 = arith.constant 1 : i2 +// CHECK-NEXT: %7 = arith.constant 2 : i2 +// CHECK-NEXT: %8 = arith.constant 3 : i2 +// CHECK-NEXT: %9 = arith.constant 0 : i2 +// CHECK-NEXT: %10 = prob.fin_supp [ 0.25 of %6, 0.25 of %7, 0.25 of %8, else %9 ] : i2 +%2 = prob.uniform : i2 + +// CHECK-NEXT: %11 = prob.uniform : i3 +%3 = prob.uniform : i3