From 5a007ea69b1452f4c772fa1615513abd5c29b1d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcel=20M=C3=BCller?= Date: Fri, 16 Aug 2024 12:45:34 +0200 Subject: [PATCH] Add option to set the xtb_path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marcel Müller --- mindlessgen.toml | 1 + src/mindlessgen/cli/cli_parser.py | 11 ++++++++--- src/mindlessgen/generator/main.py | 2 +- src/mindlessgen/prog/config.py | 17 +++++++++++++++++ src/mindlessgen/qm/xtb.py | 16 +++++++++++----- 5 files changed, 38 insertions(+), 9 deletions(-) diff --git a/mindlessgen.toml b/mindlessgen.toml index 5fd45cd..fc739ac 100644 --- a/mindlessgen.toml +++ b/mindlessgen.toml @@ -29,6 +29,7 @@ max_num_atoms = 100 max_frag_cycles = 100 [xtb] +xtb_path = "/path/to/xtb" # TODO # Specific configurations for the XTB engine (if needed) # xtb_option_1 = "value1" diff --git a/src/mindlessgen/cli/cli_parser.py b/src/mindlessgen/cli/cli_parser.py index 28174da..c2ed4b7 100644 --- a/src/mindlessgen/cli/cli_parser.py +++ b/src/mindlessgen/cli/cli_parser.py @@ -76,8 +76,13 @@ def cli_parser(argv: Sequence[str] | None = None) -> dict: required=False, help="Maximum number of fragment optimization cycles.", ) - # XTB specific arguments - # TODO: Add XTB specific arguments + # xTB specific arguments + parser.add_argument( + "--xtb-path", + type=str, + required=False, + help="Path to the xTB binary.", + ) # ORCA specific arguments # TODO: Add ORCA specific arguments args = parser.parse_args(argv) @@ -101,7 +106,7 @@ def cli_parser(argv: Sequence[str] | None = None) -> dict: "max_num_atoms": args_dict["max_num_atoms"], } # XTB specific arguments - rev_args_dict["xtb"] = {} + rev_args_dict["xtb"] = {"xtb_path": args_dict["xtb_path"]} # ORCA specific arguments rev_args_dict["orca"] = {} diff --git a/src/mindlessgen/generator/main.py b/src/mindlessgen/generator/main.py index 5080fea..0910217 100644 --- a/src/mindlessgen/generator/main.py +++ b/src/mindlessgen/generator/main.py @@ -33,7 +33,7 @@ def generator(config: ConfigManager) -> tuple[Molecule | None, int]: if config.general.engine == "xtb": try: - xtb_path = get_xtb_path(["xtb_dev", "xtb"]) + xtb_path = get_xtb_path(config.xtb.xtb_path) if not xtb_path: raise ImportError("xtb not found.") except ImportError as e: diff --git a/src/mindlessgen/prog/config.py b/src/mindlessgen/prog/config.py index 72aedc4..d2381fe 100644 --- a/src/mindlessgen/prog/config.py +++ b/src/mindlessgen/prog/config.py @@ -210,6 +210,7 @@ class XTBConfig(BaseConfig): def __init__(self): self._xtb_option: str = "dummy" + self._xtb_path: str | Path = "xtb" def get_identifier(self) -> str: return "xtb" @@ -230,6 +231,22 @@ def xtb_option(self, xtb_option: str): raise TypeError("xtb_option should be a string.") self._xtb_option = xtb_option + @property + def xtb_path(self): + """ + Get the xtb path. + """ + return self._xtb_path + + @xtb_path.setter + def xtb_path(self, xtb_path: str | Path): + """ + Set the xtb path. + """ + if not isinstance(xtb_path, str | Path): + raise TypeError("xtb_path should be a string.") + self._xtb_path = xtb_path + class ORCAConfig(BaseConfig): """ diff --git a/src/mindlessgen/qm/xtb.py b/src/mindlessgen/qm/xtb.py index 95c08de..924e380 100644 --- a/src/mindlessgen/qm/xtb.py +++ b/src/mindlessgen/qm/xtb.py @@ -159,15 +159,21 @@ def run(self, temp_path: Path, arguments: list[str]) -> tuple[str, str, int]: return xtb_log_out, xtb_log_err, e.returncode -def get_xtb_path(binary_names: list[str]) -> Path | None: +def get_xtb_path(binary_name: str | Path | None = None) -> Path: """ Get the path to the xtb binary based on different possible names that are searched for in the PATH. """ + default_xtb_names: list[str | Path] = ["xtb", "xtb_dev"] + # put binary name at the beginning of the lixt to prioritize it + if binary_name is not None: + binary_names = [binary_name] + default_xtb_names + else: + binary_names = default_xtb_names # Get xtb path from 'which xtb' command - for binary_name in binary_names: - which_xtb = shutil.which(binary_name) + for binpath in binary_names: + which_xtb = shutil.which(binpath) if which_xtb: - xtb_path = Path(which_xtb) + xtb_path = Path(which_xtb).resolve() return xtb_path - raise ImportError("'xtb' or 'xtb_dev' not found.") + raise ImportError("'xtb' binary could not be found.")