Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to (nearly) xdsl main #12

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions inconspiquous/alloc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,44 @@
from abc import ABC
from dataclasses import dataclass
from typing import Sequence
from xdsl.ir import Attribute, ParametrizedAttribute
from xdsl.irdl import WithRangeType
from xdsl.ir import Attribute, ParametrizedAttribute, VerifyException
from xdsl.irdl import (
ConstraintContext,
ConstraintVariableType,
GenericAttrConstraint,
RangeConstraint,
VarExtractor,
)


class AllocAttr(ParametrizedAttribute, WithRangeType, ABC):
class AllocAttr(ParametrizedAttribute, ABC):
def get_types(self) -> Sequence[Attribute]: ...


@dataclass(frozen=True)
class AllocConstraint(GenericAttrConstraint[AllocAttr]):
"""
Put a constraint on the result types of an alloc operation.
"""

type_constraint: RangeConstraint

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if not isinstance(attr, AllocAttr):
raise VerifyException(
f"attribute {attr} expected to be a allocation attribute"
)
self.type_constraint.verify(attr.get_types(), constraint_context)

@dataclass(frozen=True)
class _Extractor(VarExtractor[AllocAttr]):
inner: VarExtractor[Sequence[Attribute]]

def extract_var(self, a: AllocAttr) -> ConstraintVariableType:
return self.inner.extract_var(a.get_types())

def get_variable_extractors(self) -> dict[str, VarExtractor[AllocAttr]]:
return {
v: self._Extractor(r)
for v, r in self.type_constraint.get_variable_extractors().items()
}
45 changes: 42 additions & 3 deletions inconspiquous/dialects/gate.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from __future__ import annotations
import math
from typing import ClassVar
from dataclasses import dataclass

