Skip to content

Commit

Permalink
transformations: (randomized-comp) add measurement padding (#28)
Browse files Browse the repository at this point in the history
Adds randomized padding for measurement operations.
  • Loading branch information
alexarice authored Jan 7, 2025
1 parent 3cdbc18 commit cd4da13
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
56 changes: 53 additions & 3 deletions inconspiquous/transforms/randomized_comp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from xdsl.context import MLContext
from xdsl.dialects import builtin
from xdsl.dialects.arith import SelectOp
from xdsl.dialects.arith import AddiOp, SelectOp
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
Expand All @@ -24,7 +24,7 @@
XGate,
ZGate,
)
from inconspiquous.dialects.qssa import DynGateOp, GateOp
from inconspiquous.dialects.qssa import DynGateOp, GateOp, MeasureOp
from inconspiquous.dialects.prob import UniformOp


Expand Down Expand Up @@ -252,6 +252,50 @@ def match_and_rewrite(self, op: GateOp, rewriter: PatternRewriter):
)


class PadMeasure(RewritePattern):
"""
Places randomized dynamic pauli gates before a measurement.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: MeasureOp, rewriter: PatternRewriter):
x_rand = UniformOp(i1)
z_rand = UniformOp(i1)

id_gate = ConstantGateOp(IdentityGate())
x_gate = ConstantGateOp(XGate())
z_gate = ConstantGateOp(ZGate())

pre_x_sel = SelectOp(x_rand, x_gate, id_gate)
pre_x = DynGateOp(pre_x_sel, op.in_qubit)
pre_z_sel = SelectOp(z_rand, z_gate, id_gate)
pre_z = DynGateOp(pre_z_sel, pre_x)

new_measure = MeasureOp(pre_z)

corrected_measure = AddiOp(x_rand, new_measure.out)

rewriter.insert_op(
(
x_rand,
z_rand,
id_gate,
x_gate,
z_gate,
pre_x_sel,
pre_x,
pre_z_sel,
pre_z,
),
InsertPoint.before(op),
)

rewriter.replace_matched_op(
(new_measure, corrected_measure),
(corrected_measure.result, new_measure.out_qubit),
)


class RandomizedComp(ModulePass):
"""
Pads all "difficult" gates in a circuit.
Expand All @@ -262,7 +306,13 @@ class RandomizedComp(ModulePass):
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[PadTGate(), PadTDaggerGate(), PadHadamardGate(), PadCNotGate()]
[
PadTGate(),
PadTDaggerGate(),
PadHadamardGate(),
PadCNotGate(),
PadMeasure(),
]
),
apply_recursively=False, # Do not reapply
).rewrite_module(op)
19 changes: 19 additions & 0 deletions tests/filecheck/transforms/randomized_compilation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,22 @@ func.func @cnot_gate(%q1: !qubit.bit, %q2: !qubit.bit) -> (!qubit.bit, !qubit.bi
%q1_1, %q2_1 = qssa.gate<#gate.cnot> %q1, %q2
func.return %q1_1, %q2_1 : !qubit.bit, !qubit.bit
}

// CHECK: func.func @measure(%q : !qubit.bit) -> i1 {
// CHECK-NEXT: %0 = prob.uniform : i1
// CHECK-NEXT: %1 = prob.uniform : i1
// CHECK-NEXT: %2 = gate.constant #gate.id
// CHECK-NEXT: %3 = gate.constant #gate.x
// CHECK-NEXT: %4 = gate.constant #gate.z
// CHECK-NEXT: %5 = arith.select %0, %3, %2 : !gate.type<1>
// CHECK-NEXT: %6 = qssa.dyn_gate<%5> %q : !qubit.bit
// CHECK-NEXT: %7 = arith.select %1, %4, %2 : !gate.type<1>
// CHECK-NEXT: %8 = qssa.dyn_gate<%7> %6 : !qubit.bit
// CHECK-NEXT: %9, %q_1 = qssa.measure %8
// CHECK-NEXT: %10 = arith.addi %0, %9 : i1
// CHECK-NEXT: func.return %10 : i1
// CHECK-NEXT: }
func.func @measure(%q: !qubit.bit) -> i1 {
%0, %q_1 = qssa.measure %q
func.return %0 : i1
}

0 comments on commit cd4da13

Please sign in to comment.