Skip to content

Commit

Permalink
dialects: (angle) add cond_negate
Browse files Browse the repository at this point in the history
  • Loading branch information
alexarice committed Jan 31, 2025
1 parent 9d525a6 commit 7d18700
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 6 deletions.
43 changes: 39 additions & 4 deletions inconspiquous/dialects/angle.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from __future__ import annotations

import math
from xdsl.ir import Dialect, ParametrizedAttribute, TypeAttribute
from xdsl.ir import Dialect, Operation, ParametrizedAttribute, SSAValue, TypeAttribute
from xdsl.irdl import (
IRDLOperation,
ParameterDef,
irdl_attr_definition,
irdl_op_definition,
operand_def,
prop_def,
result_def,
traits_def,
)
from xdsl.dialects.builtin import Float64Type, FloatAttr
from xdsl.dialects.builtin import Float64Type, FloatAttr, i1
from xdsl.parser import AttrParser
from xdsl.pattern_rewriter import RewritePattern
from xdsl.printer import Printer
from xdsl.traits import ConstantLike, Pure
from xdsl.traits import ConstantLike, HasCanonicalizationPatternsTrait, Pure


@irdl_attr_definition
Expand Down Expand Up @@ -113,4 +115,37 @@ def __init__(self, angle: AngleAttr):
)


Angle = Dialect("angle", [ConstantAngleOp], [AngleAttr, AngleType])
class CondNegateAngleOpHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from inconspiquous.transforms.canonicalization import angle

return (
angle.CondNegateAngleOpFoldPattern(),
angle.CondNegateAngleOpAssocPattern(),
)


@irdl_op_definition
class CondNegateAngleOp(IRDLOperation):
"""
Negates an angle if input condition is true.
"""

name = "angle.cond_negate"

cond = operand_def(i1)

angle = operand_def(AngleType)

out = result_def(AngleType)

traits = traits_def(CondNegateAngleOpHasCanonicalizationPatterns(), Pure())

assembly_format = "$cond `,` $angle attr-dict"

def __init__(self, cond: SSAValue | Operation, angle: SSAValue | Operation):
super().__init__(operands=(cond, angle), result_types=(AngleType(),))


Angle = Dialect("angle", [ConstantAngleOp, CondNegateAngleOp], [AngleAttr, AngleType])
39 changes: 39 additions & 0 deletions inconspiquous/transforms/canonicalization/angle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from xdsl.dialects.arith import XOrIOp
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.transforms.canonicalization_patterns.utils import const_evaluate_operand

from inconspiquous.dialects.angle import CondNegateAngleOp, ConstantAngleOp


class CondNegateAngleOpFoldPattern(RewritePattern):
"""
Fold an angle.cond_negate when both arguments are constant.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: CondNegateAngleOp, rewriter: PatternRewriter):
if (cond := const_evaluate_operand(op.cond)) is None:
return
if not cond:
rewriter.replace_matched_op((), (op.angle,))

if isinstance(op.angle.owner, ConstantAngleOp):
rewriter.replace_matched_op(ConstantAngleOp(-op.angle.owner.angle))


class CondNegateAngleOpAssocPattern(RewritePattern):
"""
Reassociate two conditional negations to a conditional negation on
the xor of the conditions.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: CondNegateAngleOp, rewriter: PatternRewriter):
if not isinstance(op.angle.owner, CondNegateAngleOp):
return
xor = XOrIOp(op.cond, op.angle.owner.cond)
rewriter.replace_matched_op((xor, CondNegateAngleOp(xor, op.angle.owner.angle)))
28 changes: 28 additions & 0 deletions tests/filecheck/dialects/angle/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: quopt %s -p canonicalize | filecheck %s

%c0 = arith.constant false

%a = "test.op"() : () -> !angle.type

%b = angle.cond_negate %c0, %a

// CHECK: "test.op"(%a) {cond_false} : (!angle.type) -> ()
"test.op"(%b) {cond_false} : (!angle.type) -> ()

%c1 = arith.constant true

%c = angle.constant<0.5pi>

%d = angle.cond_negate %c1, %c
// CHECK: [[const:%.*]] = angle.constant<1.5pi>
// CHECK: "test.op"([[const]]) : (!angle.type) -> ()
"test.op"(%d) : (!angle.type) -> ()

%x, %y = "test.op"() : () -> (i1, i1)

%e = angle.cond_negate %x, %a
%f = angle.cond_negate %y, %e
// CHECK: [[xor:%.*]] = arith.xori %y, %x
// CHECK: [[assoc:%.*]] = angle.cond_negate [[xor]], %a
// CHECK: "test.op"([[assoc]]) : (!angle.type) -> ()
"test.op"(%f) : (!angle.type) -> ()
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
// CHECK-NEXT: "test.op"() {angle = #angle.attr<pi>} : () -> ()
"test.op"() {angle = #angle.attr<-pi>} : () -> ()

// CHECK-NEXT: %{{.*}} = angle.constant<pi>
// CHECK-GENERIC: %{{.*}} = "angle.constant"() <{angle = #angle.attr<pi>}> : () -> !angle.type
// CHECK-NEXT: %a = angle.constant<pi>
// CHECK-GENERIC: %a = "angle.constant"() <{angle = #angle.attr<pi>}> : () -> !angle.type
%a = angle.constant<pi>

%0 = "test.op"() : () -> i1

// CHECK: %a2 = angle.cond_negate %0, %a
// CHECK-GENERIC: %a2 = "angle.cond_negate"(%0, %a) : (i1, !angle.type) -> !angle.type
%a2 = angle.cond_negate %0, %a

0 comments on commit 7d18700

Please sign in to comment.