Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (measurement) add dialect and attributes for measurement #41

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions inconspiquous/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def get_linalg():

return Linalg

def get_measurement():
from inconspiquous.dialects.measurement import Measurement

return Measurement

def get_prob():
from inconspiquous.dialects.prob import Prob

Expand Down Expand Up @@ -93,6 +98,7 @@ def get_varith():
"func": get_func,
"gate": get_gate,
"linalg": get_linalg,
"measurement": get_measurement,
"prob": get_prob,
"qec": get_qec,
"qref": get_qref,
Expand Down
62 changes: 62 additions & 0 deletions inconspiquous/dialects/measurement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from xdsl.ir import Dialect
from xdsl.irdl import (
ParameterDef,
irdl_attr_definition,
)
from xdsl.parser import AttrParser
from xdsl.printer import Printer
from inconspiquous.dialects.gate import AngleAttr
from inconspiquous.measurement import MeasurementAttr


@irdl_attr_definition
class CompBasisMeasurementAttr(MeasurementAttr):
"""
A computational basis measurement attribute.
"""

name = "measurement.comp_basis"

@property
def num_qubits(self) -> int:
return 1


@irdl_attr_definition
class XYMeasurementAttr(MeasurementAttr):
"""
An XY plane measurement attribute with specified angle.
"""

name = "measurement.xy"

angle: ParameterDef[AngleAttr]

def __init__(self, angle: float | AngleAttr):
if not isinstance(angle, AngleAttr):
angle = AngleAttr(angle)

super().__init__((angle,))

@classmethod
def parse_parameters(cls, parser: AttrParser) -> tuple[AngleAttr]:
return (AngleAttr.new(AngleAttr.parse_parameters(parser)),)

def print_parameters(self, printer: Printer) -> None:
return self.angle.print_parameters(printer)

@property
def num_qubits(self) -> int:
return 1


Measurement = Dialect(
"measurement",
[],
[
CompBasisMeasurementAttr,
XYMeasurementAttr,
],
)
30 changes: 23 additions & 7 deletions inconspiquous/dialects/qref.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
irdl_op_definition,
operand_def,
prop_def,
result_def,
traits_def,
var_operand_def,
eq,
var_result_def,
)
from xdsl.pattern_rewriter import RewritePattern
from xdsl.traits import HasCanonicalizationPatternsTrait

from inconspiquous.dialects.gate import GateType
from inconspiquous.dialects.measurement import CompBasisMeasurementAttr
from inconspiquous.gates import GateAttr
from inconspiquous.dialects.qubit import BitType
from inconspiquous.constraints import SizedAttributeConstraint
from inconspiquous.measurement import MeasurementAttr


@irdl_op_definition
Expand Down Expand Up @@ -76,16 +78,30 @@ def __init__(self, gate: SSAValue | Operation, *ins: SSAValue | Operation):
class MeasureOp(IRDLOperation):
name = "qref.measure"

in_qubit = operand_def(BitType())
_I: ClassVar = IntVarConstraint("I", AnyInt())

measurement = prop_def(
SizedAttributeConstraint(MeasurementAttr, _I),
default_value=CompBasisMeasurementAttr(),
)

in_qubits = var_operand_def(RangeOf(eq(BitType()), length=_I))

out = result_def(i1)
out = var_result_def(RangeOf(eq(i1), length=_I))

assembly_format = "$in_qubit attr-dict"
assembly_format = "(`` `<` $measurement^ `>`)? $in_qubits attr-dict"

def __init__(self, in_qubit: SSAValue | Operation):
def __init__(
self,
*in_qubits: SSAValue | Operation,
measurement: MeasurementAttr = CompBasisMeasurementAttr(),
):
super().__init__(
operands=(in_qubit,),
result_types=(i1,),
properties={
"measurement": measurement,
},
operands=(in_qubits,),
result_types=((i1,) * len(in_qubits)),
)


