diff --git a/src/isolate/backends/conda.py b/src/isolate/backends/conda.py index 39edfc2..575d317 100644 --- a/src/isolate/backends/conda.py +++ b/src/isolate/backends/conda.py @@ -1,15 +1,15 @@ from __future__ import annotations +import copy import functools import os +import secrets import shutil import subprocess import tempfile from dataclasses import dataclass, field from pathlib import Path -from typing import Any, ClassVar, Dict, List, Optional - -import yaml +from typing import Any, ClassVar, Dict, List, Optional, Union from isolate.backends import BaseEnvironment, EnvironmentCreationError from isolate.backends.common import active_python, logged_io, sha256_digest_of @@ -22,7 +22,7 @@ _ISOLATE_CONDA_HOME = os.getenv("ISOLATE_CONDA_HOME") # Conda accepts the following version specifiers: =, ==, >=, <=, >, <, != -_CONDA_VERSION_IDENTIFIER_CHARS = ( +_POSSIBLE_CONDA_VERSION_IDENTIFIERS = ( "=", "<", ">", @@ -34,9 +34,8 @@ class CondaEnvironment(BaseEnvironment[Path]): BACKEND_NAME: ClassVar[str] = "conda" - packages: List[str] = field(default_factory=list) + environment_definition: Dict[str, Any] = field(default_factory=dict) python_version: Optional[str] = None - env_dict: Optional[Dict[str, Any]] = None @classmethod def from_config( @@ -44,61 +43,64 @@ def from_config( config: Dict[str, Any], settings: IsolateSettings = DEFAULT_SETTINGS, ) -> BaseEnvironment: - if config.get("env_dict") and config.get("env_yml_str"): - raise EnvironmentCreationError( - "Either env_dict or env_yml_str can be provided, not both!" + processing_config = copy.deepcopy(config) + processing_config.setdefault("python_version", active_python()) + + if "env_dict" in processing_config: + definition = processing_config.pop("env_dict") + elif "env_yml_str" in processing_config: + import yaml + + definition = yaml.safe_load(processing_config.pop("env_yml_str")) + elif "packages" in processing_config: + definition = { + "dependencies": processing_config.pop("packages"), + } + else: + raise ValueError( + "Either 'env_dict', 'env_yml_str' or 'packages' must be specified" ) - if config.get("env_yml_str"): - config["env_dict"] = yaml.safe_load(config["env_yml_str"]) - del config["env_yml_str"] - environment = cls(**config) + + dependencies = definition.setdefault("dependencies", []) + if _depends_on(dependencies, "python"): + raise ValueError( + "Python version can not be specified by the environment but rather ", + " it needs to be passed as `python_version` option to the environment.", + ) + + dependencies.append(f"python={processing_config['python_version']}") + + # Extend pip dependencies and channels if they are specified. + if "pip" in processing_config: + if not _depends_on(dependencies, "pip"): + dependencies.append("pip") + + try: + dependency_group = next( + dependency + for dependency in dependencies + if isinstance(dependency, dict) and "pip" in dependency + ) + except StopIteration: + dependency_group = {"pip": []} + dependencies.append(dependency_group) + + dependency_group["pip"].extend(processing_config.pop("pip")) + + if "channels" in processing_config: + definition.setdefault("channels", []) + definition["channels"].extend(processing_config.pop("channels")) + + environment = cls( + environment_definition=definition, + **processing_config, + ) environment.apply_settings(settings) return environment @property def key(self) -> str: - if self.env_dict: - return sha256_digest_of(str(self._compute_dependencies())) - return sha256_digest_of(*self._compute_dependencies()) - - def _compute_dependencies(self) -> List[Any]: - if self.env_dict: - user_dependencies = self.env_dict.get("dependencies", []).copy() - else: - user_dependencies = self.packages.copy() - for raw_requirement in user_dependencies: - # It could be 'pip': [...] - if type(raw_requirement) is dict: - continue - # Get rid of all whitespace characters (python = 3.8 becomes python=3.8) - raw_requirement = raw_requirement.replace(" ", "") - if not raw_requirement.startswith("python"): - continue - - # Ensure that the package is either python or python followed - # by a version specifier. Examples: - # - python # OK - # - python=3.8 # OK - # - python>=3.8 # OK - # - python-user-toolkit # NOT OK - # - pythonhelp!=1.0 # NOT OK - - python_suffix = raw_requirement[len("python") :] - if ( - python_suffix - and python_suffix[0] not in _CONDA_VERSION_IDENTIFIER_CHARS - ): - continue - - raise EnvironmentCreationError( - "Python version can not be specified by packages (it needs to be passed as `python_version` option)" - ) - - # Now that we verified that the user did not specify the Python version - # we can add it by ourselves - target_python = self.python_version or active_python() - user_dependencies.append(f"python={target_python}") - return user_dependencies + return sha256_digest_of(repr(self.environment_definition)) def create(self, *, force: bool = False) -> Path: env_path = self.settings.cache_dir_for(self) @@ -106,37 +108,15 @@ def create(self, *, force: bool = False) -> Path: if env_path.exists() and not force: return env_path - if self.env_dict: - self.env_dict["dependencies"] = self._compute_dependencies() - with tempfile.NamedTemporaryFile(mode="w", suffix=".yml") as tf: - yaml.dump(self.env_dict, tf) - tf.flush() - try: - self._run_conda( - "env", "create", "-f", tf.name, "--prefix", env_path - ) - except subprocess.SubprocessError as exc: - raise EnvironmentCreationError( - f"Failure during 'conda create': {exc}" - ) - - else: - # Since our agent needs Python to be installed (at very least) - # we need to make sure that the base environment is created with - # the same Python version as the one that is used to run the - # isolate agent. - dependencies = self._compute_dependencies() - - self.log(f"Creating the environment at '{env_path}'") - self.log(f"Installing packages: {', '.join(dependencies)}") + self.log(f"Creating the environment at '{env_path}'") + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml") as tf: + import yaml + yaml.dump(self.environment_definition, tf) + tf.flush() try: self._run_conda( - "create", - "--yes", - "--prefix", - env_path, - *dependencies, + "env", "create", "--force", "--prefix", env_path, "-f", tf.name ) except subprocess.SubprocessError as exc: raise EnvironmentCreationError( @@ -191,3 +171,33 @@ def _get_conda_executable() -> Path: "Could not find conda executable. If conda executable is not available by default, please point isolate " " to the path where conda binary is available 'ISOLATE_CONDA_HOME'." ) + + +def _depends_on( + dependencies: List[Union[str, Dict[str, List[str]]]], + package_name: str, +) -> bool: + for dependency in dependencies: + if isinstance(dependency, dict): + # It is a dependency group like pip: [...] + continue + + # Get rid of all whitespace characters (python = 3.8 becomes python=3.8) + package = dependency.replace(" ", "") + if not package.startswith(package_name): + continue + + # Ensure that the package name matches perfectly and not only + # at the prefix level. Examples: + # - python # OK + # - python=3.8 # OK + # - python>=3.8 # OK + # - python-user-toolkit # NOT OK + # - pythonhelp!=1.0 # NOT OK + suffix = package[len(package_name) :] + if suffix and suffix[0] not in _POSSIBLE_CONDA_VERSION_IDENTIFIERS: + continue + + return True + else: + return False diff --git a/tests/test_backends.py b/tests/test_backends.py index 26f8608..2c434d2 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -1,6 +1,7 @@ import re import subprocess import sys +import textwrap from contextlib import contextmanager from functools import partial from os import environ @@ -393,9 +394,11 @@ class TestConda(GenericEnvironmentTests): }, "old-python": { "python_version": "3.7", + "packages": [], }, "new-python": { "python_version": "3.10", + "packages": [], }, "env-dict": { "env_dict": { @@ -442,18 +445,68 @@ def test_conda_binary_execution(self, tmp_path): def test_fail_when_user_overwrites_python( self, tmp_path, user_packages, python_version ): - environment = self.get_environment( - tmp_path, + with pytest.raises( + ValueError, + match="Python version can not be specified by the environment", + ): + self.get_environment( + tmp_path, + { + "packages": user_packages, + "python_version": python_version, + }, + ) + + @pytest.mark.parametrize( + "configuration", + [ + { + "env_dict": { + "name": "test", + "channels": "defaults", + "dependencies": ["a", "b"], + } + }, + { + "env_dict": { + "name": "test", + "channels": "defaults", + "dependencies": ["a", "b", "pip", {"pip": ["c", "d"]}], + } + }, + { + "env_yml_str": textwrap.dedent( + """ + name: test + channels: + - defaults + - conda-forge + """ + ) + }, { - "packages": user_packages, - "python_version": python_version, + "packages": ["a", "piped", "b"], }, + ], + ) + def test_add_pip_dependencies(self, tmp_path, configuration): + environment = self.get_environment( + tmp_path, {**configuration, "pip": ["agent"]} ) - with pytest.raises( - EnvironmentCreationError, - match="Python version can not be specified by packages", - ): - environment.create() + all_deps = environment.environment_definition["dependencies"] + assert "pip" in all_deps # Ensurue pip is added as a dependency + assert ( + all_deps.count("pip") == 1 + ) # And it does not appear twice (when the environment already supplies itr) + + dep_groups = [ + dependency + for dependency in all_deps + if isinstance(dependency, dict) and "pip" in dependency + ] + assert len(dep_groups) == 1 + pip_dep = dep_groups[0]["pip"] + assert "agent" in pip_dep # And pip dependency is added def test_local_python_environment(): @@ -641,6 +694,18 @@ def test_isolate_server_multiple_envs(isolate_server): ] }, ), + ( + "conda", + { + "packages": [ + "pyjokes=1.0.0", + ], + "env_dict": { + "name": "test", + "dependencies": ["pyjokes=2.0.0"], + }, + }, + ), ( "isolate-server", { @@ -657,7 +722,7 @@ def test_isolate_server_multiple_envs(isolate_server): ], ) def test_wrong_options(kind, config): - with pytest.raises(TypeError): + with pytest.raises((TypeError, ValueError)): isolate.prepare_environment(kind, **config)