diff --git a/CHANGELOG.md b/CHANGELOG.md index 663d5df1..91559c4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and the project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0. ## [Unreleased] +### Added + +- Reuse search results when given a partially filled directory ([#98](https://github.com/microsoft/syntheseus/pull/98)) ([@kmaziarz]) + ### Fixed - Shift the `pandas` dependency to the external model packages ([#94](https://github.com/microsoft/syntheseus/pull/94)) ([@kmaziarz]) diff --git a/syntheseus/cli/search.py b/syntheseus/cli/search.py index 9e886bd3..11bdfdc5 100644 --- a/syntheseus/cli/search.py +++ b/syntheseus/cli/search.py @@ -29,7 +29,7 @@ from omegaconf import MISSING, DictConfig, OmegaConf from tqdm import tqdm -from syntheseus.interface.molecule import Molecule +from syntheseus import Molecule from syntheseus.reaction_prediction.inference.config import BackwardModelConfig from syntheseus.reaction_prediction.utils.config import get_config as cli_get_config from syntheseus.reaction_prediction.utils.misc import set_random_seed @@ -46,7 +46,14 @@ from syntheseus.search.mol_inventory import SmilesListInventory from syntheseus.search.node_evaluation import common as node_evaluation_common from syntheseus.search.utils.misc import lookup_by_name -from syntheseus.search.visualization import visualize_andor, visualize_molset + +try: + # Try to import the visualization code, which will work only if `graphviz` is installed. + from syntheseus.search.visualization import visualize_andor, visualize_molset + + VISUALIZATION_CODE_IMPORTED = True +except ModuleNotFoundError: + VISUALIZATION_CODE_IMPORTED = False logger = logging.getLogger(__file__) @@ -108,6 +115,7 @@ class BaseSearchConfig: inventory_smiles_file: str = MISSING # Purchasable molecules results_dir: str = "." # Directory to save the results in + append_timestamp_to_dir: bool = True # Whether to append the current time to directory name # By default limit search time (but set very high iteration limits just in case) time_limit_s: float = 600 @@ -147,6 +155,12 @@ def run_from_config(config: SearchConfig) -> Path: print("Running search with the following config:") print(config) + if config.num_routes_to_plot > 0 and not VISUALIZATION_CODE_IMPORTED: + raise ValueError( + "Could not import visualization code (likely `viz` dependencies are not installed); " + "please install missing dependencies or set `num_routes_to_plot=0`" + ) + search_target, search_targets_file = [ cast(DictConfig, config).get(key) for key in ["search_target", "search_targets_file"] ] @@ -242,8 +256,13 @@ def build_node_evaluator(key: str) -> None: # Prepare the output directory results_dir_top_level = Path(config.results_dir) - timestamp = datetime.datetime.now().isoformat(timespec="seconds") - results_dir_current_run = results_dir_top_level / f"{config.model_class.name}_{str(timestamp)}" + + dirname = config.model_class.name + if config.append_timestamp_to_dir: + timestamp = datetime.datetime.now().isoformat(timespec="seconds") + dirname += f"_{str(timestamp)}" + + results_dir_current_run = results_dir_top_level / dirname logger.info("Setup completed") num_targets = len(search_targets) @@ -260,6 +279,32 @@ def build_node_evaluator(key: str) -> None: results_dir.mkdir(parents=True, exist_ok=True) logger.info(f"Outputs will be saved under {results_dir}") + results_lock_path = results_dir / ".lock" + results_stats_path = results_dir / "stats.json" + + if results_lock_path.exists(): + paths = [path for path in results_dir.iterdir() if path.is_file()] + logger.warning( + f"Lockfile was found which means the last run failed, purging {len(paths)} files" + ) + + for path in paths: + path.unlink() + elif results_stats_path.exists(): + with open(results_stats_path, "rt") as f_stats: + stats = json.load(f_stats) + if stats.get("index") != idx or stats.get("smiles") != smiles: + raise RuntimeError( + f"Data present under {results_dir} does not match the current run" + ) + + all_stats.append(stats) + + logger.info("Search results already exist, skipping") + continue + + results_lock_path.touch() + alg.reset() output_graph, _ = alg.run_from_mol(Molecule(smiles)) logger.info(f"Finished search for target {smiles}") @@ -288,7 +333,7 @@ def build_node_evaluator(key: str) -> None: all_stats.append(stats) logger.info(pformat(stats)) - with open(results_dir / "stats.json", "wt") as f_stats: + with open(results_stats_path, "wt") as f_stats: f_stats.write(json.dumps(stats, indent=2)) if config.save_graph: @@ -321,6 +366,7 @@ def build_node_evaluator(key: str) -> None: else: assert False + results_lock_path.unlink() del results_dir if num_targets > 1: diff --git a/syntheseus/reaction_prediction/cli/__init__.py b/syntheseus/reaction_prediction/cli/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/syntheseus/tests/cli/test_search.py b/syntheseus/tests/cli/test_search.py new file mode 100644 index 00000000..d5002b49 --- /dev/null +++ b/syntheseus/tests/cli/test_search.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import json +from pathlib import Path +from types import SimpleNamespace +from typing import Sequence + +import pytest +from omegaconf import OmegaConf + +from syntheseus import BackwardReactionModel, Bag, Molecule, SingleProductReaction +from syntheseus.cli.search import SearchConfig, run_from_config +from syntheseus.reaction_prediction.inference.config import BackwardModelClass + + +class FlakyReactionModel(BackwardReactionModel): + """Dummy reaction model that only works when called for the first time.""" + + def __init__(self, *args, **kwargs) -> None: + super().__init__() + self._used = False + + def _get_reactions( + self, inputs: list[Molecule], num_results: int + ) -> list[Sequence[SingleProductReaction]]: + if self._used: + raise RuntimeError() + + self._used = True + return [ + [ + SingleProductReaction( + reactants=Bag([Molecule("C")]), product=product, metadata={"probability": 1.0} + ) + ] + for product in inputs + ] + + +def test_resume_search(tmpdir: Path) -> None: + search_targets_file_path = tmpdir / "search_targets.smiles" + with open(search_targets_file_path, "wt") as f_search_targets: + f_search_targets.write("CC\nCC\nCC\nCC\n") + + inventory_file_path = tmpdir / "inventory.smiles" + with open(inventory_file_path, "wt") as f_inventory: + f_inventory.write("C\n") + + # Inject our flaky reaction model into the set of supported model classes. + BackwardModelClass._member_map_["FlakyReactionModel"] = SimpleNamespace( # type: ignore + name="FlakyReactionModel", value=FlakyReactionModel + ) + + config = OmegaConf.create( # type: ignore + SearchConfig( + model_class="FlakyReactionModel", # type: ignore[arg-type] + search_algorithm="retro_star", + search_targets_file=str(search_targets_file_path), + inventory_smiles_file=str(inventory_file_path), + results_dir=str(tmpdir), + append_timestamp_to_dir=False, + limit_iterations=1, + num_routes_to_plot=0, + ) + ) + + results_dir = tmpdir / "FlakyReactionModel" + + def file_exist(idx: int, name: str) -> bool: + return (results_dir / str(idx) / name).exists() + + # Try to run search three times; each time we will succeed solving one target (which requires one + # call) and then fail on the next one. + for trial_idx in range(3): + with pytest.raises(RuntimeError): + run_from_config(config) + + for idx in range(trial_idx + 1): + assert file_exist(idx, "stats.json") + assert not file_exist(idx, ".lock") + + assert not file_exist(trial_idx + 1, "stats.json") + assert file_exist(trial_idx + 1, ".lock") + + run_from_config(config) + + # The last search needs to solve one final target so it will succeed. + for idx in range(4): + assert file_exist(idx, "stats.json") + assert not file_exist(idx, ".lock") + + with open(results_dir / "stats.json", "rt") as f_stats: + stats = json.load(f_stats) + + # Even though each search only solved a single target, final stats should include everything. + assert stats["num_targets"] == stats["num_solved_targets"] == 4 + + # Finally change the targets and verify that the discrepancy will be detected. + with open(search_targets_file_path, "wt") as f_search_targets: + f_search_targets.write("CC\nCCCC\nCC\nCC\n") + + with pytest.raises(RuntimeError): + run_from_config(config)