Skip to content

Commit

Permalink
Resolver Resolution Test Utils
Browse files Browse the repository at this point in the history
  • Loading branch information
vvancak committed Jul 24, 2024
1 parent d2ab2bb commit 4999416
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 3 deletions.
46 changes: 45 additions & 1 deletion rialto/jobs/decorators/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import importlib
import typing
from contextlib import contextmanager
from unittest.mock import patch
from unittest.mock import patch, create_autospec, MagicMock
from rialto.jobs.decorators.resolver import Resolver, ResolverException
from rialto.jobs.decorators.job_base import JobBase


def _passthrough_decorator(*args, **kwargs) -> typing.Callable:
Expand Down Expand Up @@ -58,3 +60,45 @@ def disable_job_decorators(module) -> None:
yield

importlib.reload(module)


def resolver_resolves(spark, job: JobBase) -> bool:
"""
Checker method for your dependency resoultion.
If your job's dependencies are all defined and resolvable, returns true.
Otherwise, throws an exception.
:param spark: SparkSession object.
:param job: Job to try and resolve.
:return: bool, True if job can be resolved
"""

class SmartStorage:
def __init__(self):
self._storage = Resolver._storage.copy()
self._call_stack = []

def __setitem__(self, key, value):
self._storage[key] = value

def keys(self):
return self._storage.keys()

def __getitem__(self, func_name):
if func_name in self._call_stack:
raise ResolverException(f"Circular Dependence on {func_name}!")

self._call_stack.append(func_name)

real_method = self._storage[func_name]
fake_method = create_autospec(real_method)
fake_method.side_effect = lambda *args, **kwargs: self._call_stack.remove(func_name)

return fake_method

with patch("rialto.jobs.decorators.resolver.Resolver._storage", SmartStorage()):
job().run(MagicMock(), MagicMock(), MagicMock(), MagicMock(), spark=spark)

return True
51 changes: 51 additions & 0 deletions tests/jobs/test_job/dependency_tests_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from rialto.jobs.decorators import job, datasource


@datasource
def a():
return 1


@datasource
def b(a):
return a + 1


@datasource
def c(a, b):
return a + b


@job
def ok_dependency_job(c):
return c + 1


@datasource
def d(a, circle_1):
return circle_1 + a


@datasource
def circle_1(circle_2):
return circle_2 + 1


@datasource
def circle_2(circle_1):
return circle_1 + 1


@job
def circular_dependency_job(d):
return d + 1


@job
def missing_dependency_job(a, x):
return x + a


@job
def default_dependency_job(run_date, spark, config, dependencies, table_reader, feature_loader):
return 1
29 changes: 27 additions & 2 deletions tests/jobs/test_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import rialto.jobs.decorators as decorators
import tests.jobs.test_job.test_job as test_job
import tests.jobs.test_job.dependency_tests_job as dependency_tests_job
from rialto.jobs.decorators.resolver import Resolver
from rialto.jobs.decorators.test_utils import disable_job_decorators
from rialto.jobs.decorators.test_utils import disable_job_decorators, resolver_resolves


def test_raw_dataset_patch(mocker):
Expand Down Expand Up @@ -46,3 +47,27 @@ def test_custom_name_job_function_patch(mocker):
assert test_job.custom_name_job_function() == "custom_job_name_return"

spy_dec.assert_not_called()


def test_resolver_resolves_ok_job(spark):
assert resolver_resolves(spark, dependency_tests_job.ok_dependency_job)


def test_resolver_resolves_default_dependency(spark):
assert resolver_resolves(spark, dependency_tests_job.default_dependency_job)


def test_resolver_resolves_fails_circular_dependency(spark):
with pytest.raises(Exception) as exc_info:
assert resolver_resolves(spark, dependency_tests_job.circular_dependency_job)

assert exc_info is not None
assert str(exc_info.value) == "Circular Dependence on circle_1!"


def test_resolver_resolves_fails_missing_dependency(spark):
with pytest.raises(Exception) as exc_info:
assert resolver_resolves(spark, dependency_tests_job.missing_dependency_job)

assert exc_info is not None
assert str(exc_info.value) == "x declaration not found!"

0 comments on commit 4999416

Please sign in to comment.