diff --git a/docs/api_reference.md b/docs/api_reference.md index 3a2f5db8..de4f9f4a 100644 --- a/docs/api_reference.md +++ b/docs/api_reference.md @@ -274,7 +274,7 @@ CaseType = Union[Callable, Type, ModuleRef] A decorator for test functions or fixtures, to parametrize them based on test cases. It works similarly to [`@pytest.mark.parametrize`](https://docs.pytest.org/en/stable/parametrize.html): argnames represent a coma-separated string of arguments to inject in the decorated test function or fixture. The argument values (`argvalues` in [`@pytest.mark.parametrize`](https://docs.pytest.org/en/stable/parametrize.html)) are collected from the various case functions found according to `cases`, and injected as lazy values so that the case functions are called just before the test or fixture is executed. -By default (`cases=AUTO`) the list of test cases is automatically drawn from the python module file named `test__cases.py` or if not found, `cases_.py`, where `test_` is the current module name. +By default (`cases=AUTO`) the list of test cases is automatically drawn from the python module file named `test__cases.py` or if not found, `cases_.py`, where `test_` is the current module name. Also works for `tests.py` (`tests_cases.py`). Finally, the `cases` argument also accepts an explicit case function, cases-containing class, module or module name; or a list containing any mix of these elements. Note that both absolute and relative module names are supported. diff --git a/docs/changelog.md b/docs/changelog.md index 269fd663..7d58b24b 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,10 @@ # Changelog + +### 3.8.7 (in progress) + +- Allow `tests.py` to find tests from `tests_cases.py`. + ### 3.8.6 - compatibility fix - Fixed issue with legacy python 2.7 and 3.5. Fixes [#352](https://github.com/smarie/python-pytest-cases/issues/352). diff --git a/src/pytest_cases/case_parametrizer_new.py b/src/pytest_cases/case_parametrizer_new.py index 9796f062..47c0f17e 100644 --- a/src/pytest_cases/case_parametrizer_new.py +++ b/src/pytest_cases/case_parametrizer_new.py @@ -304,7 +304,7 @@ def get_all_cases(parametrization_target=None, # type: Callable # as we don't know what to look for. We complain here # rather than raising AssertionError in the call to # import_default_cases_module. See #309. - if not caller_module_name.split('.')[-1].startswith('test_'): + if not _has_test_prefix(caller_module_name.split('.')[-1]): raise ValueError( 'Cannot use `cases=AUTO` in file "%s". `cases=AUTO` is ' 'only allowed in files whose name starts with "test_" ' @@ -324,6 +324,11 @@ def get_all_cases(parametrization_target=None, # type: Callable if matches_tag_query(c, has_tag=has_tag, filter=filters)] +def _has_test_prefix(module_name): # type: (str) -> bool + prefixes = ('test_', 'tests') + return any(module_name.startswith(p) for p in prefixes) + + def get_parametrize_args(host_class_or_module, # type: Union[Type, ModuleType] cases_funs, # type: List[Callable] prefix, # type: str @@ -693,7 +698,7 @@ def import_default_cases_module(test_module_name): except ModuleNotFoundError: # Then try `cases_.py` parts = test_module_name.split('.') - assert parts[-1][0:5] == 'test_' + assert _has_test_prefix(parts[-1]) cases_module_name2 = "%s.cases_%s" % ('.'.join(parts[:-1]), parts[-1][5:]) try: cases_module = import_module(cases_module_name2) diff --git a/tests/cases/issues/issue_309/tests.py b/tests/cases/issues/issue_309/tests.py new file mode 100644 index 00000000..1d3e8b4b --- /dev/null +++ b/tests/cases/issues/issue_309/tests.py @@ -0,0 +1,10 @@ +from pytest_cases import get_all_cases +from pytest_cases.common_others import AUTO + + +def mock_parameterization_target(): + """A callable to use as parametrization target.""" + + +def test_get_all_cases_auto_works_in_tests_py(): + get_all_cases(mock_parameterization_target, cases=AUTO) diff --git a/tests/cases/issues/issue_309/tests_cases.py b/tests/cases/issues/issue_309/tests_cases.py new file mode 100644 index 00000000..ce5818d7 --- /dev/null +++ b/tests/cases/issues/issue_309/tests_cases.py @@ -0,0 +1,2 @@ +def case_one(): + return 1