From 946cb30a815b201aa66be8e13110fd67bea59493 Mon Sep 17 00:00:00 2001 From: Amethyst Reese Date: Fri, 17 May 2024 12:28:47 -0400 Subject: [PATCH] Add support for installing project with extras ghstack-source-id: 0b6103fbcd2d497fab2497386eebbff9497f0ba6 Pull Request resolved: https://github.com/omnilib/thx/pull/85 --- thx/config.py | 2 ++ thx/context.py | 19 +++++++++---- thx/tests/config.py | 4 +++ thx/tests/context.py | 65 ++++++++++++++++++++++++++++++++++++++++++-- thx/types.py | 1 + 5 files changed, 82 insertions(+), 9 deletions(-) diff --git a/thx/config.py b/thx/config.py index d234bc3..7f4cacd 100644 --- a/thx/config.py +++ b/thx/config.py @@ -148,6 +148,7 @@ def load_config(path: Optional[Path] = None) -> Config: requirements: List[str] = ensure_listish( data.pop("requirements", None), "tool.thx.requirements" ) + extras: List[str] = ensure_listish(data.pop("extras", None), "tool.thx.extras") watch_paths: Set[Path] = { Path(p) for p in ensure_listish( @@ -167,6 +168,7 @@ def load_config(path: Optional[Path] = None) -> Config: values=values, versions=versions, requirements=requirements, + extras=extras, watch_paths=watch_paths, ) ) diff --git a/thx/context.py b/thx/context.py index e4cb2d0..fa9cfe4 100644 --- a/thx/context.py +++ b/thx/context.py @@ -7,6 +7,7 @@ import shutil import subprocess import time +from itertools import chain from pathlib import Path from typing import AsyncIterator, Dict, List, Optional, Sequence, Tuple @@ -160,10 +161,12 @@ def needs_update(context: Context, config: Config) -> bool: if timestamp.exists(): base = timestamp.stat().st_mtime_ns newest = 0 - reqs = project_requirements(config) - for req in reqs: - if req.exists(): - mod_time = req.stat().st_mtime_ns + for path in chain( + [config.root / "pyproject.toml"], + project_requirements(config), + ): + if path.exists(): + mod_time = path.stat().st_mtime_ns newest = max(newest, mod_time) return newest > base @@ -219,9 +222,9 @@ async def prepare_virtualenv(context: Context, config: Config) -> AsyncIterator[ pip = which("pip", context) # install requirements.txt - yield VenvCreate(context, message="installing requirements") requirements = project_requirements(config) if requirements: + yield VenvCreate(context, message="installing requirements") LOG.debug("installing deps from %s", requirements) cmd: List[StrPath] = [pip, "install", "-U"] for requirement in requirements: @@ -230,7 +233,11 @@ async def prepare_virtualenv(context: Context, config: Config) -> AsyncIterator[ # install local project yield VenvCreate(context, message="installing project") - await check_command([pip, "install", "-U", config.root]) + if config.extras: + proj = f"{config.root}[{','.join(config.extras)}]" + else: + proj = str(config.root) + await check_command([pip, "install", "-U", proj]) # timestamp marker content = f"{time.time_ns()}\n" diff --git a/thx/tests/config.py b/thx/tests/config.py index 7df3561..d98ba05 100644 --- a/thx/tests/config.py +++ b/thx/tests/config.py @@ -162,6 +162,8 @@ def test_complex_config(self) -> None: [tool.thx] default = ["test", "lint"] module = "foobar" + requirements = "requirements/dev.txt" + extras = "docs" watch_paths = ["foobar", "pyproject.toml"] [tool.thx.values] @@ -204,6 +206,8 @@ def test_complex_config(self) -> None: ), }, values={"module": "foobar", "something": "else"}, + requirements=["requirements/dev.txt"], + extras=["docs"], watch_paths={Path("foobar"), Path("pyproject.toml")}, ) result = load_config(td) diff --git a/thx/tests/context.py b/thx/tests/context.py index 63a89e9..999e58f 100644 --- a/thx/tests/context.py +++ b/thx/tests/context.py @@ -1,6 +1,7 @@ # Copyright 2022 Amethyst Reese # Licensed under the MIT License +import asyncio import platform import subprocess from pathlib import Path @@ -116,8 +117,8 @@ def test_find_runtime_no_venv_binary_found( tdp = Path(td).resolve() config = Config(root=tdp) - which_mock.side_effect = ( - lambda b: f"/fake/bin/{b}" if "." not in b else None + which_mock.side_effect = lambda b: ( + f"/fake/bin/{b}" if "." not in b else None ) for version in TEST_VERSIONS: @@ -340,6 +341,8 @@ async def test_needs_update(self) -> None: with TemporaryDirectory() as td: tdp = Path(td).resolve() + pyproj = tdp / "pyproject.toml" + pyproj.write_text("\n") reqs = tdp / "requirements.txt" reqs.write_text("\n") @@ -355,6 +358,62 @@ async def test_needs_update(self) -> None: (venv / context.TIMESTAMP).write_text("0\n") self.assertFalse(context.needs_update(ctx, config)) + with self.subTest("touch pyproject.toml"): + await asyncio.sleep(0.01) + pyproj.write_text("\n\n") + self.assertTrue(context.needs_update(ctx, config)) + + @patch("thx.context.check_command") + @patch("thx.context.which") + @async_test + async def test_prepare_virtualenv_extras( + self, which_mock: Mock, run_mock: Mock + ) -> None: + self.maxDiff = None + + async def fake_check_command(cmd: Sequence[StrPath]) -> CommandResult: + return CommandResult(0, "", "") + + run_mock.side_effect = fake_check_command + which_mock.side_effect = lambda b, ctx: f"{ctx.venv / 'bin'}/{b}" + + with TemporaryDirectory() as td: + tdp = Path(td).resolve() + venv = tdp / ".thx" / "venv" / "3.9" + venv.mkdir(parents=True) + + config = Config(root=tdp, extras=["more"]) + ctx = Context(Version("3.9"), venv / "bin" / "python", venv) + pip = which_mock("pip", ctx) + reqs = context.project_requirements(config) + self.assertEqual([], reqs) + + events = [event async for event in context.prepare_virtualenv(ctx, config)] + expected = [ + VenvCreate(ctx, "creating virtualenv"), + VenvCreate(ctx, "upgrading pip"), + VenvCreate(ctx, "installing project"), + VenvReady(ctx), + ] + self.assertEqual(expected, events) + + run_mock.assert_has_calls( + [ + call( + [ + ctx.python_path, + "-m", + "pip", + "install", + "-U", + "pip", + "setuptools", + ] + ), + call([pip, "install", "-U", str(config.root) + "[more]"]), + ], + ) + @patch("thx.context.check_command") @patch("thx.context.which") @async_test @@ -397,7 +456,7 @@ async def fake_check_command(cmd: Sequence[StrPath]) -> CommandResult: ] ), call([pip, "install", "-U", "-r", reqs]), - call([pip, "install", "-U", config.root]), + call([pip, "install", "-U", str(config.root)]), ] ) diff --git a/thx/types.py b/thx/types.py index 497b94a..16083bd 100644 --- a/thx/types.py +++ b/thx/types.py @@ -74,6 +74,7 @@ class Config: values: Mapping[str, str] = field(default_factory=dict) versions: Sequence[Version] = field(default_factory=list) requirements: Sequence[str] = field(default_factory=list) + extras: Sequence[str] = field(default_factory=list) watch_paths: Set[Path] = field(default_factory=set) def __post_init__(self) -> None: