diff --git a/inconspiquous/transforms/randomized_comp.py b/inconspiquous/transforms/randomized_comp.py index ee071f4..27ad430 100644 --- a/inconspiquous/transforms/randomized_comp.py +++ b/inconspiquous/transforms/randomized_comp.py @@ -1,6 +1,6 @@ from xdsl.context import MLContext from xdsl.dialects import builtin -from xdsl.dialects.arith import SelectOp +from xdsl.dialects.arith import AddiOp, SelectOp from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -24,7 +24,7 @@ XGate, ZGate, ) -from inconspiquous.dialects.qssa import DynGateOp, GateOp +from inconspiquous.dialects.qssa import DynGateOp, GateOp, MeasureOp from inconspiquous.dialects.prob import UniformOp @@ -252,6 +252,50 @@ def match_and_rewrite(self, op: GateOp, rewriter: PatternRewriter): ) +class PadMeasure(RewritePattern): + """ + Places randomized dynamic pauli gates before a measurement. + """ + + @op_type_rewrite_pattern + def match_and_rewrite(self, op: MeasureOp, rewriter: PatternRewriter): + x_rand = UniformOp(i1) + z_rand = UniformOp(i1) + + id_gate = ConstantGateOp(IdentityGate()) + x_gate = ConstantGateOp(XGate()) + z_gate = ConstantGateOp(ZGate()) + + pre_x_sel = SelectOp(x_rand, x_gate, id_gate) + pre_x = DynGateOp(pre_x_sel, op.in_qubit) + pre_z_sel = SelectOp(z_rand, z_gate, id_gate) + pre_z = DynGateOp(pre_z_sel, pre_x) + + new_measure = MeasureOp(pre_z) + + corrected_measure = AddiOp(x_rand, new_measure.out) + + rewriter.insert_op( + ( + x_rand, + z_rand, + id_gate, + x_gate, + z_gate, + pre_x_sel, + pre_x, + pre_z_sel, + pre_z, + ), + InsertPoint.before(op), + ) + + rewriter.replace_matched_op( + (new_measure, corrected_measure), + (corrected_measure.result, new_measure.out_qubit), + ) + + class RandomizedComp(ModulePass): """ Pads all "difficult" gates in a circuit. @@ -262,7 +306,13 @@ class RandomizedComp(ModulePass): def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: PatternRewriteWalker( GreedyRewritePatternApplier( - [PadTGate(), PadTDaggerGate(), PadHadamardGate(), PadCNotGate()] + [ + PadTGate(), + PadTDaggerGate(), + PadHadamardGate(), + PadCNotGate(), + PadMeasure(), + ] ), apply_recursively=False, # Do not reapply ).rewrite_module(op) diff --git a/tests/filecheck/transforms/randomized_compilation.mlir b/tests/filecheck/transforms/randomized_compilation.mlir index 2188819..87cf00f 100644 --- a/tests/filecheck/transforms/randomized_compilation.mlir +++ b/tests/filecheck/transforms/randomized_compilation.mlir @@ -99,3 +99,22 @@ func.func @cnot_gate(%q1: !qubit.bit, %q2: !qubit.bit) -> (!qubit.bit, !qubit.bi %q1_1, %q2_1 = qssa.gate<#gate.cnot> %q1, %q2 func.return %q1_1, %q2_1 : !qubit.bit, !qubit.bit } + +// CHECK: func.func @measure(%q : !qubit.bit) -> i1 { +// CHECK-NEXT: %0 = prob.uniform : i1 +// CHECK-NEXT: %1 = prob.uniform : i1 +// CHECK-NEXT: %2 = gate.constant #gate.id +// CHECK-NEXT: %3 = gate.constant #gate.x +// CHECK-NEXT: %4 = gate.constant #gate.z +// CHECK-NEXT: %5 = arith.select %0, %3, %2 : !gate.type<1> +// CHECK-NEXT: %6 = qssa.dyn_gate<%5> %q : !qubit.bit +// CHECK-NEXT: %7 = arith.select %1, %4, %2 : !gate.type<1> +// CHECK-NEXT: %8 = qssa.dyn_gate<%7> %6 : !qubit.bit +// CHECK-NEXT: %9, %q_1 = qssa.measure %8 +// CHECK-NEXT: %10 = arith.addi %0, %9 : i1 +// CHECK-NEXT: func.return %10 : i1 +// CHECK-NEXT: } +func.func @measure(%q: !qubit.bit) -> i1 { + %0, %q_1 = qssa.measure %q + func.return %0 : i1 +}