Skip to content

Commit

Permalink
prepare for rl-zoo
Browse files Browse the repository at this point in the history
  • Loading branch information
flowerthrower committed Apr 19, 2024
1 parent da11497 commit 91dce53
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/mqt/predictor/rl/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def create_feature_dict(qc: QuantumCircuit, features: list[str] | str = "all") -
)
feature_dict["circuit"] = encode_circuit(qc) if ("all" in features or "circuit" in features) else None
# graph feature creation
if "all" in features or "graph" in features:
if "all" in features or any(k.startswith("graph") for k in features):
try:
ops_list = qc.count_ops()
ops_list_dict = ml.helper.dict_to_featurevector(ops_list)
Expand Down
7 changes: 7 additions & 0 deletions src/mqt/predictor/rl/predictorenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
self.num_steps = 0
self.layout: TranspileLayout | None = None
self.num_qubits_uncompiled_circuit = 0
self.init_reward = None

self.has_parametrized_gates = False

Expand Down Expand Up @@ -139,6 +140,8 @@ def step(self, action: int) -> tuple[dict[str, Any], float, bool, bool, dict[Any
raise RuntimeError(msg)

if action == self.action_terminate_index:
if self.init_reward is None:
self.init_reward = self.calculate_reward()
reward_val = self.calculate_reward()
done = True
else:
Expand All @@ -160,6 +163,10 @@ def calculate_reward(self) -> float:
return reward.crit_depth(self.state)
error_msg = f"Reward function {self.reward_function} not supported."
raise ValueError(error_msg)

def calculate_improvement(self) -> float:
"""Calculates and returns the improvement in reward."""
return self.init_reward - reward.expected_fidelity(self.state, self.device)

def render(self) -> None:
"""Renders the current state."""
Expand Down

0 comments on commit 91dce53

Please sign in to comment.