diff --git a/rialto/jobs/__init__.py b/rialto/jobs/__init__.py index b9f443a..46eb756 100644 --- a/rialto/jobs/__init__.py +++ b/rialto/jobs/__init__.py @@ -12,4 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.jobs.decorators import config_parser, datasource, job, register_module +from rialto.jobs.decorators import config_parser, datasource, job +from rialto.jobs.module_register import ( + register_dependency_callable, + register_dependency_module, +) diff --git a/rialto/jobs/decorators.py b/rialto/jobs/decorators.py index 617e8d2..68bb58f 100644 --- a/rialto/jobs/decorators.py +++ b/rialto/jobs/decorators.py @@ -14,7 +14,6 @@ __all__ = ["datasource", "job", "config_parser"] -import inspect import typing import importlib_metadata @@ -25,11 +24,6 @@ from rialto.jobs.module_register import ModuleRegister -def register_module(module): - caller_module = get_caller_module() - ModuleRegister.register_dependency(caller_module, module) - - def config_parser(cf_getter: typing.Callable) -> typing.Callable: """ Config parser functions decorator. diff --git a/rialto/jobs/job_base.py b/rialto/jobs/job_base.py index e548503..0a4779d 100644 --- a/rialto/jobs/job_base.py +++ b/rialto/jobs/job_base.py @@ -17,14 +17,12 @@ import abc import datetime import typing -from contextlib import contextmanager import pyspark.sql.functions as F from loguru import logger from pyspark.sql import DataFrame, SparkSession from rialto.common import TableReader -from rialto.jobs.module_register import ModuleRegister from rialto.jobs.resolver import Resolver from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager @@ -50,8 +48,7 @@ def get_job_name(self) -> str: """Job name getter""" pass - @contextmanager - def _setup_resolver( + def _get_resolver( self, spark: SparkSession, run_date: datetime.date, @@ -59,30 +56,23 @@ def _setup_resolver( config: PipelineConfig = None, metadata_manager: MetadataManager = None, feature_loader: PysparkFeatureLoader = None, - ) -> None: - # Static Always - Available dependencies - Resolver.register_object(spark, "spark") - Resolver.register_object(run_date, "run_date") - Resolver.register_object(config, "config") - Resolver.register_object(table_reader, "table_reader") + ) -> Resolver: + resolver = Resolver() - # Datasets & Configs - callable_module_name = self.get_custom_callable().__module__ - for m in ModuleRegister.get_registered_callables(callable_module_name): - Resolver.register_callable(m) + # Static Always - Available dependencies + resolver.register_object(spark, "spark") + resolver.register_object(run_date, "run_date") + resolver.register_object(config, "config") + resolver.register_object(table_reader, "table_reader") # Optionals if feature_loader is not None: - Resolver.register_object(feature_loader, "feature_loader") + resolver.register_object(feature_loader, "feature_loader") if metadata_manager is not None: - Resolver.register_object(metadata_manager, "metadata_manager") + resolver.register_object(metadata_manager, "metadata_manager") - try: - yield - - finally: - Resolver.clear() + return resolver def _get_timestamp_holder_result(self, spark) -> DataFrame: return spark.createDataFrame( @@ -116,9 +106,10 @@ def run( :return: dataframe """ try: - with self._setup_resolver(spark, run_date, reader, config, metadata_manager, feature_loader): - custom_callable = self.get_custom_callable() - raw_result = Resolver.register_resolve(custom_callable) + resolver = self._get_resolver(spark, run_date, reader, config, metadata_manager, feature_loader) + + custom_callable = self.get_custom_callable() + raw_result = resolver.resolve(custom_callable) if raw_result is None: raw_result = self._get_timestamp_holder_result(spark) diff --git a/rialto/jobs/module_register.py b/rialto/jobs/module_register.py index a7d0e87..afeb2ed 100644 --- a/rialto/jobs/module_register.py +++ b/rialto/jobs/module_register.py @@ -12,38 +12,91 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["ModuleRegister"] +__all__ = ["ModuleRegister", "register_dependency_module", "register_dependency_callable"] + +from rialto.common.utils import get_caller_module class ModuleRegister: + """ + Module register. Class which is used by @datasource and @config_parser decorators to register callables / getters. + Resolver, when searching for a getter for f() defined in module M, uses find_callable("f", "M"). + """ + _storage = {} _dependency_tree = {} @classmethod - def register_callable(cls, callable): - callable_module = callable.__module__ + def add_callable_to_module(cls, callable, module_name): + """ + Adds a callable to the specified module's storage. - module_callables = cls._storage.get(callable_module, []) + :param callable: The callable to be added. + :param module_name: The name of the module to which the callable is added. + """ + module_callables = cls._storage.get(module_name, []) module_callables.append(callable) - cls._storage[callable_module] = module_callables + cls._storage[module_name] = module_callables + + @classmethod + def register_callable(cls, callable): + """ + Registers a callable by adding it to the module's storage. + + :param callable: The callable to be registered. + """ + callable_module = callable.__module__ + cls.add_callable_to_module(callable, callable_module) @classmethod def register_dependency(cls, caller_module, module): - caller_module_name = caller_module.__name__ - target_module_name = module.__name__ + """ + Registers a module as a dependency of the caller module. - module_dep_tree = cls._dependency_tree.get(caller_module_name, []) - module_dep_tree.append(target_module_name) + :param caller_module: The module that is registering the dependency. + :param module: The module to be registered as a dependency. + """ + module_dep_tree = cls._dependency_tree.get(caller_module, []) + module_dep_tree.append(module) - cls._dependency_tree[caller_module_name] = module_dep_tree + cls._dependency_tree[caller_module] = module_dep_tree @classmethod - def get_registered_callables(cls, module_name): - callables = cls._storage.get(module_name, []) + def find_callable(cls, callable_name, module_name): + """ + Finds a callable by its name in the specified module and its dependencies. + + :param callable_name: The name of the callable to find. + :param module_name: The name of the module to search in. + :return: The found callable or None if not found. + """ + + # Loop through this module, and its dependencies + searched_modules = [module_name] + cls._dependency_tree.get(module_name, []) + for module in searched_modules: + # Loop through all functions registered in the module + for func in cls._storage.get(module, []): + if func.__name__ == callable_name: + return func + + +def register_dependency_module(module): + """ + Registers a module as a dependency of the caller module. + + :param module: The module to be registered as a dependency. + """ + caller_module = get_caller_module().__name__ + ModuleRegister.register_dependency(caller_module, module.__name__) + - for included_module in cls._dependency_tree.get(module_name, []): - included_callables = cls.get_registered_callables(included_module) - callables.extend(included_callables) +def register_dependency_callable(callable): + """ + Registers a callable as a dependency of the caller module. + Note that the function will be added to the module's list of available dependencies. - return callables + :param callable: The callable to be registered as a dependency. + """ + caller_module_name = get_caller_module().__name__ + ModuleRegister.add_callable_to_module(callable, caller_module_name) diff --git a/rialto/jobs/resolver.py b/rialto/jobs/resolver.py index b6e06d2..34b08e8 100644 --- a/rialto/jobs/resolver.py +++ b/rialto/jobs/resolver.py @@ -16,7 +16,8 @@ import inspect import typing -from functools import cache + +from rialto.jobs.module_register import ModuleRegister class ResolverException(Exception): @@ -33,20 +34,10 @@ class Resolver: Calling resolve() we attempt to resolve these dependencies. """ - _storage = {} - - @classmethod - def _get_args_for_call(cls, function: typing.Callable) -> typing.Dict[str, typing.Any]: - result_dict = {} - signature = inspect.signature(function) - - for param in signature.parameters.values(): - result_dict[param.name] = cls.resolve(param.name) - - return result_dict + def __init__(self): + self._storage = {} - @classmethod - def register_object(cls, object: typing.Any, name: str) -> None: + def register_object(self, object: typing.Any, name: str) -> None: """ Register an object with a given name for later resolution. @@ -55,10 +46,9 @@ def register_object(cls, object: typing.Any, name: str) -> None: :return: None """ - cls.register_callable(lambda: object, name) + self.register_getter(lambda: object, name) - @classmethod - def register_callable(cls, callable: typing.Callable, name: str = None) -> str: + def register_getter(self, callable: typing.Callable, name: str = None) -> str: """ Register callable with a given name for later resolution. @@ -70,54 +60,45 @@ def register_callable(cls, callable: typing.Callable, name: str = None) -> str: """ if name is None: name = getattr(callable, "__name__", repr(callable)) - """ - if name in cls._storage: + + if name in self._storage: raise ResolverException(f"Resolver already registered {name}!") - """ - cls._storage[name] = callable + self._storage[name] = callable return name - @classmethod - @cache - def resolve(cls, name: str) -> typing.Any: + def _find_getter(self, name: str, module_name) -> typing.Callable: + if name in self._storage.keys(): + return self._storage[name] + + callable_from_dependencies = ModuleRegister.find_callable(name, module_name) + if callable_from_dependencies is None: + raise ResolverException(f"{name} declaration not found!") + + return callable_from_dependencies + + def resolve(self, callable: typing.Callable) -> typing.Dict[str, typing.Any]: """ - Search for a callable registered prior and attempt to call it with correct arguents. + Take a callable and resolve its dependencies / arguments. Arguments can be + a) objects registered via register_object + b) callables registered via register_getter + c) ModuleRegister registered callables via ModuleRegister.register_callable (+ dependencies) Arguments are resolved recursively according to requirements; For example, if we have a(b, c), b(d), and c(), d() registered, then we recursively call resolve() methods until we resolve c, d -> b -> a - :param name: name of the callable to resolve + :param callable: function to resolve :return: result of the callable """ - if name not in cls._storage.keys(): - raise ResolverException(f"{name} declaration not found!") - - getter = cls._storage[name] - args = cls._get_args_for_call(getter) - - return getter(**args) - @classmethod - def register_resolve(cls, callable: typing.Callable) -> typing.Any: - """ - Register and Resolve a callable. + arg_list = {} - Combination of the register() and resolve() methods for a simplified execution. + signature = inspect.signature(callable) + module_name = callable.__module__ - :param callable: callable to register and immediately resolve - :return: result of the callable - """ - name = cls.register_callable(callable) - return cls.resolve(name) - - @classmethod - def clear(cls) -> None: - """ - Clear all registered datasources and jobs. + for param in signature.parameters.values(): + param_getter = self._find_getter(param.name, module_name) + arg_list[param.name] = self.resolve(param_getter) - :return: None - """ - cls.resolve.cache_clear() - cls._storage.clear() + return callable(**arg_list) diff --git a/rialto/jobs/test_utils.py b/rialto/jobs/test_utils.py index e363e80..d8f2945 100644 --- a/rialto/jobs/test_utils.py +++ b/rialto/jobs/test_utils.py @@ -77,40 +77,37 @@ def resolver_resolves(spark, job: JobBase) -> bool: :return: bool, True if job can be resolved """ - - class SmartStorage: - def __init__(self): - self._storage = Resolver._storage.copy() - self._call_stack = [] - - def __setitem__(self, key, value): - self._storage[key] = value - - def keys(self): - return self._storage.keys() - - def clear(self): - self._storage.clear() - - def __getitem__(self, func_name): - if func_name in self._call_stack: - raise ResolverException(f"Circular Dependence on {func_name}!") - - self._call_stack.append(func_name) - - real_method = self._storage[func_name] - fake_method = create_autospec(real_method) - fake_method.side_effect = lambda *args, **kwargs: self._call_stack.remove(func_name) - - return fake_method - - with patch("rialto.jobs.resolver.Resolver._storage", SmartStorage()): - job().run( - reader=MagicMock(), - run_date=MagicMock(), - spark=spark, - metadata_manager=MagicMock(), - feature_loader=MagicMock(), - ) - - return True + call_stack = [] + original_resolve_method = Resolver.resolve + + def stack_watching_resolver_resolve(self, callable): + # Check for cycles + if callable in call_stack: + raise ResolverException(f"Circular Dependence in {callable.__name__}!") + + # Append to call stack + call_stack.append(callable) + + # Create fake method + fake_method = create_autospec(callable) + fake_method.__module__ = callable.__module__ + + # Resolve fake method + result = original_resolve_method(self, fake_method) + + # Remove from call stack + call_stack.remove(callable) + + return result + + with patch(f"rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve): + with patch(f"rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x): + job().run( + reader=MagicMock(), + run_date=MagicMock(), + spark=spark, + config=MagicMock(), + metadata_manager=MagicMock(), + feature_loader=MagicMock(), + ) + return True diff --git a/tests/jobs/dependency_checks_job/complex_dependency_job.py b/tests/jobs/dependency_checks_job/complex_dependency_job.py new file mode 100644 index 0000000..eabb70a --- /dev/null +++ b/tests/jobs/dependency_checks_job/complex_dependency_job.py @@ -0,0 +1,26 @@ +import tests.jobs.dependency_checks_job.datasources_a as a +import tests.jobs.dependency_checks_job.datasources_b as b +from rialto.jobs import job, register_dependency_callable, register_dependency_module + +# module "A" has i(), j(), k() +# module "B" has i(j), and dependency on module C +# module "C" has j(), k() + +register_dependency_module(b) +register_dependency_callable(a.j) + + +@job +def complex_dependency_job(i, j): + # If we import module B, and A.j, we should not see any conflicts, because: + # A.i won't get imported, thus won't clash with B.i + # B has no j it only sees C.j as registered dependency + + assert i == "B.i-C.j" + assert j == "A.j" + + +@job +def unimported_dependency_job(k): + # k is in both A and C, but it's not imported here, thus won't get resolved + pass diff --git a/tests/jobs/dependency_checks_job/datasources_a.py b/tests/jobs/dependency_checks_job/datasources_a.py new file mode 100644 index 0000000..f8ff293 --- /dev/null +++ b/tests/jobs/dependency_checks_job/datasources_a.py @@ -0,0 +1,16 @@ +from rialto.jobs import datasource + + +@datasource +def i(): + return "A.i" + + +@datasource +def j(): + return "A.j" + + +@datasource +def k(): + return "A.k" diff --git a/tests/jobs/dependency_checks_job/datasources_b.py b/tests/jobs/dependency_checks_job/datasources_b.py new file mode 100644 index 0000000..fce58bc --- /dev/null +++ b/tests/jobs/dependency_checks_job/datasources_b.py @@ -0,0 +1,9 @@ +import tests.jobs.dependency_checks_job.datasources_c as c +from rialto.jobs import datasource, register_dependency_module + +register_dependency_module(c) + + +@datasource +def i(j): + return f"B.i-{j}" diff --git a/tests/jobs/dependency_checks_job/datasources_c.py b/tests/jobs/dependency_checks_job/datasources_c.py new file mode 100644 index 0000000..5a08eb0 --- /dev/null +++ b/tests/jobs/dependency_checks_job/datasources_c.py @@ -0,0 +1,11 @@ +from rialto.jobs import datasource + + +@datasource +def j(): + return "C.j" + + +@datasource +def k(): + return "C.k" diff --git a/tests/jobs/dependency_checks_job/dependency_checks_job.py b/tests/jobs/dependency_checks_job/dependency_checks_job.py new file mode 100644 index 0000000..5952705 --- /dev/null +++ b/tests/jobs/dependency_checks_job/dependency_checks_job.py @@ -0,0 +1,32 @@ +import tests.jobs.dependency_checks_job.main_datasources as ds +from rialto.jobs import job, register_dependency_module + +register_dependency_module(ds) + + +@job +def ok_dependency_job(c): + return c + 1 + + +@job +def circular_dependency_job(circle_third): + return circle_third + 1 + + +@job +def missing_dependency_job(a, x): + return x + a + + +@job +def self_dependency_job(self_dependency): + return self_dependency + 1 + + +@job +def default_dependency_job(run_date, spark, config, table_reader): + assert run_date is not None + assert spark is not None + assert config is not None + assert table_reader is not None diff --git a/tests/jobs/dependency_checks_job/main_datasources.py b/tests/jobs/dependency_checks_job/main_datasources.py new file mode 100644 index 0000000..8ac2d94 --- /dev/null +++ b/tests/jobs/dependency_checks_job/main_datasources.py @@ -0,0 +1,37 @@ +from rialto.jobs import datasource + + +@datasource +def a(): + return 1 + + +@datasource +def b(a): + return a + 10 + + +@datasource +def c(a, b): + # 1 + 11 = 12 + return a + b + + +@datasource +def circle_first(circle_second): + return circle_second + 1 + + +@datasource +def circle_second(circle_third): + return circle_third + 1 + + +@datasource +def circle_third(circle_first): + return circle_first + 1 + + +@datasource +def self_dependency(a, b, c, self_dependency): + return a diff --git a/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_a.py b/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_a.py deleted file mode 100644 index a073c7c..0000000 --- a/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_a.py +++ /dev/null @@ -1,9 +0,0 @@ -import tests.jobs.resolver_dep_checks_job.datasources as ds -from rialto.jobs import job, register_module - -register_module(ds) - - -@job -def ok_dep_job(datasource_pkg, datasource_base): - pass diff --git a/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_b.py b/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_b.py deleted file mode 100644 index 8248272..0000000 --- a/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_b.py +++ /dev/null @@ -1,6 +0,0 @@ -from rialto.jobs import job - - -@job -def missing_dep_job(datasource_pkg, datasource_base): - pass diff --git a/tests/jobs/resolver_dep_checks_job/datasources.py b/tests/jobs/resolver_dep_checks_job/datasources.py deleted file mode 100644 index d123b5e..0000000 --- a/tests/jobs/resolver_dep_checks_job/datasources.py +++ /dev/null @@ -1,39 +0,0 @@ -import tests.jobs.resolver_dep_checks_job.dep_package.pkg_datasources as pkg_ds -from rialto.jobs import datasource, register_module - -register_module(pkg_ds) - - -@datasource -def a(): - return 1 - - -@datasource -def b(a): - return a + 1 - - -@datasource -def c(a, b): - return a + b - - -@datasource -def d(a, circle_1): - return circle_1 + a - - -@datasource -def circle_1(circle_2): - return circle_2 + 1 - - -@datasource -def circle_2(circle_1): - return circle_1 + 1 - - -@datasource -def datasource_base(): - return "dataset_base_return" diff --git a/tests/jobs/resolver_dep_checks_job/dep_package/pkg_datasources.py b/tests/jobs/resolver_dep_checks_job/dep_package/pkg_datasources.py deleted file mode 100644 index 2325265..0000000 --- a/tests/jobs/resolver_dep_checks_job/dep_package/pkg_datasources.py +++ /dev/null @@ -1,6 +0,0 @@ -from rialto.jobs import datasource - - -@datasource -def datasource_pkg(): - return "datasource_pkg_return" diff --git a/tests/jobs/resolver_dep_checks_job/dependency_tests_job.py b/tests/jobs/resolver_dep_checks_job/dependency_tests_job.py deleted file mode 100644 index dffff6e..0000000 --- a/tests/jobs/resolver_dep_checks_job/dependency_tests_job.py +++ /dev/null @@ -1,24 +0,0 @@ -import tests.jobs.resolver_dep_checks_job.datasources as ds -from rialto.jobs import job, register_module - -register_module(ds) - - -@job -def ok_dependency_job(c): - return c + 1 - - -@job -def circular_dependency_job(d): - return d + 1 - - -@job -def missing_dependency_job(a, x): - return x + a - - -@job -def default_dependency_job(run_date, spark, config, table_reader, feature_loader): - return 1 diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py index a1945e4..d1931c3 100644 --- a/tests/jobs/test_decorators.py +++ b/tests/jobs/test_decorators.py @@ -20,20 +20,12 @@ def test_dataset_decorator(): _ = import_module("tests.jobs.test_job.test_job") - - callables = ModuleRegister.get_registered_callables("tests.jobs.test_job.test_job") - callable_names = [f.__name__ for f in callables] - - assert "dataset" in callable_names + assert ModuleRegister.find_callable("dataset", "tests.jobs.test_job.test_job") is not None def test_config_decorator(): _ = import_module("tests.jobs.test_job.test_job") - - callables = ModuleRegister.get_registered_callables("tests.jobs.test_job.test_job") - callable_names = [f.__name__ for f in callables] - - assert "custom_config" in callable_names + assert ModuleRegister.find_callable("custom_config", "tests.jobs.test_job.test_job") is not None def _rialto_import_stub(module_name, class_name): diff --git a/tests/jobs/test_resolver.py b/tests/jobs/test_resolver.py index c6ccdb0..443e27b 100644 --- a/tests/jobs/test_resolver.py +++ b/tests/jobs/test_resolver.py @@ -20,46 +20,51 @@ def test_simple_resolve_custom_name(): def f(): return 7 - Resolver.register_callable(f, "hello") + resolver = Resolver() + resolver.register_getter(f, "hello") - assert Resolver.resolve("hello") == 7 + assert resolver.resolve(lambda hello: hello) == 7 def test_simple_resolve_infer_f_name(): def f(): - return 7 + return 8 - Resolver.register_callable(f) + resolver = Resolver() + resolver.register_getter(f) - assert Resolver.resolve("f") == 7 + assert resolver.resolve(lambda f: f) == 8 -def test_dependency_resolve(): - def f(): - return 7 - - def g(f): - return f + 1 +def test_resolve_non_defined(): + resolver = Resolver() + with pytest.raises(ResolverException): + resolver.resolve(lambda x: ...) - Resolver.register_callable(f) - Resolver.register_callable(g) - assert Resolver.resolve("g") == 8 +def test_resolve_multi_dependency(): + def a(b, c): + return b + c + def b(): + return 1 -def test_resolve_non_defined(): - with pytest.raises(ResolverException): - Resolver.resolve("whatever") + def c(d): + return d + 10 + def d(): + return 100 -def test_register_resolve(mocker): - def f(): - return 7 + resolver = Resolver() + resolver.register_getter(a) + resolver.register_getter(b) + resolver.register_getter(c) + resolver.register_getter(d) - mocker.patch("rialto.jobs.resolver.Resolver.register_callable", return_value="f") - mocker.patch("rialto.jobs.resolver.Resolver.resolve") + assert resolver.resolve(a) == 111 - Resolver.register_resolve(f) - Resolver.register_callable.assert_called_once_with(f) - Resolver.resolve.assert_called_once_with("f") +def test_register_objects(): + resolver = Resolver() + resolver.register_object(7, "seven") + assert resolver.resolve(lambda seven: seven) == 7 diff --git a/tests/jobs/test_test_utils.py b/tests/jobs/test_test_utils.py index 373f5fa..9b87b97 100644 --- a/tests/jobs/test_test_utils.py +++ b/tests/jobs/test_test_utils.py @@ -14,23 +14,19 @@ import pytest import rialto.jobs.decorators as decorators -import tests.jobs.resolver_dep_checks_job.cross_dep_tests_job_a as cross_dep_tests_job_a -import tests.jobs.resolver_dep_checks_job.cross_dep_tests_job_b as cross_dep_tests_job_b -import tests.jobs.resolver_dep_checks_job.dependency_tests_job as dependency_tests_job +import tests.jobs.dependency_checks_job.complex_dependency_job as complex_dependency_job +import tests.jobs.dependency_checks_job.dependency_checks_job as dependency_checks_job import tests.jobs.test_job.test_job as test_job -from rialto.jobs.resolver import Resolver from rialto.jobs.test_utils import disable_job_decorators, resolver_resolves def test_raw_dataset_patch(mocker): - spy_rc = mocker.spy(Resolver, "register_callable") spy_dec = mocker.spy(decorators, "datasource") with disable_job_decorators(test_job): assert test_job.dataset() == "dataset_return" - spy_dec.assert_not_called() - spy_rc.assert_not_called() + spy_dec.assert_not_called() def test_job_function_patch(mocker): @@ -39,7 +35,7 @@ def test_job_function_patch(mocker): with disable_job_decorators(test_job): assert test_job.job_function() == "job_function_return" - spy_dec.assert_not_called() + spy_dec.assert_not_called() def test_custom_name_job_function_patch(mocker): @@ -48,48 +44,48 @@ def test_custom_name_job_function_patch(mocker): with disable_job_decorators(test_job): assert test_job.custom_name_job_function() == "custom_job_name_return" - spy_dec.assert_not_called() + spy_dec.assert_not_called() def test_resolver_resolves_ok_job(spark): - assert resolver_resolves(spark, dependency_tests_job.ok_dependency_job) + assert resolver_resolves(spark, dependency_checks_job.ok_dependency_job) def test_resolver_resolves_default_dependency(spark): - assert resolver_resolves(spark, dependency_tests_job.default_dependency_job) + assert resolver_resolves(spark, dependency_checks_job.default_dependency_job) -def test_resolver_resolves_fails_circular_dependency(spark): +def test_resolver_fails_circular_dependency(spark): with pytest.raises(Exception) as exc_info: - assert resolver_resolves(spark, dependency_tests_job.circular_dependency_job) + assert resolver_resolves(spark, dependency_checks_job.circular_dependency_job) assert exc_info is not None - assert str(exc_info.value) == "Circular Dependence on circle_1!" + assert str(exc_info.value) == "Circular Dependence in circle_third!" -def test_resolver_resolves_fails_missing_dependency(spark): +def test_resolver_fails_missing_dependency(spark): with pytest.raises(Exception) as exc_info: - assert resolver_resolves(spark, dependency_tests_job.missing_dependency_job) + assert resolver_resolves(spark, dependency_checks_job.missing_dependency_job) assert exc_info is not None assert str(exc_info.value) == "x declaration not found!" -def test_resolver_dep_separation_correct_load_existing(spark): - assert resolver_resolves(spark, cross_dep_tests_job_a.ok_dep_job) - - -def test_resolver_dep_separation_fail_load_missing(spark): +def tests_resolver_fails_self_dependency(spark): with pytest.raises(Exception) as exc_info: - assert resolver_resolves(spark, cross_dep_tests_job_b.missing_dep_job) + assert resolver_resolves(spark, dependency_checks_job.self_dependency_job) + assert exc_info is not None + assert str(exc_info.value) == "Circular Dependence in self_dependency!" -def test_resolver_wont_cross_pollinate(spark): - # This job has imported the dependencies - assert resolver_resolves(spark, cross_dep_tests_job_a.ok_dep_job) +def test_complex_dependencies_resolves_correctly(spark): + assert resolver_resolves(spark, complex_dependency_job.complex_dependency_job) - # This job has no imported dependencies + +def test_complex_dependencies_fails_on_unimported(spark): with pytest.raises(Exception) as exc_info: - assert resolver_resolves(spark, cross_dep_tests_job_b.missing_dep_job) + assert resolver_resolves(spark, complex_dependency_job.unimported_dependency_job) + assert exc_info is not None + assert str(exc_info.value) == "k declaration not found!"