Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create and diagonalize mid-circuit measurements in conditionals #7037

Open
wants to merge 31 commits into
base: move_validation_to_cond
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f87539c
Add validation and raise error
lillian542 Feb 28, 2025
3b3485c
Add test for error
lillian542 Feb 28, 2025
558aedb
Update pennylane/devices/preprocess.py
lillian542 Feb 28, 2025
017d899
Merge branch 'master' into validate_mcm_in_cond
lillian542 Feb 28, 2025
18d88df
update changelog
lillian542 Feb 28, 2025
d67b42b
Merge branch 'validate_mcm_in_cond' of github.com:PennyLaneAI/pennyla…
lillian542 Feb 28, 2025
346263f
Merge branch 'validate_mcm_in_cond' into meas_in_cond
lillian542 Feb 28, 2025
eab3b61
add rough draft that does the thing
lillian542 Mar 3, 2025
2456436
update existing tests
lillian542 Mar 4, 2025
f9fadfe
remove unnecessary handling for elif clauses
lillian542 Mar 4, 2025
eb608c3
handle case where single MCM is final op
lillian542 Mar 4, 2025
c025816
more test updates
lillian542 Mar 4, 2025
522e6f7
more specific error messages
lillian542 Mar 5, 2025
7fb7d7b
add alternative implementation
lillian542 Mar 5, 2025
00b2c4a
add check that measurement is returned
lillian542 Mar 6, 2025
1875450
save intermediate debugging thing for now
lillian542 Mar 6, 2025
d912ce7
clean up, remove abandoned implementation
lillian542 Mar 6, 2025
0b0a977
merge master
lillian542 Mar 6, 2025
6e4c493
move validation rejecting Conditional(MCM) to qml.cond
lillian542 Mar 6, 2025
678ee75
move corresponding test
lillian542 Mar 6, 2025
5b057b2
delete temp file
lillian542 Mar 6, 2025
72b2e04
merge
lillian542 Mar 6, 2025
fa14fd1
Merge branch 'move_validation_to_cond' into meas_in_cond
lillian542 Mar 6, 2025
e3d69cf
update docstring on diagonalize_mcms
lillian542 Mar 6, 2025
7573bea
fix bug with condition for false_fn diag_gates
lillian542 Mar 7, 2025
ec80023
update cond_meas docstring
lillian542 Mar 7, 2025
565f879
tidy up
lillian542 Mar 7, 2025
359894f
update docstrings
lillian542 Mar 7, 2025
08a7bc9
update diagonalization tests
lillian542 Mar 7, 2025
65f34be
add tests
lillian542 Mar 7, 2025
89aa2d3
add jax tag to test
lillian542 Mar 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pennylane/ftqc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
measure_z,
diagonalize_mcms,
)
from .conditional_measure import cond_meas
from .lattice import Lattice, generate_lattice

__all__ = [
"Lattice",
"ParametricMidMeasureMP",
"XMidMeasureMP",
"YMidMeasureMP",
"cond_meas",
"measure_arbitrary_basis",
"measure_x",
"measure_y",
Expand Down
157 changes: 157 additions & 0 deletions pennylane/ftqc/conditional_measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the condition transform.
"""
from functools import wraps
from typing import Callable

import pennylane as qml
from pennylane.measurements import MeasurementValue, MidMeasureMP
from pennylane.ops.op_math.condition import CondCallable, Conditional


def cond_meas(
condition,
true_fn: Callable,
false_fn: Callable,
):
"""Conditions the basis of mid-circuit qubit measurements on parameters such as the results of
other mid-circuit qubit measurements.

.. note::

This function is currently not compatible with :func:`~.qjit`, or with
:func:`.pennylane.capture.enabled`.

Args:
condition (Union[.MeasurementValue, bool]): a conditional expression that may involve a mid-circuit
measurement value (see :func:`.pennylane.measure`).
true_fn (callable): The quantum function or PennyLane operation to
apply if ``condition`` is ``True``. The callable must create a single mid-circuit measurement.
false_fn (callable): The quantum function or PennyLane operation to
apply if ``condition`` is ``False``. The callable must create a single mid-circuit measurement.

.. note::
The mid-circuit measurements applied the two branches must both be applied to the same
wire, and they must have the same settings for `reset` and `postselection`. The two
branches can differ only in regard to the measurement basis of the applied measurement.


Returns:
function: A new function that applies the conditional measurements. The returned
function takes the same input arguments as ``true_fn`` and ``false_fn``.

**Example**

.. code-block:: python3

import pennylane as qml
from pennylane.ftqc import cond_meas, diagonalize_mcms, measure_x, measure_y

dev = qml.device("default.qubit", wires=3, shots=1000)

@diagonalize_mcms
@qml.qnode(dev, mcm_method="one-shot")
def qnode(x, y):
qml.RY(x, 0)
qml.Hadamard(1)

m0 = qml.measure(0)
m2 = cond_meas(m0, measure_x, measure_y)(1)

qml.Hadamard(2)
qml.cond(m2 == 0, qml.RY)(y, wires=2)
return qml.expval(qml.X(2))


>>> qnode(np.pi/3, np.pi/2)
0.3806

.. note::

If the first argument of ``cond_meas`` is a measurement value (e.g., ``m_0``
in ``qml.cond(m_0, measure_x, measure_y)``), then ``m_0 == 1`` is considered
internally.

.. warning::

Expressions with boolean logic flow using operators like ``and``,
``or`` and ``not`` are not supported as the ``condition`` argument.

While such statements may not result in errors, they may result in
incorrect behaviour.
"""
if qml.capture.enabled():
raise NotImplementedError("The `cond_meas` function is not compatible with program capture")

if not isinstance(condition, MeasurementValue):
# The condition is not a mid-circuit measurement - we can simplify immediately
return CondCallable(condition, true_fn, false_fn)

if callable(true_fn) and callable(false_fn):

# We assume this callable is a measurement function that returns a MeasurementValue
# containing a single mid-circuit measurement. If this isn't the case, getting the
# measurements will return None, and it will be caught in _validate_measurements.

@wraps(true_fn)
def wrapper(*args, **kwargs):

with qml.QueuingManager.stop_recording():
true_meas_return = true_fn(*args, **kwargs)
false_meas_return = false_fn(*args, **kwargs)

true_meas = getattr(true_meas_return, "measurements", [None])[0]
false_meas = getattr(false_meas_return, "measurements", [None])[0]

_validate_measurements(true_meas, false_meas)

Conditional(condition, true_meas)
Conditional(~condition, false_meas)

return MeasurementValue(
[true_meas, false_meas],
processing_fn=lambda v1, v2: qml.math.logical_or( # pylint: disable=unnecessary-lambda
v1, v2
),
)

else:
raise ValueError("Only measurement functions can be applied conditionally by `cond_meas`.")

return wrapper


def _validate_measurements(true_meas, false_meas):
"""Takes a pair of variables that are expected to be mid-circuit measurements
(representing a true and false functions for the conditional) and confirms that
they have the expected type ,and 'match' except for the measurement basis"""

if not (isinstance(true_meas, MidMeasureMP) and isinstance(false_meas, MidMeasureMP)):
raise ValueError(
"Only measurement functions that return a measurement value can be used in `cond_meas`"
)

if not (
true_meas.wires == false_meas.wires
and true_meas.reset == false_meas.reset
and true_meas.postselect == false_meas.postselect
):
raise ValueError(
"When applying a mid-circuit measurement in `cond_meas`, the `wire`, "
"`postselect` and `reset` behaviour must be consistent for both "
"branches of the conditional. Only the basis of the measurement (defined "
"by measurement type or by `plane` and `angle`) can vary."
)
106 changes: 85 additions & 21 deletions pennylane/ftqc/parametric_midmeasure.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,31 +570,75 @@ def diagonalize_mcms(tape):

from pennylane.ftqc import diagonalize_mcms, ParametricMidMeasureMP

dev = qml.device("default.qubit")
dev = qml.device("default.qubit", shots=1000)

@diagonalize_mcms
@qml.qnode(dev)
@qml.qnode(dev, method="one-shot")
def circuit(x):
qml.RY(x[0], wires=0)
ParametricMidMeasureMP(0, angle=x[1], plane="XY")
return qml.expval(qml.Z(0))
qml.RX(x, wires=0)
m = measure_y(0)
qml.cond(m, qml.X)(1)
return qml.expval(qml.Z(1))

Applying the transform inserts the relevant gates before the measurement to allow
measurements to be in the Z basis, so the original circuit

>>> print(qml.draw(circuit, level=0)([np.pi/4, np.pi]))
0: ──RY(0.79)──┤↗ˣʸ(3.14)├─┤ <Z>
>>> print(qml.draw(circuit, level=0)(np.pi/4))
0: ──RX(0.79)──┤↗ʸ├────┤
1: ─────────────║────X─┤ <Z>
╚════╝

becomes

>>> print(qml.draw(circuit)([np.pi/4, np.pi]))
──RY(0.79)──Rϕ(-3.14)──H──┤↗├─┤ <Z>
>>> print(qml.draw(circuit)(np.pi/4))
0: ──RY(0.79)──S†──H──┤↗├────┤
1: ────────────────────║───X─┤ <Z>
╚═══╝


.. details::
:title: Conditional measurements

The transform can also handle diagonalization of conditional measurements created by
:func:`qml.ftqc.cond_meas <pennylane.ftqc.cond_meas>`. This is done by replacing the
measurements for the true and false condition with conditional diagonalizing gates,
and a single measurement in the computational basis:

.. code-block:: python3

from pennylane.ftqc import diagonalize_mcms, measure_x

dev = qml.device("default.qubit")

@diagonalize_mcms
@qml.qnode(dev)
def circuit(x):
qml.RY(x[0], wires=0)
qml.RX(x[1], wires=1)
m = qml.measure(0)
m2 = cond_meas(m, measure_x, measure_y)(1)
qml.cond(m2, qml.X)(1)
return qml.expval(qml.Z(1))

This circuit diagonalizes to:

>>> print(qml.draw(circuit)([np.pi, np.pi/4]))
0: ──RY(3.14)──┤↗├───────────────────┤
1: ──RX(0.79)───║───H──S†──H──┤↗├──X─┤ <Z>
╚═══╩══╩═══╝ ║ ║
╚═══╝
"""

new_operations = []
mps_mapping = {}

for op in tape.operations:
curr_idx = 0

for i, op in enumerate(tape.operations):

if i != curr_idx:
continue

if isinstance(op, ParametricMidMeasureMP):

# add diagonalizing gates to tape
Expand All @@ -612,35 +656,55 @@ def circuit(x):
elif isinstance(op, qml.ops.Conditional):

# from MCM mapping, map any MCMs in the condition if needed
processing_fn = op.meas_val.processing_fn
mps = [mps_mapping.get(op, op) for op in op.meas_val.measurements]
expr = MeasurementValue(mps, processing_fn=processing_fn)

if isinstance(op.base, ParametricMidMeasureMP):
# add conditional diagonalizing gates + conditional MCM to the tape
if isinstance(op.base, MidMeasureMP):
# the only user-facing API for creating Conditionals with MCMs is meas_cond,
# which ensures both and true_fn and false_fn are included, so here we assume the
# expected format (i.e. conditional mcms are found pairwise with opposite conditions)
true_cond, false_cond = (op, tape.operations[i + 1])
# we process both the true_cond and the false_cond together, so we skip an index in the ops
curr_idx += 1

# add conditional diagonalizing gates + computational basis MCM to the tape
expr_true = MeasurementValue(mps, processing_fn=true_cond.meas_val.processing_fn)
expr_false = MeasurementValue(mps, processing_fn=false_cond.meas_val.processing_fn)

with qml.QueuingManager.stop_recording():
diag_gates = [
qml.ops.Conditional(expr=expr, then_op=gate)
for gate in op.diagonalizing_gates()
diag_gates_true = [
qml.ops.Conditional(expr=expr_true, then_op=gate)
for gate in true_cond.diagonalizing_gates()
]

diag_gates_false = [
qml.ops.Conditional(expr=expr_false, then_op=gate)
for gate in false_cond.diagonalizing_gates()
]

new_mp = MidMeasureMP(
op.wires, reset=op.base.reset, postselect=op.base.postselect, id=op.base.id
)
new_cond = qml.ops.Conditional(expr=expr, then_op=new_mp)

new_operations.extend(diag_gates)
new_operations.append(new_cond)
new_operations.extend(diag_gates_true)
new_operations.extend(diag_gates_false)
new_operations.append(new_mp)

# track mapping from original to computational basis MCMs
mps_mapping[op.base] = new_mp
mps_mapping[true_cond.base] = new_mp
mps_mapping[false_cond.base] = new_mp
else:
processing_fn = op.meas_val.processing_fn
expr = MeasurementValue(mps, processing_fn=processing_fn)

with qml.QueuingManager.stop_recording():
new_cond = qml.ops.Conditional(expr=expr, then_op=op.base)
new_operations.append(new_cond)

else:
new_operations.append(op)

curr_idx += 1

new_tape = tape.copy(operations=new_operations)

return (new_tape,), null_postprocessing
Loading