from xdsl.dialects.builtin import FloatAttr, Float64Type
from xdsl.ir import Dialect, Operation, ParametrizedAttribute, SSAValue, TypeAttribute
from xdsl.ir import (
Dialect,
Operation,
ParametrizedAttribute,
SSAValue,
TypeAttribute,
VerifyException,
Attribute,
)
from xdsl.irdl import (
AttrConstraint,
ConstraintContext,
ConstraintVariableType,
GenericAttrConstraint,
IRDLOperation,
ParameterDef,
VarConstraint,
WithTypeConstraint,
VarExtractor,
base,
irdl_attr_definition,
irdl_op_definition,
Expand Down Expand Up @@ -184,6 +196,33 @@ def print_parameters(self, printer: Printer) -> None:
printer.print_string(str(self.num_qubits.value.data))


@dataclass(frozen=True)
class GateTypeConstraint(GenericAttrConstraint[GateAttr]):
"""
Put a constraint on the gate type of a gate.
"""

type_constraint: GenericAttrConstraint[GateType]

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
if not isinstance(attr, GateAttr):
raise VerifyException(f"attribute {attr} expected to be a gate")
self.type_constraint.verify(GateType(attr.num_qubits), constraint_context)

@dataclass(frozen=True)
class _Extractor(VarExtractor[GateAttr]):
inner: VarExtractor[GateType]

def extract_var(self, a: GateAttr) -> ConstraintVariableType:
return self.inner.extract_var(GateType(a.num_qubits))

def get_variable_extractors(self) -> dict[str, VarExtractor[GateAttr]]:
return {
v: self._Extractor(r)
for v, r in self.type_constraint.get_variable_extractors().items()
}


@irdl_op_definition
class ConstantGateOp(IRDLOperation):
"""
Expand All @@ -194,7 +233,7 @@ class ConstantGateOp(IRDLOperation):

name = "gate.constant"

gate = prop_def(WithTypeConstraint(base(GateAttr), _T))
gate = prop_def(GateTypeConstraint(_T))

out = result_def(_T)

Expand Down
2 changes: 1 addition & 1 deletion inconspiquous/dialects/qssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GateOp(IRDLOperation):

outs = var_result_def(_Q)

assembly_format = "`<` $gate `>` $ins attr-dict `:` type($ins)"
assembly_format = "`<` $gate `>` $ins attr-dict `:` type($outs)"

traits = traits_def(GateOpHasCanonicalizationPatterns())

Expand Down
8 changes: 2 additions & 6 deletions inconspiquous/dialects/qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
from xdsl.ir import Attribute, Dialect, ParametrizedAttribute, TypeAttribute
from xdsl.irdl import (
AnyAttr,
BaseAttr,
IRDLOperation,
RangeConstraint,
RangeOf,
RangeVarConstraint,
WithRangeTypeConstraint,
irdl_attr_definition,
irdl_op_definition,
prop_def,
var_result_def,
)

from inconspiquous.alloc import AllocAttr
from inconspiquous.alloc import AllocAttr, AllocConstraint


@irdl_attr_definition
Expand Down Expand Up @@ -44,9 +42,7 @@ class AllocOp(IRDLOperation):

_T: ClassVar[RangeConstraint] = RangeVarConstraint("T", RangeOf(AnyAttr()))

alloc = prop_def(
WithRangeTypeConstraint(BaseAttr(AllocAttr), _T), default_value=AllocZeroAttr()
)
alloc = prop_def(AllocConstraint(_T), default_value=AllocZeroAttr())

outs = var_result_def(_T)

Expand Down
83 changes: 34 additions & 49 deletions inconspiquous/gates/constraints.py
Original file line number Diff line number Diff line change
@@ -1,70 +1,55 @@
from collections.abc import Set
from dataclasses import dataclass
from xdsl.ir import Attribute, VerifyException
from xdsl.irdl import ConstraintContext, GenericAttrConstraint, RangeVarConstraint, base

from xdsl.irdl import (
ConstraintContext,
GenericAttrConstraint,
GenericRangeConstraint,
InferenceContext,
RangeVarConstraint,
)

from inconspiquous.dialects import qubit
from inconspiquous.dialects.gate import GateType
from inconspiquous.gates import GateAttr


@dataclass(frozen=True)
class GateConstraint(GenericAttrConstraint[GateAttr]):
"""
Constrains a given range variable to have the correct size for the gate.
"""

range_var: str
gate_constraint: GenericAttrConstraint[GateAttr]

def __init__(
self,
range_constraint: str | RangeVarConstraint[Attribute],
gate_constraint: GenericAttrConstraint[GateAttr] = base(GateAttr),
):
if isinstance(range_constraint, str):
self.range_var = range_constraint
else:
self.range_var = range_constraint.name
self.gate_constraint = gate_constraint
range_constraint: GenericRangeConstraint[qubit.BitType]

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
self.gate_constraint.verify(attr, constraint_context)
if self.range_var in constraint_context.range_variables:
attrs = constraint_context.get_range_variable(self.range_var)
assert isinstance(attr, GateAttr)
if attr.num_qubits != len(attrs):
raise VerifyException(
f"Gate {attr} expected {attr.num_qubits} qubits but got {len(attrs)}"
)
if not isinstance(attr, GateAttr):
raise VerifyException(f"attribute {attr} expected to be a gate")
self.range_constraint.verify(
(qubit.BitType(),) * attr.num_qubits, constraint_context
)


@dataclass(frozen=True)
class DynGateConstraint(GenericAttrConstraint[GateType]):
"""
Constrains a given range variable to have the correct size for the dynamic gate.
"""

range_var: str

def __init__(
self,
range_constraint: str | RangeVarConstraint[Attribute],
):
if isinstance(range_constraint, str):
self.range_var = range_constraint
else:
self.range_var = range_constraint.name
range_constraint: RangeVarConstraint[qubit.BitType]

def verify(self, attr: Attribute, constraint_context: ConstraintContext) -> None:
base(GateType).verify(attr, constraint_context)
if self.range_var in constraint_context.range_variables:
attrs = constraint_context.get_range_variable(self.range_var)
assert isinstance(attr, GateType)
num_qubits = attr.num_qubits.value.data
if num_qubits != len(attrs):
raise VerifyException(
f"Gate input expected {num_qubits} qubits but got {len(attrs)}"
)

def can_infer(self, constraint_names: set[str]) -> bool:
return self.range_var in constraint_names

def infer(self, constraint_context: ConstraintContext) -> Attribute:
types = constraint_context.get_range_variable(self.range_var)
return GateType(len(types))
if not isinstance(attr, GateType):
raise VerifyException(f"type {attr} expected to be a gate type")
self.range_constraint.verify(
(qubit.BitType(),) * attr.num_qubits.value.data, constraint_context
)

def can_infer(self, var_constraint_names: Set[str]) -> bool:
return self.range_constraint.name in var_constraint_names

def infer(self, context: InferenceContext) -> GateType:
range_type = self.range_constraint.infer(
0, context
) # We know a range constraint does not use the input length
return GateType(len(range_type))
5 changes: 1 addition & 4 deletions inconspiquous/gates/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,9 @@
Attribute,
ParametrizedAttribute,
)
from xdsl.irdl import (
WithType,
)


class GateAttr(ParametrizedAttribute, WithType, ABC):
class GateAttr(ParametrizedAttribute, ABC):
"""
In general most gate operations are not operationally different, so differentiating between them
may actually be better done via an attribute that can be attached to a gate operation.
Expand Down
2 changes: 1 addition & 1 deletion inconspiquous/transforms/convert_qssa_to_qref.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
]
),
apply_recursively=False,
).rewrite_op(op)
).rewrite_module(op)
2 changes: 1 addition & 1 deletion inconspiquous/transforms/lower_dyn_gate_to_scf.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ class LowerDynGateToScf(ModulePass):
name = "lower-dyn-gate-to-scf"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(LowerDynGateToScfPattern()).rewrite_op(op)
PatternRewriteWalker(LowerDynGateToScfPattern()).rewrite_module(op)
2 changes: 1 addition & 1 deletion inconspiquous/transforms/lower_to_fin_supp.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,4 @@ class LowerToFinSupp(ModulePass):
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier([LowerBernoulli(), LowerUniform(self.max_size)])
).rewrite_op(op)
).rewrite_module(op)
2 changes: 1 addition & 1 deletion inconspiquous/transforms/randomized_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,4 @@ def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier([PadTGate(), PadHadamardGate(), PadCNotGate()]),
apply_recursively=False, # Do not reapply
).rewrite_op(op)
).rewrite_module(op)
2 changes: 1 addition & 1 deletion inconspiquous/transforms/xs/convert_to_xs.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,4 @@ class ConvertToXS(ModulePass):
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier([ToDynGate(), ToXSGate()])
).rewrite_op(op)
).rewrite_module(op)
2 changes: 1 addition & 1 deletion inconspiquous/transforms/xs/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ class LowerXSToSelect(ModulePass):
name = "lower-xs-to-select"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(LowerXSToSelectPattern()).rewrite_op(op)
PatternRewriteWalker(LowerXSToSelectPattern()).rewrite_module(op)
2 changes: 1 addition & 1 deletion inconspiquous/transforms/xs/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ class MergeXSGates(ModulePass):
name = "merge-xs"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(MergeXSGatesPattern()).rewrite_op(op)
PatternRewriteWalker(MergeXSGatesPattern()).rewrite_module(op)
2 changes: 1 addition & 1 deletion inconspiquous/transforms/xs/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ class XSSelect(ModulePass):
name = "xs-select"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(XSSelectPattern()).rewrite_op(op)
PatternRewriteWalker(XSSelectPattern()).rewrite_module(op)
2 changes: 1 addition & 1 deletion tests/filecheck/dialects/prob/ops.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: QUOPT_ROUNDTRIP
// RUN: QUOPT_GENERIC_ROUNDTRIP

