Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse search results when given a partially filled directory #98

Merged
merged 8 commits into from
Aug 23, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.

## [Unreleased]

### Added

- Reuse search results when given a partially filled directory ([#98](https://github.com/microsoft/syntheseus/pull/98)) ([@kmaziarz])

### Fixed

- Shift the `pandas` dependency to the external model packages ([#94](https://github.com/microsoft/syntheseus/pull/94)) ([@kmaziarz])
Expand Down
48 changes: 43 additions & 5 deletions syntheseus/cli/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from omegaconf import MISSING, DictConfig, OmegaConf
from tqdm import tqdm

from syntheseus.interface.molecule import Molecule
from syntheseus import Molecule
from syntheseus.reaction_prediction.inference.config import BackwardModelConfig
from syntheseus.reaction_prediction.utils.config import get_config as cli_get_config
from syntheseus.reaction_prediction.utils.misc import set_random_seed
Expand All @@ -46,7 +46,12 @@
from syntheseus.search.mol_inventory import SmilesListInventory
from syntheseus.search.node_evaluation import common as node_evaluation_common
from syntheseus.search.utils.misc import lookup_by_name
from syntheseus.search.visualization import visualize_andor, visualize_molset

try:
# Try to import the visualization code, which will work only if `graphviz` is installed.
from syntheseus.search.visualization import visualize_andor, visualize_molset
except ModuleNotFoundError:
kmaziarz marked this conversation as resolved.
Show resolved Hide resolved
pass

logger = logging.getLogger(__file__)

Expand Down Expand Up @@ -108,6 +113,7 @@ class BaseSearchConfig:

inventory_smiles_file: str = MISSING # Purchasable molecules
results_dir: str = "." # Directory to save the results in
append_timestamp_to_dir: bool = True # Whether to append the current time to directory name

# By default limit search time (but set very high iteration limits just in case)
time_limit_s: float = 600
Expand Down Expand Up @@ -242,8 +248,13 @@ def build_node_evaluator(key: str) -> None:

# Prepare the output directory
results_dir_top_level = Path(config.results_dir)
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
results_dir_current_run = results_dir_top_level / f"{config.model_class.name}_{str(timestamp)}"

dirname = config.model_class.name
if config.append_timestamp_to_dir:
timestamp = datetime.datetime.now().isoformat(timespec="seconds")
dirname += f"_{str(timestamp)}"

results_dir_current_run = results_dir_top_level / dirname

logger.info("Setup completed")
num_targets = len(search_targets)
Expand All @@ -260,6 +271,32 @@ def build_node_evaluator(key: str) -> None:
results_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Outputs will be saved under {results_dir}")

results_lock_path = results_dir / ".lock"
results_stats_path = results_dir / "stats.json"

if results_lock_path.exists():
paths = [path for path in results_dir.iterdir() if path.is_file()]
logger.warning(
f"Lockfile was found which means the last run failed, purging {len(paths)} files"
)

for path in paths:
path.unlink()
elif results_stats_path.exists():
with open(results_stats_path, "rt") as f_stats:
stats = json.load(f_stats)
if stats.get("index") != idx or stats.get("smiles") != smiles:
raise RuntimeError(
f"Data present under {results_dir} does not match the current run"
)

all_stats.append(stats)

logger.info("Search results already exist, skipping")
continue

results_lock_path.touch()

alg.reset()
output_graph, _ = alg.run_from_mol(Molecule(smiles))
logger.info(f"Finished search for target {smiles}")
Expand Down Expand Up @@ -288,7 +325,7 @@ def build_node_evaluator(key: str) -> None:
all_stats.append(stats)
logger.info(pformat(stats))

with open(results_dir / "stats.json", "wt") as f_stats:
with open(results_stats_path, "wt") as f_stats:
f_stats.write(json.dumps(stats, indent=2))

if config.save_graph:
Expand Down Expand Up @@ -321,6 +358,7 @@ def build_node_evaluator(key: str) -> None:
else:
assert False

results_lock_path.unlink()
del results_dir

if num_targets > 1:
Expand Down
Empty file.
103 changes: 103 additions & 0 deletions syntheseus/tests/cli/test_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

import json
from pathlib import Path
from types import SimpleNamespace
from typing import Sequence

import pytest
from omegaconf import OmegaConf

from syntheseus import BackwardReactionModel, Bag, Molecule, SingleProductReaction
from syntheseus.cli.search import SearchConfig, run_from_config
from syntheseus.reaction_prediction.inference.config import BackwardModelClass


class FlakyReactionModel(BackwardReactionModel):
"""Dummy reaction model that only works when called for the first time."""

def __init__(self, *args, **kwargs) -> None:
super().__init__()
self._used = False

def _get_reactions(
self, inputs: list[Molecule], num_results: int
) -> list[Sequence[SingleProductReaction]]:
if self._used:
raise RuntimeError()

self._used = True
return [
[
SingleProductReaction(
reactants=Bag([Molecule("C")]), product=product, metadata={"probability": 1.0}
)
]
for product in inputs
]


def test_resume_search(tmpdir: Path) -> None:
search_targets_file_path = tmpdir / "search_targets.smiles"
with open(search_targets_file_path, "wt") as f_search_targets:
f_search_targets.write("CC\nCC\nCC\nCC\n")

inventory_file_path = tmpdir / "inventory.smiles"
with open(inventory_file_path, "wt") as f_inventory:
f_inventory.write("C\n")

# Inject our flaky reaction model into the set of supported model classes.
BackwardModelClass._member_map_["FlakyReactionModel"] = SimpleNamespace( # type: ignore
name="FlakyReactionModel", value=FlakyReactionModel
)

config = OmegaConf.create( # type: ignore
SearchConfig(
model_class="FlakyReactionModel", # type: ignore[arg-type]
search_algorithm="retro_star",
search_targets_file=str(search_targets_file_path),
inventory_smiles_file=str(inventory_file_path),
results_dir=str(tmpdir),
append_timestamp_to_dir=False,
limit_iterations=1,
num_routes_to_plot=0,
)
)

results_dir = tmpdir / "FlakyReactionModel"

def file_exist(idx: int, name: str) -> bool:
return (results_dir / str(idx) / name).exists()

# Try to run search three times; each time we will succeed solving one target (which requires one
# call) and then fail on the next one.
for trial_idx in range(3):
with pytest.raises(RuntimeError):
run_from_config(config)

for idx in range(trial_idx + 1):
assert file_exist(idx, "stats.json")
assert not file_exist(idx, ".lock")

assert not file_exist(trial_idx + 1, "stats.json")
assert file_exist(trial_idx + 1, ".lock")

run_from_config(config)

# The last search needs to solve one final target so it will succeed.
for idx in range(4):
assert file_exist(idx, "stats.json")
assert not file_exist(idx, ".lock")

with open(results_dir / "stats.json", "rt") as f_stats:
stats = json.load(f_stats)

# Even though each search only solved a single target, final stats should include everything.
assert stats["num_targets"] == stats["num_solved_targets"] == 4

# Finally change the targets and verify that the discrepancy will be detected.
with open(search_targets_file_path, "wt") as f_search_targets:
f_search_targets.write("CC\nCCCC\nCC\nCC\n")

with pytest.raises(RuntimeError):
run_from_config(config)
Loading