Skip to content

Commit

Permalink
fix for register on reload
Browse files Browse the repository at this point in the history
  • Loading branch information
MDobransky committed Sep 27, 2024
1 parent 1020f55 commit 1acc6bd
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 3 deletions.
9 changes: 9 additions & 0 deletions rialto/jobs/module_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def register_callable(cls, callable):
callable_module = callable.__module__
cls.add_callable_to_module(callable, callable_module)

@classmethod
def remove_module(cls, module):
"""
Remove a module from the storage.
:param module: The module to be removed.
"""
cls._storage.pop(module.__name__, None)

@classmethod
def register_dependency(cls, module, parent_name):
"""
Expand Down
9 changes: 6 additions & 3 deletions rialto/jobs/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import MagicMock, create_autospec, patch

from rialto.jobs.job_base import JobBase
from rialto.jobs.module_register import ModuleRegister
from rialto.jobs.resolver import Resolver, ResolverException


Expand Down Expand Up @@ -59,15 +60,17 @@ def disable_job_decorators(module) -> None:
:return: None
"""
with _disable_job_decorators():
ModuleRegister.remove_module(module)
importlib.reload(module)
yield

ModuleRegister.remove_module(module)
importlib.reload(module)


def resolver_resolves(spark, job: JobBase) -> bool:
"""
Checker method for your dependency resoultion.
Checker method for your dependency resolution.
If your job's dependencies are all defined and resolvable, returns true.
Otherwise, throws an exception.
Expand Down Expand Up @@ -100,8 +103,8 @@ def stack_watching_resolver_resolve(self, callable):

return result

with patch(f"rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve):
with patch(f"rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x):
with patch("rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve):
with patch("rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x):
job().run(
reader=MagicMock(),
run_date=MagicMock(),
Expand Down
5 changes: 5 additions & 0 deletions tests/jobs/test_job/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def job_function():
return "job_function_return"


@job
def job_with_datasource(dataset):
return dataset


@job(custom_name="custom_job_name")
def custom_name_job_function():
return "custom_job_name_return"
Expand Down
8 changes: 8 additions & 0 deletions tests/jobs/test_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from rialto.jobs.test_utils import disable_job_decorators, resolver_resolves
from tests.jobs.test_job import test_job


def test_resolve_after_disable(spark):
with disable_job_decorators(test_job):
assert test_job.job_with_datasource("test") == "test"
assert resolver_resolves(spark, test_job.job_with_datasource)

0 comments on commit 1acc6bd

Please sign in to comment.