Skip to content

Commit

Permalink
add gnn for RL
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick Hopf committed Apr 19, 2024
1 parent 7900db5 commit da11497
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ implicit_reexport = true
# recent versions of `gym` are typed, but stable-baselines3 pins a very old version of gym.
# qiskit is not yet marked as typed, but is typed mostly.
# the other libraries do not have type stubs.
module = ["qiskit.*", "joblib.*", "sklearn.*", "matplotlib.*", "gymnasium.*", "mqt.bench.*", "sb3_contrib.*", "bqskit.*", "qiskit_ibm_runtime.*"]
module = ["qiskit.*", "joblib.*", "sklearn.*", "matplotlib.*", "gymnasium.*", "mqt.bench.*", "sb3_contrib.*", "bqskit.*", "qiskit_ibm_runtime.*", "torch.*", "torch_geometric.*", "stable_baselines3.*" ]
ignore_missing_imports = true


Expand Down
2 changes: 1 addition & 1 deletion src/mqt/predictor/ml/GNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, **kwargs: object) -> None:
dropout=self.dropout,
act=self.activation_func,
norm=self.batch_norm_layer if self.batch_norm else None,
edge_dim=edge_dim,
# edge_dim=edge_dim,
)
]
last_hidden_dim = self.output_dim
Expand Down
38 changes: 34 additions & 4 deletions src/mqt/predictor/rl/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def encode_circuit(qc: QuantumCircuit) -> NDArray[np.int_]:
# Create a look-up table for qubit indices (needed for multiple registers)
q_idx_LUT = {qubit: idx for idx, qubit in enumerate(dag.qubits)}

num_qubits, _max_depth = 11, 10000
num_qubits, _max_depth = 11, 10000 # Adjust according to the device

matrix = [] # np.zeros((num_qubits, num_qubits, max_depth), dtype=np.int_)
for _i, tensor_op in enumerate(layers[1:-1]):
Expand Down Expand Up @@ -453,10 +453,40 @@ def create_feature_dict(qc: QuantumCircuit, features: list[str] | str = "all") -
feature_dict["directed_program_communication"] = np.array(
[supermarq_features.directed_program_communication], dtype=np.float32
)
feature_dict["singleQ_gates_per_layer"] = np.array([supermarq_features.singleQ_gates_per_layer], dtype=np.float32)
feature_dict["multiQ_gates_per_layer"] = np.array([supermarq_features.multiQ_gates_per_layer], dtype=np.float32)
feature_dict["single_qubit_gates_per_layer"] = np.array(
[supermarq_features.single_qubit_gates_per_layer], dtype=np.float32
)
feature_dict["multi_qubit_gates_per_layer"] = np.array(
[supermarq_features.multi_qubit_gates_per_layer], dtype=np.float32
)
feature_dict["circuit"] = encode_circuit(qc) if ("all" in features or "circuit" in features) else None

# graph feature creation
if "all" in features or "graph" in features:
try:
ops_list = qc.count_ops()
ops_list_dict = ml.helper.dict_to_featurevector(ops_list)
# operations/gates encoding for graph feature creation
ops_list_encoding = ops_list_dict.copy()
ops_list_encoding["measure"] = len(ops_list_encoding) # add extra gate
# unique number for each gate {'measure': 0, 'cx': 1, ...}
for i, key in enumerate(ops_list_dict):
ops_list_encoding[key] = i
graph = ml.helper.circuit_to_graph(qc, ops_list_encoding)
# convert to lists of numpy arrays
x_np = [x.numpy() for x in graph.x]
edge_index_np = [edge.numpy() for edge in graph.edge_index]
edge_attr_np = [attr.numpy() for attr in graph.edge_attr]
feature_dict["graph_x"] = x_np
feature_dict["graph_edge_index"] = edge_index_np
feature_dict["graph_edge_attr"] = edge_attr_np
except Exception:
feature_dict["graph_x"] = None
feature_dict["graph_edge_index"] = None
feature_dict["graph_edge_attr"] = None
else:
feature_dict["graph_x"] = None
feature_dict["graph_edge_index"] = None
feature_dict["graph_edge_attr"] = None
return {k: v for k, v in feature_dict.items() if ("all" in features or k in features)}