Expand Down
29 changes: 22 additions & 7 deletions inconspiquous/dialects/qssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
irdl_op_definition,
operand_def,
prop_def,
result_def,
traits_def,
var_operand_def,
var_result_def,
Expand All @@ -19,9 +18,11 @@
from xdsl.traits import HasCanonicalizationPatternsTrait

from inconspiquous.dialects.gate import GateType
from inconspiquous.dialects.measurement import CompBasisMeasurementAttr
from inconspiquous.gates import GateAttr
from inconspiquous.dialects.qubit import BitType
from inconspiquous.constraints import SizedAttributeConstraint
from inconspiquous.measurement import MeasurementAttr


class GateOpHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
Expand Down Expand Up @@ -96,16 +97,30 @@ def __init__(self, gate: SSAValue | Operation, *ins: SSAValue | Operation):
class MeasureOp(IRDLOperation):
name = "qssa.measure"

in_qubit = operand_def(BitType())
_I: ClassVar = IntVarConstraint("I", AnyInt())

measurement = prop_def(
SizedAttributeConstraint(MeasurementAttr, _I),
default_value=CompBasisMeasurementAttr(),
)

in_qubits = var_operand_def(RangeOf(eq(BitType()), length=_I))

out = result_def(i1)
out = var_result_def(RangeOf(eq(i1), length=_I))

assembly_format = "$in_qubit attr-dict"
assembly_format = "(`` `<` $measurement^ `>`)? $in_qubits attr-dict"

def __init__(self, in_qubit: SSAValue | Operation):
def __init__(
self,
*in_qubits: SSAValue | Operation,
measurement: MeasurementAttr = CompBasisMeasurementAttr(),
):
super().__init__(
operands=(in_qubit,),
result_types=(i1,),
properties={
"measurement": measurement,
},
operands=(in_qubits,),
result_types=((i1,) * len(in_qubits)),
)


Expand Down
14 changes: 14 additions & 0 deletions inconspiquous/measurement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from xdsl.ir import ParametrizedAttribute

from inconspiquous.constraints import SizedAttribute


class MeasurementAttr(ParametrizedAttribute, SizedAttribute, ABC):
@property
@abstractmethod
def num_qubits(self) -> int: ...

@property
def size(self) -> int:
return self.num_qubits
17 changes: 9 additions & 8 deletions inconspiquous/transforms/convert_qref_to_qssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@ def match_and_rewrite(self, op: qref.MeasureOp, rewriter: PatternRewriter):
# Don't rewrite if uses live in different blocks
if op.parent_block() is None:
return
for use in op.in_qubit.uses:
if use.operation.parent_block() != op.parent_block():
return
for operand in op.in_qubits:
for use in operand.uses:
if use.operation.parent_block() != op.parent_block():
return

# Don't rewrite if there are further uses of the measured qubit
if len(op.in_qubit.uses) != 1:
return
# Don't rewrite if there are further uses of the measured qubit
if len(operand.uses) != 1:
return

new_op = qssa.MeasureOp(op.in_qubit)
new_op = qssa.MeasureOp(*op.in_qubits, measurement=op.measurement)

rewriter.replace_matched_op(new_op, (new_op.out,))
rewriter.replace_matched_op(new_op, new_op.out)


class ConvertQrefToQssa(ModulePass):
Expand Down
2 changes: 1 addition & 1 deletion inconspiquous/transforms/convert_qssa_to_qref.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ConvertQssaMeasureToQrefMeasure(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: qssa.MeasureOp, rewriter: PatternRewriter):
new_measure = qref.MeasureOp(op.in_qubit)
new_measure = qref.MeasureOp(*op.in_qubits, measurement=op.measurement)
rewriter.replace_matched_op(new_measure)


Expand Down
8 changes: 6 additions & 2 deletions inconspiquous/transforms/randomized_comp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
XGate,
ZGate,
)
from inconspiquous.dialects.measurement import CompBasisMeasurementAttr
from inconspiquous.dialects.qssa import DynGateOp, GateOp, MeasureOp
from inconspiquous.dialects.prob import UniformOp

