Skip to content

Commit

Permalink
upgrade light-the-torch requirements (#5)
Browse files Browse the repository at this point in the history
* upgrade light-the-torch requirements

* extract force_cpu help from light-the-torch
  • Loading branch information
pmeier authored Jul 16, 2020
1 parent 993aa58 commit a943140
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 32 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ packages = find:
include_package_data = True
python_requires = >=3.6
install_requires =
light-the-torch>=0.1.1
light-the-torch>=0.2
tox

[options.packages.find]
Expand Down
81 changes: 56 additions & 25 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,35 @@
from light_the_torch.computation_backend import CPUBackend


@pytest.mark.slow
def test_help_ini(cmd):
result = cmd("--help-ini")
result.assert_success(is_run_test_env=False)
assert "disable_light_the_torch" in result.out
assert "force_cpu" in result.out
@pytest.fixture
def patch_extract_dists(mocker):
def patch_extract_dists_(return_value=None):
if return_value is None:
return_value = []
return mocker.patch(
"tox_ltt.plugin.ltt.extract_dists", return_value=return_value
)
return mocker.patch()

return patch_extract_dists_


@pytest.fixture
def patch_find_links(mocker):
def patch_find_links_(return_value=None):
if return_value is None:
return_value = []
return mocker.patch(
"tox_ltt.plugin.ltt.find_links", return_value=return_value
)
return mocker.patch()

return patch_find_links_


@pytest.fixture
def install_mock(mocker):
return mocker.patch("tox.venv.VirtualEnv.run_install_command")


def get_pyproject_toml():
Expand Down Expand Up @@ -100,14 +123,16 @@ def tox_ltt_initproj_(
return tox_ltt_initproj_


@pytest.fixture
def install_mock(mocker):
return mocker.patch("tox.venv.VirtualEnv.run_install_command")
def test_help_ini(cmd):
result = cmd("--help-ini")
result.assert_success(is_run_test_env=False)
assert "disable_light_the_torch" in result.out
assert "force_cpu" in result.out


@pytest.mark.slow
def test_tox_ltt_disabled(mocker, tox_ltt_initproj, cmd):
mock = mocker.patch("tox_ltt.plugin.ltt.resolve_dists")
def test_tox_ltt_disabled(patch_extract_dists, tox_ltt_initproj, cmd):
mock = patch_extract_dists()
tox_ltt_initproj(disable_light_the_torch=True)

result = cmd()
Expand All @@ -117,9 +142,8 @@ def test_tox_ltt_disabled(mocker, tox_ltt_initproj, cmd):


@pytest.mark.slow
def test_tox_ltt_force_cpu(mocker, tox_ltt_initproj, cmd, install_mock):
mock = mocker.patch("tox_ltt.plugin.ltt.find_links", return_value=[])

def test_tox_ltt_force_cpu(patch_find_links, tox_ltt_initproj, cmd, install_mock):
mock = patch_find_links()
tox_ltt_initproj(deps=("torch",), force_cpu=True)

result = cmd()
Expand All @@ -130,9 +154,10 @@ def test_tox_ltt_force_cpu(mocker, tox_ltt_initproj, cmd, install_mock):
assert kwargs["computation_backend"] == CPUBackend()


def test_tox_ltt_no_requirements(mocker, tox_ltt_initproj, cmd, install_mock):
mock = mocker.patch("tox_ltt.plugin.ltt.resolve_dists")

def test_tox_ltt_no_requirements(
patch_extract_dists, tox_ltt_initproj, cmd, install_mock
):
mock = patch_extract_dists()
tox_ltt_initproj(skip_install=True)

result = cmd()
Expand All @@ -142,8 +167,10 @@ def test_tox_ltt_no_requirements(mocker, tox_ltt_initproj, cmd, install_mock):


@pytest.mark.slow
def test_tox_ltt_no_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_mock):
mock = mocker.patch("tox_ltt.plugin.ltt.find_links")
def test_tox_ltt_no_pytorch_dists(
patch_find_links, tox_ltt_initproj, cmd, install_mock
):
mock = patch_find_links()

deps = ("light-the-torch",)
tox_ltt_initproj(deps=deps)
Expand All @@ -155,8 +182,10 @@ def test_tox_ltt_no_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_mock):


@pytest.mark.slow
def test_tox_ltt_direct_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_mock):
mock = mocker.patch("tox_ltt.plugin.ltt.find_links", return_value=[])
def test_tox_ltt_direct_pytorch_dists(
patch_find_links, tox_ltt_initproj, cmd, install_mock
):
mock = patch_find_links()

deps = ("torch", "torchaudio", "torchtext", "torchvision")
dists = set(deps)
Expand All @@ -171,8 +200,10 @@ def test_tox_ltt_direct_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_moc


@pytest.mark.slow
def test_tox_ltt_indirect_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_mock):
mock = mocker.patch("tox_ltt.plugin.ltt.find_links", return_value=[])
def test_tox_ltt_indirect_pytorch_dists(
patch_find_links, tox_ltt_initproj, cmd, install_mock
):
mock = patch_find_links()

deps = ("git+https://github.com/pmeier/[email protected]",)
dists = {"torch>=1.5.0", "torchvision>=0.6.0"}
Expand All @@ -187,9 +218,9 @@ def test_tox_ltt_indirect_pytorch_dists(mocker, tox_ltt_initproj, cmd, install_m


def test_tox_ltt_project_pytorch_dists(
subtests, mocker, tox_ltt_initproj, cmd, install_mock
subtests, patch_find_links, tox_ltt_initproj, cmd, install_mock
):
mock = mocker.patch("tox_ltt.plugin.ltt.find_links", return_value=[])
mock = patch_find_links()

install_requires = ("torch>=1.5.0", "torchvision>=0.6.0")
dists = set(install_requires)
Expand Down
26 changes: 20 additions & 6 deletions tox_ltt/plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Optional, Sequence, cast

import tox
from tox import reporter
Expand All @@ -7,9 +7,25 @@
from tox.venv import VirtualEnv

import light_the_torch as ltt
from light_the_torch.cli import make_ltt_parser
from light_the_torch.computation_backend import CPUBackend


def extract_force_cpu_help() -> str:
def extract(seq: Sequence, attr: str, eq_cond: Any) -> Any:
reduced_seq = [item for item in seq if getattr(item, attr) == eq_cond]
assert len(reduced_seq) == 1
return reduced_seq[0]

ltt_parser = make_ltt_parser()

argument_group = extract(ltt_parser._action_groups, "title", "subcommands")
sub_parsers = extract(argument_group._actions, "dest", "subcommand")
install_parser = sub_parsers.choices["install"]
force_cpu = extract(install_parser._actions, "dest", "force_cpu")
return cast(str, force_cpu.help)


@tox.hookimpl
def tox_addoption(parser: Parser) -> None:
parser.add_testenv_attribute(
Expand All @@ -18,11 +34,9 @@ def tox_addoption(parser: Parser) -> None:
help="disable installing PyTorch distributions with light-the-torch",
default=False,
)

parser.add_testenv_attribute(
name="force_cpu",
type="bool",
help="force CPU as computation backend",
default=False,
name="force_cpu", type="bool", help=extract_force_cpu_help(), default=False,
)


Expand Down Expand Up @@ -54,7 +68,7 @@ def tox_testenv_install_deps(venv: VirtualEnv, action: Action) -> None:

action.setactivity("finddeps-light-the-torch", "")

dists = ltt.resolve_dists(requirements)
dists = ltt.extract_dists(requirements)

if not dists:
reporter.verbosity1(
Expand Down

0 comments on commit a943140

Please sign in to comment.