Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Decomposition] Decomposition of controlled operations (WIP) #7045

Draft
wants to merge 22 commits into
base: decomp-graph-djikstra
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d247233
[wip] controlled decompositions
astralcai Feb 11, 2025
20877f7
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Feb 18, 2025
d6b04ad
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Feb 18, 2025
c73e4ce
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Feb 18, 2025
0106e89
wip
astralcai Feb 24, 2025
dbad1ed
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 4, 2025
e56d869
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 4, 2025
a933e39
wip
astralcai Mar 4, 2025
1126053
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 5, 2025
6be7e09
wip
astralcai Mar 5, 2025
e0c0283
implement controlled decompositions
astralcai Mar 5, 2025
ea06f46
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 6, 2025
597b783
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 6, 2025
c9c3048
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 6, 2025
590d2be
remove duplicate
astralcai Mar 6, 2025
bf71691
remove duplicate
astralcai Mar 6, 2025
26677a8
fix name
astralcai Mar 6, 2025
c66f441
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 6, 2025
26ed1b0
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 6, 2025
73b975d
fix bug
astralcai Mar 6, 2025
238cb7f
move something
astralcai Mar 6, 2025
72f14d8
Merge branch 'decomp-graph-djikstra' of https://github.com/PennyLaneA…
astralcai Mar 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions pennylane/decomposition/controlled_decomposition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains special logic of decomposing controlled operations."""

from typing import Callable

import pennylane as qml

from .decomposition_rule import DecompositionRule, register_resources
from .resources import Resources, controlled_resource_rep, resource_rep


class CustomControlledDecomposition(DecompositionRule):
"""A decomposition rule applicable to an operator with a custom controlled decomposition."""

def __init__(self, custom_op_type):
self.custom_op_type = custom_op_type
super().__init__(self._get_impl())

def _get_impl(self):
"""The implementation of a controlled op that decomposes to a custom controlled op."""

def _impl(*params, wires, control_wires, control_values, **_):
for w, val in zip(control_wires, control_values):
if not val:
qml.PauliX(w)
self.custom_op_type(*params, wires=wires)
for w, val in zip(control_wires, control_values):
if not val:
qml.PauliX(w)

return _impl

def compute_resources(
self, base_params, num_control_wires, num_zero_control_values, num_work_wires
) -> Resources:
return Resources(
num_zero_control_values * 2 + 1,
{
resource_rep(self.custom_op_type): 1,
resource_rep(qml.X): num_zero_control_values * 2,
},
)


class GeneralControlledDecomposition(DecompositionRule):
"""A decomposition rule for a controlled operation with a decomposition."""

def __init__(self, base_decomposition: DecompositionRule):
self._base_decomposition = base_decomposition
super().__init__(self._get_impl())

def compute_resources(
self, base_params, num_control_wires, num_zero_control_values, num_work_wires
) -> Resources:
base_resource_decomp = self._base_decomposition.compute_resources(**base_params)
controlled_resources = {
controlled_resource_rep(
base_class=base_op_rep.op_type,
base_params=base_op_rep.params,
num_control_wires=num_control_wires,
num_zero_control_values=0,
num_work_wires=num_work_wires,
): count
for base_op_rep, count in base_resource_decomp.gate_counts.items()
if count > 0
}
controlled_resources[resource_rep(qml.X)] = num_zero_control_values * 2
gate_count = sum(controlled_resources.values())
return Resources(gate_count, controlled_resources)

def _get_impl(self) -> Callable:
"""The default implementation of a controlled decomposition."""

def _impl(*_, control_wires, control_values, work_wires, base, **__):
for w, val in zip(control_wires, control_values):
if not val:
qml.PauliX(w)
qml.ctrl(
self._base_decomposition.impl,
control=control_wires,
control_values=control_values,
work_wires=work_wires,
)(*base.params, wires=base.wires, **base.hyperparameters)
for w, val in zip(control_wires, control_values):
if not val:
qml.PauliX(w)

return _impl


def _controlled_g_phase_resource(*_, num_control_wires, num_zero_control_values, num_work_wires):
if num_control_wires == 1:
return {
resource_rep(qml.PauliX): num_zero_control_values * 2,
resource_rep(qml.PhaseShift): 1,
}
else:
return {
resource_rep(qml.PauliX): num_zero_control_values * 2,
controlled_resource_rep(
qml.PhaseShift,
base_params={},
num_control_wires=num_control_wires - 1,
num_zero_control_values=0,
num_work_wires=num_work_wires,
): 1,
}


@register_resources(_controlled_g_phase_resource)
def controlled_global_phase_decomp(*_, control_wires, control_values, work_wires, base, **__):
"""The decomposition rule for a controlled global phase."""

for w, val in zip(control_wires, control_values):
if not val:
qml.PauliX(w)
if len(control_wires) == 1:
qml.PhaseShift(-base.data[0], wires=control_wires[-1])
else:
qml.ctrl(
qml.PhaseShift(-base.data[0], wires=control_wires[-1]),
control=control_wires[:-1],
work_wires=work_wires,
)
for w, val in zip(control_wires, control_values):
if not val:
qml.PauliX(w)


def _controlled_x_resource(*_, num_control_wires, num_zero_control_values, num_work_wires):
if num_control_wires == 1 and num_zero_control_values == 0:
return {resource_rep(qml.CNOT): 1}
if num_control_wires == 2 and num_zero_control_values == 0:
return {resource_rep(qml.Toffoli): 1}
return {
resource_rep(
qml.MultiControlledX,
num_control_wires=num_control_wires,
num_zero_control_values=num_zero_control_values,
num_work_wires=num_work_wires,
): 1,
}


@register_resources(_controlled_x_resource)
def controlled_x_decomp(*_, control_wires, control_values, work_wires, base, **__):
"""The decomposition rule for a controlled PauliX."""

qml.ctrl(
qml.PauliX(base.wires),
control=control_wires,
control_values=control_values,
work_wires=work_wires,
)
85 changes: 79 additions & 6 deletions pennylane/decomposition/decomposition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,20 @@

from __future__ import annotations

import functools
from dataclasses import dataclass

import rustworkx as rx
from rustworkx.visit import DijkstraVisitor, PruneSearch, StopSearch

import pennylane as qml

from .controlled_decomposition import (
CustomControlledDecomposition,
GeneralControlledDecomposition,
controlled_global_phase_decomp,
controlled_x_decomp,
)
from .decomposition_rule import DecompositionRule, list_decomps
from .resources import CompressedResourceOp, Resources, resource_rep

Expand Down Expand Up @@ -72,10 +81,10 @@ def __init__(
self._target_gate_set = target_gate_set
self._original_ops_indices: set[int] = set()
self._target_gate_indices: set[int] = set()
self._graph = rx.PyDiGraph()
self._op_node_indices: dict[CompressedResourceOp, int] = {}
self._fixed_decomps = fixed_decomps or {}
self._alt_decomps = alt_decomps or {}
self._graph = rx.PyDiGraph()
self._construct_graph()
self._visitor = None

Expand Down Expand Up @@ -107,29 +116,73 @@ def _recursively_add_op_node(self, op_node: CompressedResourceOp) -> int:

op_node_idx = self._graph.add_node(op_node)
self._op_node_indices[op_node] = op_node_idx

if op_node.op_type.__name__ in self._target_gate_set:
self._target_gate_indices.add(op_node_idx)
return op_node_idx

if op_node.op_type is qml.ops.Controlled or op_node.op_type is qml.ops.ControlledOp:
# This branch only applies to general controlled operators
return self._add_controlled_decomp_node(op_node, op_node_idx)

for decomposition in self._get_decompositions(op_node.op_type):
d_node_idx = self._recursively_add_decomposition_node(decomposition, op_node.params)
resource_decomp = decomposition.compute_resources(**op_node.params)
d_node_idx = self._recursively_add_decomposition_node(decomposition, resource_decomp)
self._graph.add_edge(d_node_idx, op_node_idx, 0)

return op_node_idx

def _recursively_add_decomposition_node(self, rule: DecompositionRule, params: dict) -> int:
def _add_special_decomp_rule_to_op(
self, rule: DecompositionRule, op_node: CompressedResourceOp, op_node_idx: int
) -> int:
"""Adds a special decomposition rule to the graph."""
resource_decomp = rule.compute_resources(**op_node.params)
d_node_idx = self._recursively_add_decomposition_node(rule, resource_decomp)
self._graph.add_edge(d_node_idx, op_node_idx, 0)
return op_node_idx

def _add_controlled_decomp_node(self, op_node: CompressedResourceOp, op_node_idx: int) -> int:
"""Adds a controlled decomposition node to the graph."""

base_class = op_node.params["base_class"]
num_control_wires = op_node.params["num_control_wires"]

# Handle controlled global phase
if base_class is qml.GlobalPhase:
rule = controlled_global_phase_decomp
return self._add_special_decomp_rule_to_op(rule, op_node, op_node_idx)

# Handle controlled-X gates
if base_class is qml.X:
rule = controlled_x_decomp
return self._add_special_decomp_rule_to_op(rule, op_node, op_node_idx)

# Handle custom controlled ops
if (base_class, num_control_wires) in base_to_custom_ctrl_op():
custom_op_type = base_to_custom_ctrl_op()[(base_class, num_control_wires)]
rule = CustomControlledDecomposition(custom_op_type)
return self._add_special_decomp_rule_to_op(rule, op_node, op_node_idx)

# General case
for base_decomposition in self._get_decompositions(base_class):
rule = GeneralControlledDecomposition(base_decomposition)
self._add_special_decomp_rule_to_op(rule, op_node, op_node_idx)

return op_node_idx

def _recursively_add_decomposition_node(
self, rule: DecompositionRule, resource_decomp: Resources
) -> int:
"""Recursively adds a decomposition node to the graph.

A decomposition node is defined by a decomposition rule and a first-order resource estimate
of this decomposition as computed with resource params passed from the operator node.

"""

resource_decomp = rule.compute_resources(**params)
d_node = _DecompositionNode(rule, resource_decomp)
d_node_idx = self._graph.add_node(d_node)
all_ops = [op for op, count in resource_decomp.gate_counts.items() if count > 0]
for op in all_ops:
for op in resource_decomp.gate_counts:
op_node_idx = self._recursively_add_op_node(op)
self._graph.add_edge(op_node_idx, d_node_idx, (op_node_idx, d_node_idx))
return d_node_idx
Expand Down Expand Up @@ -241,3 +294,23 @@ def edge_relaxed(self, edge):
elif isinstance(target_node, CompressedResourceOp):
self.p[target_idx] = src_idx
self.d[target_idx] = self.d[src_idx]


@functools.lru_cache()
def base_to_custom_ctrl_op():
"""A dictionary mapping base op types to their custom controlled versions."""

ops_with_custom_ctrl_ops = {
(qml.PauliZ, 1): qml.CZ,
(qml.PauliZ, 2): qml.CCZ,
(qml.PauliY, 1): qml.CY,
(qml.CZ, 1): qml.CCZ,
(qml.SWAP, 1): qml.CSWAP,
(qml.Hadamard, 1): qml.CH,
(qml.RX, 1): qml.CRX,
(qml.RY, 1): qml.CRY,
(qml.RZ, 1): qml.CRZ,
(qml.Rot, 1): qml.CRot,
(qml.PhaseShift, 1): qml.ControlledPhaseShift,
}
return ops_with_custom_ctrl_ops
3 changes: 3 additions & 0 deletions pennylane/decomposition/decomposition_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ def __init__(self, func: Callable, resources: Callable | dict = None):
else:
self._compute_resources = resources

def __call__(self, *args, **kwargs):
return self.impl(*args, **kwargs)

def compute_resources(self, *args, **kwargs) -> Resources:
"""Computes the resources required to implement this decomposition rule."""
if self._compute_resources is None:
Expand Down
2 changes: 1 addition & 1 deletion pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def setter(self, func):
return self


def classproperty(func):
def classproperty(func) -> ClassPropertyDescriptor:
"""The class property decorator"""
if not isinstance(func, (classmethod, staticmethod)):
func = classmethod(func)
Expand Down