Expand Down
4 changes: 2 additions & 2 deletions src/mqt/predictor/rl/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ def train_model(

if test:
n_steps = 100
progress_bar = True
progress_bar = False
else:
n_steps = 2048
progress_bar = True
progress_bar = False

logger.debug("Start training for: " + self.figure_of_merit + " on " + self.device_name)
rl.PredictorEnv(reward_function=self.figure_of_merit, device_name=self.device_name, features=self.features)
Expand Down
20 changes: 10 additions & 10 deletions src/mqt/predictor/rl/predictorenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

if TYPE_CHECKING:
from pathlib import Path

import numpy as np
from bqskit.ext import bqskit_to_qiskit, qiskit_to_bqskit
from gymnasium import Env
from gymnasium.spaces import Box, Dict, Discrete, Sequence
from pytket.circuit import Qubit
from pytket.extensions.qiskit import qiskit_to_tk, tk_to_qiskit
from qiskit import QuantumCircuit
from qiskit.passmanager.flow_controllers import DoWhileController
Expand Down Expand Up @@ -83,6 +83,8 @@ def __init__(
self.has_parametrized_gates = False

qubit_num, _max_depth = self.device.num_qubits, 10000
max_num_nodes = 10000
max_num_node_labels = 50

spaces = {
"num_qubits": Discrete(128),
Expand All @@ -93,8 +95,8 @@ def __init__(
"parallelism": Box(low=0, high=1, shape=(1,), dtype=np.float32),
"liveness": Box(low=0, high=1, shape=(1,), dtype=np.float32),
"directed_program_communication": Box(low=0, high=1, shape=(1,), dtype=np.float32),
"singleQ_gates_per_layer": Box(low=0, high=1, shape=(1,), dtype=np.float32),
"multiQ_gates_per_layer": Box(low=0, high=1, shape=(1,), dtype=np.float32),
"single_qubit_gates_per_layer": Box(low=0, high=1, shape=(1,), dtype=np.float32),
"multi_qubit_gates_per_layer": Box(low=0, high=1, shape=(1,), dtype=np.float32),
"circuit": Sequence(
Box(
low=0,
Expand All @@ -107,6 +109,9 @@ def __init__(
dtype=np.int_,
),
),
"graph_edge_index": Sequence(Box(low=0, high=max_num_nodes, shape=(2,), dtype=np.int_)),
"graph_x": Sequence(Box(low=0, high=max_num_node_labels, shape=(1,), dtype=np.int_)),
"graph_edge_attr": Sequence(Box(low=0, high=1, shape=(1,), dtype=np.int_)),
}
self.observation_space = Dict({k: v for k, v in spaces.items() if ("all" in features or k in features)})
self.features = features
Expand Down Expand Up @@ -228,13 +233,8 @@ def apply_action(self, action_index: int) -> QuantumCircuit | None:
action = self.action_set[action_index]
if action["name"] == "terminate":
return self.state
if (
action_index
in self.actions_layout_indices + self.actions_routing_indices + self.actions_mapping_indices
):
transpile_pass = action["transpile_pass"](self.device.coupling_map)
elif action_index in self.actions_synthesis_indices:
transpile_pass = action["transpile_pass"](self.device.basis_gates)
if action_index in self.actions_opt_indices:
transpile_pass = action["transpile_pass"]
else:
transpile_pass = action["transpile_pass"](self.device)

Expand Down
48 changes: 47 additions & 1 deletion src/mqt/predictor/rl/torch_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from typing import TYPE_CHECKING

from stable_baselines3.common.preprocessing import get_flattened_obs_dim
from torch_geometric.data import Batch, Data

from mqt.predictor.ml.GNN import Net

if TYPE_CHECKING:
from gymnasium import spaces
Expand Down Expand Up @@ -46,14 +49,20 @@ def __init__(
extractors: Dict[str, nn.Module] = {}

total_concat_size = 0
graph_observation_space = []
for key, subspace in observation_space.spaces.items():
if key == "circuit":
extractors[key] = CustomCNN(subspace, features_dim=cnn_output_dim, normalized_image=normalized_image)
total_concat_size += cnn_output_dim
elif key.startswith("graph"):
graph_observation_space.append(subspace)
else:
# The observation key is a vector, flatten it if needed
extractors[key] = nn.Flatten()
total_concat_size += get_flattened_obs_dim(subspace)
if graph_observation_space:
extractors["graph"] = CustomGNN(graph_observation_space, features_dim=cnn_output_dim)
total_concat_size += cnn_output_dim

self.extractors = nn.ModuleDict(extractors)

Expand All @@ -64,7 +73,11 @@ def forward(self, observations: TensorDict) -> th.Tensor:
encoded_tensor_list = []

for key, extractor in self.extractors.items():
encoded_tensor_list.append(extractor(observations[key]))
if key == "graph":
obs = [v for k, v in observations.items() if k.startswith("graph")]
encoded_tensor_list.append(extractor(obs))
else:
encoded_tensor_list.append(extractor(observations[key]))
return th.cat(encoded_tensor_list, dim=1)


Expand Down Expand Up @@ -128,3 +141,36 @@ def forward(self, x: list[th.Tensor] | th.Tensor) -> th.Tensor:
lstm_out, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)

return self.linear(lstm_out[:, -1, :])


class CustomGNN(BaseFeaturesExtractor): # type: ignore[misc]
"""
GNN
:param observation_space:
:param features_dim: Number of features extracted.
This corresponds to the number of unit for the last layer.
:param normalized_image: Whether to assume that the image is already normalized
or not (this disables dtype and bounds checks): when True, it only checks that
the space is a Box and has 3 dimensions.
Otherwise, it checks that it has expected dtype (uint8) and bounds (values in [0, 255]).
"""

def __init__(self, observation_space: spaces.Box, features_dim: int = 512) -> None:
super().__init__(observation_space, features_dim)
self.gnn = Net(output_dim=features_dim)

def forward(self, input: list[list[th.Tensor]] | list[th.Tensor]) -> th.Tensor:
data_list = []
if isinstance(input[0], th.Tensor):
input = [[i] for i in input]
for i in range(len(input[0])):
x, edge_index, edge_attr = input[0][i], input[1][i], input[2][i]
x = x.squeeze(0).long()
edge_index = edge_index.squeeze(0).long()
edge_attr = edge_attr.squeeze(0).long()
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
data_list.append(data)

batch = Batch.from_data_list(data_list)
return self.gnn(batch)
2 changes: 1 addition & 1 deletion tests/compilation/test_predictor_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_qcompile_with_newly_trained_models(figure_of_merit: reward.figure_of_me
"""Test the qcompile function with a newly trained model."""
""" Important: Those trained models are used in later tests and must not be deleted. """

device = "ibm_montreal"
device = "ionq_harmony"
predictor = rl.Predictor(figure_of_merit=figure_of_merit, device_name=device)
predictor.train_model(
timesteps=100,
Expand Down

0 comments on commit da11497

Please sign in to comment.