Skip to content

Commit

Permalink
transforms: lower-to-fin-supp
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Nov 8, 2024
1 parent ae7cd0a commit 37f809a
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 1 deletion.
2 changes: 1 addition & 1 deletion inconspiquous/dialects/prob.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class FinSuppOp(IRDLOperation):
def __init__(
self,
probabilities: Sequence[float] | DenseArrayBase,
default_value: SSAValue,
default_value: SSAValue | Operation,
*ins: SSAValue | Operation,
attr_dict: dict[str, Attribute] | None = None,
):
Expand Down
6 changes: 6 additions & 0 deletions inconspiquous/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def get_lower_dyn_gate_to_scf():

return lower_dyn_gate_to_scf.LowerDynGateToScf

def get_lower_to_fin_supp():
from inconspiquous.transforms import lower_to_fin_supp

return lower_to_fin_supp.LowerToFinSupp

def get_lower_xs_to_select():
from inconspiquous.transforms.xs import lower

Expand Down Expand Up @@ -74,6 +79,7 @@ def get_xs_select():
"cse": get_cse,
"dce": get_dce,
"lower-dyn-gate-to-scf": get_lower_dyn_gate_to_scf,
"lower-to-fin-supp": get_lower_to_fin_supp,
"lower-xs-to-select": get_lower_xs_to_select,
"merge-xs": get_merge_xs,
"mlir-opt": get_mlir_opt,
Expand Down
65 changes: 65 additions & 0 deletions inconspiquous/transforms/lower_to_fin_supp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.dialects.builtin import BoolAttr
from xdsl.ir import Operation, dataclass
from xdsl.parser import IntegerType
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.dialects import arith
from xdsl.passes import ModulePass
from inconspiquous.dialects.prob import BernoulliOp, FinSuppOp, UniformOp


class LowerBernoulli(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: BernoulliOp, rewriter: PatternRewriter):
zero = arith.Constant(BoolAttr.from_bool(False))
one = arith.Constant(BoolAttr.from_bool(True))
rewriter.replace_matched_op(
(zero, one, FinSuppOp((op.prob.value.data,), zero, one))
)


@dataclass(frozen=True)
class LowerUniform(RewritePattern):
max_size: int

@op_type_rewrite_pattern
def match_and_rewrite(self, op: UniformOp, rewriter: PatternRewriter):
ty = op.out.type
if not isinstance(ty, IntegerType):
return

if ty.bitwidth > self.max_size:
return

zero = arith.Constant.from_int_and_width(0, ty.bitwidth)
ops: list[Operation] = []
for i in range(1, 2**ty.bitwidth):
ops.append(arith.Constant.from_int_and_width(i, ty.bitwidth))

fin_supp = FinSuppOp(
tuple(1.0 / (2**ty.bitwidth) for _ in range(1, 2**ty.bitwidth)), zero, *ops
)

ops.append(zero)
ops.append(fin_supp)

rewriter.replace_matched_op(ops)


@dataclass(frozen=True)
class LowerToFinSupp(ModulePass):
max_size: int

name = "lower-to-fin-supp"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier([LowerBernoulli(), LowerUniform(self.max_size)])
).rewrite_op(op)
21 changes: 21 additions & 0 deletions tests/filecheck/transforms/lower_to_fin_supp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: quopt -p lower-to-fin-supp{max_size=2} %s | filecheck %s

// CHECK: %0 = arith.constant false
// CHECK-NEXT: %1 = arith.constant true
// CHECK-NEXT: %2 = prob.fin_supp [ 0.1 of %1, else %0 ] : i1
%0 = prob.bernoulli 0.1 : f64

// CHECK-NEXT: %3 = arith.constant true
// CHECK-NEXT: %4 = arith.constant false
// CHECK-NEXT: %5 = prob.fin_supp [ 0.5 of %3, else %4 ] : i1
%1 = prob.uniform : i1

// CHECK-NEXT: %6 = arith.constant 1 : i2
// CHECK-NEXT: %7 = arith.constant 2 : i2
// CHECK-NEXT: %8 = arith.constant 3 : i2
// CHECK-NEXT: %9 = arith.constant 0 : i2
// CHECK-NEXT: %10 = prob.fin_supp [ 0.25 of %6, 0.25 of %7, 0.25 of %8, else %9 ] : i2
%2 = prob.uniform : i2

// CHECK-NEXT: %11 = prob.uniform : i3
%3 = prob.uniform : i3

0 comments on commit 37f809a

Please sign in to comment.