// CHECK: %{{.*}} = prob.bernoulli 5.000000e-01 : f64
// CHECK: %{{.*}} = prob.bernoulli 5.000000e-01
// CHECK-GENERIC: %{{.*}} = "prob.bernoulli"() <{"prob" = 5.000000e-01 : f64}> : () -> i1
%0 = prob.bernoulli 0.5

Expand Down
4 changes: 2 additions & 2 deletions tests/filecheck/dialects/qssa/gate_counts.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

%q0 = qubit.alloc

// CHECK: Gate #gate.cnot expected 2 qubits but got 1
// CHECK: attributes ('!qubit.bit',) expected from range variable 'Q', but got ('!qubit.bit', '!qubit.bit')
%q1 = "qssa.gate"(%q0) <{"gate" = #gate.cnot}> : (!qubit.bit) -> !qubit.bit

// -----

%g = "test.op"() : () -> !gate.type<2>
%q0 = qubit.alloc

// CHECK: Gate input expected 2 qubits but got 1
// CHECK: attributes ('!qubit.bit',) expected from range variable 'Q', but got ('!qubit.bit', '!qubit.bit')
%q1 = "qssa.dyn_gate"(%q0, %g) : (!qubit.bit, !gate.type<2>) -> !qubit.bit
6 changes: 3 additions & 3 deletions tests/filecheck/examples/depolarising.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: QUOPT_ROUNDTRIP

// CHECK: func.func @depolarising_dyn(%q : !qubit.bit) -> !qubit.bit {
// CHECK-NEXT: %p = prob.bernoulli 1.000000e-01 : f64
// CHECK-NEXT: %p = prob.bernoulli 1.000000e-01
// CHECK-NEXT: %id = gate.constant #gate.id
// CHECK-NEXT: %p2 = prob.uniform : i2
// CHECK-NEXT: %x = gate.constant #gate.x
Expand Down Expand Up @@ -36,7 +36,7 @@ func.func @depolarising_dyn(%q : !qubit.bit) -> !qubit.bit {
}

// CHECK: func.func @depolarising_scf(%q : !qubit.bit) -> !qubit.bit {
// CHECK-NEXT: %p = prob.bernoulli 1.000000e-01 : f64
// CHECK-NEXT: %p = prob.bernoulli 1.000000e-01
// CHECK-NEXT: %q3 = scf.if %p -> (!qubit.bit) {
// CHECK-NEXT: %p2 = prob.uniform : i4
// CHECK-NEXT: %p3 = arith.index_cast %p2 : i4 to index
Expand Down Expand Up @@ -91,7 +91,7 @@ func.func @depolarising_scf(%q : !qubit.bit) -> !qubit.bit {
}

// CHECK: func.func @depolarising_cf(%q : !qubit.bit) -> !qubit.bit {
// CHECK-NEXT: %p = prob.bernoulli 1.000000e-01 : f64
// CHECK-NEXT: %p = prob.bernoulli 1.000000e-01
// CHECK-NEXT: cf.cond_br %p, ^0, ^1(%q : !qubit.bit)
// CHECK-NEXT: ^0:
// CHECK-NEXT: %p2 = prob.uniform : i4
Expand Down
4 changes: 2 additions & 2 deletions tests/filecheck/examples/double_phase.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
// RUN: QUOPT_ROUNDTRIP

// CHECK: func.func @double_phase(%q : !qubit.bit) -> !qubit.bit {
// CHECK-NEXT: %p = prob.bernoulli 1.000000e-01 : f64
// CHECK-NEXT: %p = prob.bernoulli 1.000000e-01
// CHECK-NEXT: %id = gate.constant #gate.id
// CHECK-NEXT: %z = gate.constant #gate.z
// CHECK-NEXT: %g = arith.select %p, %z, %id : !gate.type<1>
// CHECK-NEXT: %q1 = qssa.dyn_gate<%g> %q : !qubit.bit
// CHECK-NEXT: %p2 = prob.bernoulli 1.000000e-01 : f64
// CHECK-NEXT: %p2 = prob.bernoulli 1.000000e-01
// CHECK-NEXT: %id2 = gate.constant #gate.id
// CHECK-NEXT: %z2 = gate.constant #gate.z
// CHECK-NEXT: %g2 = arith.select %p2, %z2, %id2 : !gate.type<1>
Expand Down
Loading
Loading