diff --git a/rialto/jobs/module_register.py b/rialto/jobs/module_register.py index 8283454..27a55ef 100644 --- a/rialto/jobs/module_register.py +++ b/rialto/jobs/module_register.py @@ -56,6 +56,15 @@ def register_callable(cls, callable): callable_module = callable.__module__ cls.add_callable_to_module(callable, callable_module) + @classmethod + def remove_module(cls, module): + """ + Remove a module from the storage. + + :param module: The module to be removed. + """ + cls._storage.pop(module.__name__, None) + @classmethod def register_dependency(cls, module, parent_name): """ diff --git a/rialto/jobs/test_utils.py b/rialto/jobs/test_utils.py index d8f2945..cced2fe 100644 --- a/rialto/jobs/test_utils.py +++ b/rialto/jobs/test_utils.py @@ -20,6 +20,7 @@ from unittest.mock import MagicMock, create_autospec, patch from rialto.jobs.job_base import JobBase +from rialto.jobs.module_register import ModuleRegister from rialto.jobs.resolver import Resolver, ResolverException @@ -59,15 +60,17 @@ def disable_job_decorators(module) -> None: :return: None """ with _disable_job_decorators(): + ModuleRegister.remove_module(module) importlib.reload(module) yield + ModuleRegister.remove_module(module) importlib.reload(module) def resolver_resolves(spark, job: JobBase) -> bool: """ - Checker method for your dependency resoultion. + Checker method for your dependency resolution. If your job's dependencies are all defined and resolvable, returns true. Otherwise, throws an exception. @@ -100,8 +103,8 @@ def stack_watching_resolver_resolve(self, callable): return result - with patch(f"rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve): - with patch(f"rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x): + with patch("rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve): + with patch("rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x): job().run( reader=MagicMock(), run_date=MagicMock(), diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py index 4e47364..01c8df9 100644 --- a/tests/jobs/test_job/test_job.py +++ b/tests/jobs/test_job/test_job.py @@ -29,6 +29,11 @@ def job_function(): return "job_function_return" +@job +def job_with_datasource(dataset): + return dataset + + @job(custom_name="custom_job_name") def custom_name_job_function(): return "custom_job_name_return" diff --git a/tests/jobs/test_register.py b/tests/jobs/test_register.py new file mode 100644 index 0000000..1904537 --- /dev/null +++ b/tests/jobs/test_register.py @@ -0,0 +1,8 @@ +from rialto.jobs.test_utils import disable_job_decorators, resolver_resolves +from tests.jobs.test_job import test_job + + +def test_resolve_after_disable(spark): + with disable_job_decorators(test_job): + assert test_job.job_with_datasource("test") == "test" + assert resolver_resolves(spark, test_job.job_with_datasource)