diff --git a/conda_subprocess/process.py b/conda_subprocess/process.py index 960eb3f..b5ae52b 100644 --- a/conda_subprocess/process.py +++ b/conda_subprocess/process.py @@ -2,7 +2,14 @@ from subprocess import Popen as subprocess_Popen from conda.auxlib.compat import shlex_split_unicode -from conda.base.context import context, validate_prefix_name +from conda.auxlib.ish import dals +from conda.base.context import ( + context, + _first_writable_envs_dir, + ROOT_ENV_NAME, + PREFIX_NAME_DISALLOWED_CHARS, +) +from conda.exceptions import EnvironmentNameNotFound, CondaValueError from conda.cli.common import validate_prefix from conda.common.compat import encode_arguments, encode_environment, isiterable from conda.common.path import expand @@ -92,7 +99,7 @@ def _check_prefix(prefix_name=None, prefix_path=None): elif prefix_path is not None: return expand(prefix_path) else: - return validate_prefix_name(prefix_name, ctx=context) + return _validate_prefix_name(prefix_name, ctx=context) def _check_args(args): @@ -100,3 +107,54 @@ def _check_args(args): return args.split() else: return args + + +def _locate_prefix_by_name(name, envs_dirs=None): + """Find the location of a prefix given a conda env name. If the location does not exist, an + error is raised. + """ + assert name + if name in (ROOT_ENV_NAME, "root"): + return context.root_prefix + if envs_dirs is None: + envs_dirs = context.envs_dirs + for envs_dir in envs_dirs: + if not os.path.isdir(envs_dir): + continue + prefix = os.path.join(envs_dir, name) + if os.path.isdir(prefix): + return os.path.abspath(prefix) + raise EnvironmentNameNotFound(name) + + +def _validate_prefix_name(prefix_name: str, ctx: context, allow_base=True) -> str: + """Run various validations to make sure prefix_name is valid""" + if PREFIX_NAME_DISALLOWED_CHARS.intersection(prefix_name): + raise CondaValueError( + dals( + f""" + Invalid environment name: {prefix_name!r} + Characters not allowed: {PREFIX_NAME_DISALLOWED_CHARS} + If you are specifying a path to an environment, the `-p` + flag should be used instead. + """ + ) + ) + + if prefix_name in (ROOT_ENV_NAME, "root"): + if allow_base: + return ctx.root_prefix + else: + raise CondaValueError( + "Use of 'base' as environment name is not allowed here." + ) + + else: + envs_dirs = context.envs_dirs + envs_dirs += tuple( + [os.path.abspath(os.path.join(os.environ["CONDA_EXE"], "..", "..", "envs"))] + ) + try: + return _locate_prefix_by_name(name=prefix_name, envs_dirs=envs_dirs) + except EnvironmentNameNotFound: + return os.path.join(_first_writable_envs_dir(), prefix_name) diff --git a/tests/test_conda_subprocess.py b/tests/test_conda_subprocess.py index d3d9c1a..d1f04db 100644 --- a/tests/test_conda_subprocess.py +++ b/tests/test_conda_subprocess.py @@ -7,16 +7,24 @@ class TestCondaSubprocess(TestCase): - def setUp(self): - self.env_path = os.path.join(context.root_prefix, "..", "py312") + @classmethod + def setUpClass(cls): + cls.env_name = "py312" + cls.env_path = os.path.join(context.root_prefix, "..", cls.env_name) - def test_call(self): + def test_call_path(self): self.assertEqual(call("python --version", prefix_path=self.env_path), 0) - def test_check_call(self): + def test_call_name(self): + self.assertEqual(call("python --version", prefix_name=self.env_name), 0) + + def test_check_call_path(self): self.assertEqual(check_call("python --version", prefix_path=self.env_path), 0) - def test_check_output(self): + def test_check_call_name(self): + self.assertEqual(check_call("python --version", prefix_name=self.env_name), 0) + + def test_check_output_path(self): if os.name == "nt": self.assertEqual( check_output("python --version", prefix_path=self.env_path), @@ -28,6 +36,15 @@ def test_check_output(self): b"Python 3.12.1\n", ) + def test_check_output_name(self): + expected_output = ( + b"Python 3.12.1\r\n" if os.name == "nt" else b"Python 3.12.1\n" + ) + self.assertEqual( + check_output("python --version", prefix_name=self.env_name), + expected_output, + ) + def test_check_output_universal_newlines(self): self.assertEqual( check_output( @@ -36,12 +53,17 @@ def test_check_output_universal_newlines(self): "Python 3.12.1\n", ) - def test_run(self): + def test_run_path(self): self.assertEqual( run("python --version", prefix_path=self.env_path).returncode, 0 ) - def test_popen(self): + def test_run_name(self): + self.assertEqual( + run("python --version", prefix_name=self.env_name).returncode, 0 + ) + + def test_popen_path(self): process = Popen("python --version", prefix_path=self.env_path, stdout=PIPE) output = process.communicate() if os.name == "nt": @@ -50,6 +72,15 @@ def test_popen(self): self.assertEqual(output[0], b"Python 3.12.1\n") self.assertIsNone(output[1]) + def test_popen_name(self): + process = Popen("python --version", prefix_name=self.env_name, stdout=PIPE) + output = process.communicate() + if os.name == "nt": + self.assertEqual(output[0], b"Python 3.12.1\r\n") + else: + self.assertEqual(output[0], b"Python 3.12.1\n") + self.assertIsNone(output[1]) + def test_environment_variable(self): self.assertTrue( "TESTVAR=test"