Expand Down Expand Up @@ -259,6 +260,9 @@ class PadMeasure(RewritePattern):

@op_type_rewrite_pattern
def match_and_rewrite(self, op: MeasureOp, rewriter: PatternRewriter):
if op.measurement != CompBasisMeasurementAttr():
# Only try to pad computation basis measurements
return
x_rand = UniformOp(i1)
z_rand = UniformOp(i1)

Expand All @@ -267,13 +271,13 @@ def match_and_rewrite(self, op: MeasureOp, rewriter: PatternRewriter):
z_gate = ConstantGateOp(ZGate())

pre_x_sel = SelectOp(x_rand, x_gate, id_gate)
pre_x = DynGateOp(pre_x_sel, op.in_qubit)
pre_x = DynGateOp(pre_x_sel, *op.in_qubits)
pre_z_sel = SelectOp(z_rand, z_gate, id_gate)
pre_z = DynGateOp(pre_z_sel, pre_x)

new_measure = MeasureOp(pre_z)

corrected_measure = AddiOp(x_rand, new_measure.out)
corrected_measure = AddiOp(x_rand, new_measure.out[0])

rewriter.insert_op(
(
Expand Down
5 changes: 4 additions & 1 deletion inconspiquous/transforms/xzs/commute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
HadamardGate,
XZOp,
)
from inconspiquous.dialects.measurement import CompBasisMeasurementAttr
from inconspiquous.transforms.xzs.merge import MergeXZGatesPattern
from inconspiquous.utils.linear_walker import LinearWalker

Expand All @@ -34,8 +35,10 @@ def match_and_rewrite(self, op1: qssa.DynGateOp, rewriter: PatternRewriter):
op2 = use.operation

if isinstance(op2, qssa.MeasureOp):
if not isinstance(op2.measurement, CompBasisMeasurementAttr):
return
new_op2 = qssa.MeasureOp(op1.ins[0])
new_op1 = arith.AddiOp(new_op2.out, gate.x)
new_op1 = arith.AddiOp(new_op2.out[0], gate.x)

rewriter.replace_op(op2, (new_op2, new_op1))
rewriter.erase_op(op1)
Expand Down
2 changes: 1 addition & 1 deletion inconspiquous/utils/qssa_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ def measure(self, ref: QubitRef, *, name_hint: str | None = None) -> SSAValue:
if ImplicitBuilder.get() is None:
self.insert(new_op)
ref.qubit = None
out = new_op.out
out = new_op.out[0]
out.name_hint = name_hint
return out
2 changes: 1 addition & 1 deletion tests/filecheck/dialects/qref/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ qref.gate<#gate.cnot> %q0, %q1
qref.dyn_gate<%g1> %q1

// CHECK: %{{.*}} = qref.measure %q0
// CHECK-GENERIC: %{{.*}} = "qref.measure"(%q0) : (!qubit.bit) -> i1
// CHECK-GENERIC: %{{.*}} = "qref.measure"(%q0) <{measurement = #measurement.comp_basis}> : (!qubit.bit) -> i1
%0 = qref.measure %q0
6 changes: 5 additions & 1 deletion tests/filecheck/dialects/qssa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,9 @@
%q6 = qssa.dyn_gate<%g1> %q5

// CHECK: %{{.*}} = qssa.measure %q4
// CHECK-GENERIC: %{{.*}} = "qssa.measure"(%q4) : (!qubit.bit) -> i1
// CHECK-GENERIC: %{{.*}} = "qssa.measure"(%q4) <{measurement = #measurement.comp_basis}> : (!qubit.bit) -> i1
%0 = qssa.measure %q4

// CHECK: %{{.*}} = qssa.measure<#measurement.xy<0.5pi>> %q6
// CHECK-GENERIC: %{{.*}} = "qssa.measure"(%q6) <{measurement = #measurement.xy<0.5pi>}> : (!qubit.bit) -> i1
%1 = qssa.measure<#measurement.xy<0.5pi>> %q6