diff --git a/rialto/common/utils.py b/rialto/common/utils.py index 6f5ed1f..296cba8 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["load_yaml"] +__all__ = ["load_yaml", "cast_decimals_to_floats", "get_caller_module"] +import inspect import os -from typing import Any +from typing import Any, List import pyspark.sql.functions as F import yaml @@ -51,3 +52,22 @@ def cast_decimals_to_floats(df: DataFrame) -> DataFrame: df = df.withColumn(c, F.col(c).cast(FloatType())) return df + + +def get_caller_module() -> Any: + """ + Ged module containing the function which is calling your function. + + Inspects the call stack, where: + 0th entry is this function + 1st entry is the function which needs to know who called it + 2nd entry is the calling function + + Therefore, we'll return a module which contains the function at the 2nd place on the stack. + + :return: Python Module containing the calling function. + """ + + stack = inspect.stack() + last_stack = stack[2] + return inspect.getmodule(last_stack[0]) diff --git a/rialto/jobs/__init__.py b/rialto/jobs/__init__.py index 0c3e01c..b9f443a 100644 --- a/rialto/jobs/__init__.py +++ b/rialto/jobs/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.jobs.decorators import config_parser, datasource, job +from rialto.jobs.decorators import config_parser, datasource, job, register_module diff --git a/rialto/jobs/decorators.py b/rialto/jobs/decorators.py index dd79bdd..617e8d2 100644 --- a/rialto/jobs/decorators.py +++ b/rialto/jobs/decorators.py @@ -20,8 +20,14 @@ import importlib_metadata from loguru import logger +from rialto.common.utils import get_caller_module from rialto.jobs.job_base import JobBase -from rialto.jobs.resolver import Resolver +from rialto.jobs.module_register import ModuleRegister + + +def register_module(module): + caller_module = get_caller_module() + ModuleRegister.register_dependency(caller_module, module) def config_parser(cf_getter: typing.Callable) -> typing.Callable: @@ -34,7 +40,7 @@ def config_parser(cf_getter: typing.Callable) -> typing.Callable: :param cf_getter: dataset reader function :return: raw function, unchanged """ - Resolver.register_callable(cf_getter) + ModuleRegister.register_callable(cf_getter) return cf_getter @@ -48,16 +54,10 @@ def datasource(ds_getter: typing.Callable) -> typing.Callable: :param ds_getter: dataset reader function :return: raw reader function, unchanged """ - Resolver.register_callable(ds_getter) + ModuleRegister.register_callable(ds_getter) return ds_getter -def _get_module(stack: typing.List) -> typing.Any: - last_stack = stack[1] - mod = inspect.getmodule(last_stack[0]) - return mod - - def _get_version(module: typing.Any) -> str: try: package_name, _, _ = module.__name__.partition(".") @@ -102,9 +102,7 @@ def job(*args, custom_name=None, disable_version=False): :return: One more job wrapper for run function (if custom name or version override specified). Otherwise, generates Rialto Transformation Type and returns it for in-module registration. """ - stack = inspect.stack() - - module = _get_module(stack) + module = get_caller_module() version = _get_version(module) # Use case where it's just raw @f. Otherwise, we get [] here. diff --git a/rialto/jobs/job_base.py b/rialto/jobs/job_base.py index c65341d..e548503 100644 --- a/rialto/jobs/job_base.py +++ b/rialto/jobs/job_base.py @@ -24,6 +24,7 @@ from pyspark.sql import DataFrame, SparkSession from rialto.common import TableReader +from rialto.jobs.module_register import ModuleRegister from rialto.jobs.resolver import Resolver from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager @@ -50,54 +51,40 @@ def get_job_name(self) -> str: pass @contextmanager - def _setup_resolver(self, run_date: datetime.date) -> None: - Resolver.register_callable(lambda: run_date, "run_date") - - Resolver.register_callable(self._get_spark, "spark") - Resolver.register_callable(self._get_table_reader, "table_reader") - Resolver.register_callable(self._get_config, "config") - - if self._get_feature_loader() is not None: - Resolver.register_callable(self._get_feature_loader, "feature_loader") - if self._get_metadata_manager() is not None: - Resolver.register_callable(self._get_metadata_manager, "metadata_manager") - - try: - yield - finally: - Resolver.cache_clear() - - def _setup( + def _setup_resolver( self, spark: SparkSession, + run_date: datetime.date, table_reader: TableReader, config: PipelineConfig = None, metadata_manager: MetadataManager = None, feature_loader: PysparkFeatureLoader = None, ) -> None: - self._spark = spark - self._table_rader = table_reader - self._config = config - self._metadata = metadata_manager - self._feature_loader = feature_loader + # Static Always - Available dependencies + Resolver.register_object(spark, "spark") + Resolver.register_object(run_date, "run_date") + Resolver.register_object(config, "config") + Resolver.register_object(table_reader, "table_reader") - def _get_spark(self) -> SparkSession: - return self._spark + # Datasets & Configs + callable_module_name = self.get_custom_callable().__module__ + for m in ModuleRegister.get_registered_callables(callable_module_name): + Resolver.register_callable(m) - def _get_table_reader(self) -> TableReader: - return self._table_rader + # Optionals + if feature_loader is not None: + Resolver.register_object(feature_loader, "feature_loader") - def _get_config(self) -> PipelineConfig: - return self._config + if metadata_manager is not None: + Resolver.register_object(metadata_manager, "metadata_manager") - def _get_feature_loader(self) -> PysparkFeatureLoader: - return self._feature_loader + try: + yield - def _get_metadata_manager(self) -> MetadataManager: - return self._metadata + finally: + Resolver.clear() - def _get_timestamp_holder_result(self) -> DataFrame: - spark = self._get_spark() + def _get_timestamp_holder_result(self, spark) -> DataFrame: return spark.createDataFrame( [(self.get_job_name(), datetime.datetime.now())], schema="JOB_NAME string, CREATION_TIME timestamp" ) @@ -110,17 +97,6 @@ def _add_job_version(self, df: DataFrame) -> DataFrame: return df - def _run_main_callable(self, run_date: datetime.date) -> DataFrame: - with self._setup_resolver(run_date): - custom_callable = self.get_custom_callable() - raw_result = Resolver.register_resolve(custom_callable) - - if raw_result is None: - raw_result = self._get_timestamp_holder_result() - - result_with_version = self._add_job_version(raw_result) - return result_with_version - def run( self, reader: TableReader, @@ -140,8 +116,16 @@ def run( :return: dataframe """ try: - self._setup(spark, reader, config, metadata_manager, feature_loader) - return self._run_main_callable(run_date) + with self._setup_resolver(spark, run_date, reader, config, metadata_manager, feature_loader): + custom_callable = self.get_custom_callable() + raw_result = Resolver.register_resolve(custom_callable) + + if raw_result is None: + raw_result = self._get_timestamp_holder_result(spark) + + result_with_version = self._add_job_version(raw_result) + return result_with_version + except Exception as e: logger.exception(e) raise e diff --git a/rialto/jobs/module_register.py b/rialto/jobs/module_register.py new file mode 100644 index 0000000..a7d0e87 --- /dev/null +++ b/rialto/jobs/module_register.py @@ -0,0 +1,49 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["ModuleRegister"] + + +class ModuleRegister: + _storage = {} + _dependency_tree = {} + + @classmethod + def register_callable(cls, callable): + callable_module = callable.__module__ + + module_callables = cls._storage.get(callable_module, []) + module_callables.append(callable) + + cls._storage[callable_module] = module_callables + + @classmethod + def register_dependency(cls, caller_module, module): + caller_module_name = caller_module.__name__ + target_module_name = module.__name__ + + module_dep_tree = cls._dependency_tree.get(caller_module_name, []) + module_dep_tree.append(target_module_name) + + cls._dependency_tree[caller_module_name] = module_dep_tree + + @classmethod + def get_registered_callables(cls, module_name): + callables = cls._storage.get(module_name, []) + + for included_module in cls._dependency_tree.get(module_name, []): + included_callables = cls.get_registered_callables(included_module) + callables.extend(included_callables) + + return callables diff --git a/rialto/jobs/resolver.py b/rialto/jobs/resolver.py index 26856d1..b6e06d2 100644 --- a/rialto/jobs/resolver.py +++ b/rialto/jobs/resolver.py @@ -45,6 +45,18 @@ def _get_args_for_call(cls, function: typing.Callable) -> typing.Dict[str, typin return result_dict + @classmethod + def register_object(cls, object: typing.Any, name: str) -> None: + """ + Register an object with a given name for later resolution. + + :param object: object to register (getter) + :param name: str, custom name + :return: None + """ + + cls.register_callable(lambda: object, name) + @classmethod def register_callable(cls, callable: typing.Callable, name: str = None) -> str: """ @@ -58,6 +70,10 @@ def register_callable(cls, callable: typing.Callable, name: str = None) -> str: """ if name is None: name = getattr(callable, "__name__", repr(callable)) + """ + if name in cls._storage: + raise ResolverException(f"Resolver already registered {name}!") + """ cls._storage[name] = callable return name @@ -97,14 +113,11 @@ def register_resolve(cls, callable: typing.Callable) -> typing.Any: return cls.resolve(name) @classmethod - def cache_clear(cls) -> None: + def clear(cls) -> None: """ - Clear resolver cache. - - The resolve method caches its results to avoid duplication of resolutions. - However, in case we re-register some callables, we need to clear cache - in order to ensure re-execution of all resolutions. + Clear all registered datasources and jobs. :return: None """ cls.resolve.cache_clear() + cls._storage.clear() diff --git a/rialto/jobs/test_utils.py b/rialto/jobs/test_utils.py index 3f6e3e2..e363e80 100644 --- a/rialto/jobs/test_utils.py +++ b/rialto/jobs/test_utils.py @@ -89,6 +89,9 @@ def __setitem__(self, key, value): def keys(self): return self._storage.keys() + def clear(self): + self._storage.clear() + def __getitem__(self, func_name): if func_name in self._call_stack: raise ResolverException(f"Circular Dependence on {func_name}!") @@ -102,6 +105,12 @@ def __getitem__(self, func_name): return fake_method with patch("rialto.jobs.resolver.Resolver._storage", SmartStorage()): - job().run(reader=MagicMock(), run_date=MagicMock(), spark=spark) + job().run( + reader=MagicMock(), + run_date=MagicMock(), + spark=spark, + metadata_manager=MagicMock(), + feature_loader=MagicMock(), + ) return True diff --git a/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_a.py b/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_a.py new file mode 100644 index 0000000..a073c7c --- /dev/null +++ b/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_a.py @@ -0,0 +1,9 @@ +import tests.jobs.resolver_dep_checks_job.datasources as ds +from rialto.jobs import job, register_module + +register_module(ds) + + +@job +def ok_dep_job(datasource_pkg, datasource_base): + pass diff --git a/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_b.py b/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_b.py new file mode 100644 index 0000000..8248272 --- /dev/null +++ b/tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_b.py @@ -0,0 +1,6 @@ +from rialto.jobs import job + + +@job +def missing_dep_job(datasource_pkg, datasource_base): + pass diff --git a/tests/jobs/test_job/dependency_tests_job.py b/tests/jobs/resolver_dep_checks_job/datasources.py similarity index 52% rename from tests/jobs/test_job/dependency_tests_job.py rename to tests/jobs/resolver_dep_checks_job/datasources.py index 7452d02..d123b5e 100644 --- a/tests/jobs/test_job/dependency_tests_job.py +++ b/tests/jobs/resolver_dep_checks_job/datasources.py @@ -1,4 +1,7 @@ -from rialto.jobs import datasource, job +import tests.jobs.resolver_dep_checks_job.dep_package.pkg_datasources as pkg_ds +from rialto.jobs import datasource, register_module + +register_module(pkg_ds) @datasource @@ -16,11 +19,6 @@ def c(a, b): return a + b -@job -def ok_dependency_job(c): - return c + 1 - - @datasource def d(a, circle_1): return circle_1 + a @@ -36,16 +34,6 @@ def circle_2(circle_1): return circle_1 + 1 -@job -def circular_dependency_job(d): - return d + 1 - - -@job -def missing_dependency_job(a, x): - return x + a - - -@job -def default_dependency_job(run_date, spark, config): - return 1 +@datasource +def datasource_base(): + return "dataset_base_return" diff --git a/tests/jobs/resolver_dep_checks_job/dep_package/pkg_datasources.py b/tests/jobs/resolver_dep_checks_job/dep_package/pkg_datasources.py new file mode 100644 index 0000000..2325265 --- /dev/null +++ b/tests/jobs/resolver_dep_checks_job/dep_package/pkg_datasources.py @@ -0,0 +1,6 @@ +from rialto.jobs import datasource + + +@datasource +def datasource_pkg(): + return "datasource_pkg_return" diff --git a/tests/jobs/resolver_dep_checks_job/dependency_tests_job.py b/tests/jobs/resolver_dep_checks_job/dependency_tests_job.py new file mode 100644 index 0000000..dffff6e --- /dev/null +++ b/tests/jobs/resolver_dep_checks_job/dependency_tests_job.py @@ -0,0 +1,24 @@ +import tests.jobs.resolver_dep_checks_job.datasources as ds +from rialto.jobs import job, register_module + +register_module(ds) + + +@job +def ok_dependency_job(c): + return c + 1 + + +@job +def circular_dependency_job(d): + return d + 1 + + +@job +def missing_dependency_job(a, x): + return x + a + + +@job +def default_dependency_job(run_date, spark, config, table_reader, feature_loader): + return 1 diff --git a/tests/jobs/resources.py b/tests/jobs/resources.py index 273bf38..ddb8bf8 100644 --- a/tests/jobs/resources.py +++ b/tests/jobs/resources.py @@ -16,10 +16,18 @@ import pandas as pd from rialto.jobs.job_base import JobBase +from rialto.jobs.resolver import Resolver def custom_callable(): - pass + return None + + +def asserting_callable(): + assert Resolver.resolve("run_date") + assert Resolver.resolve("config") + assert Resolver.resolve("spark") + assert Resolver.resolve("table_reader") class CustomJobNoReturnVal(JobBase): @@ -46,3 +54,8 @@ def f(spark): class CustomJobNoVersion(CustomJobNoReturnVal): def get_job_version(self) -> str: return None + + +def CustomJobAssertResolverSetup(CustomJobNoReturnVal): + def get_custom_callable(): + return asserting_callable diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py index a09ee69..a1945e4 100644 --- a/tests/jobs/test_decorators.py +++ b/tests/jobs/test_decorators.py @@ -15,21 +15,25 @@ from importlib import import_module from rialto.jobs.job_base import JobBase -from rialto.jobs.resolver import Resolver +from rialto.jobs.module_register import ModuleRegister def test_dataset_decorator(): _ = import_module("tests.jobs.test_job.test_job") - test_dataset = Resolver.resolve("dataset") - assert test_dataset == "dataset_return" + callables = ModuleRegister.get_registered_callables("tests.jobs.test_job.test_job") + callable_names = [f.__name__ for f in callables] + + assert "dataset" in callable_names def test_config_decorator(): _ = import_module("tests.jobs.test_job.test_job") - test_dataset = Resolver.resolve("custom_config") - assert test_dataset == "config_return" + callables = ModuleRegister.get_registered_callables("tests.jobs.test_job.test_job") + callable_names = [f.__name__ for f in callables] + + assert "custom_config" in callable_names def _rialto_import_stub(module_name, class_name): diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py index 2fb01ea..1514957 100644 --- a/tests/jobs/test_job_base.py +++ b/tests/jobs/test_job_base.py @@ -23,30 +23,13 @@ from rialto.loader import PysparkFeatureLoader -def test_setup_except_feature_loader(spark): +def test_setup(spark): table_reader = MagicMock() config = MagicMock() date = datetime.date(2023, 1, 1) resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=config) - assert Resolver.resolve("run_date") == date - assert Resolver.resolve("config") == config - assert Resolver.resolve("spark") == spark - assert Resolver.resolve("table_reader") == table_reader - - -def test_setup_feature_loader(spark): - table_reader = MagicMock() - date = datetime.date(2023, 1, 1) - feature_loader = PysparkFeatureLoader(spark, "", "", "") - - resources.CustomJobNoReturnVal().run( - reader=table_reader, run_date=date, spark=spark, config=None, feature_loader=feature_loader - ) - - assert type(Resolver.resolve("feature_loader")) == PysparkFeatureLoader - def test_custom_callable_called(spark, mocker): spy_cc = mocker.spy(resources, "custom_callable") diff --git a/tests/jobs/test_test_utils.py b/tests/jobs/test_test_utils.py index e6ef9da..373f5fa 100644 --- a/tests/jobs/test_test_utils.py +++ b/tests/jobs/test_test_utils.py @@ -14,7 +14,9 @@ import pytest import rialto.jobs.decorators as decorators -import tests.jobs.test_job.dependency_tests_job as dependency_tests_job +import tests.jobs.resolver_dep_checks_job.cross_dep_tests_job_a as cross_dep_tests_job_a +import tests.jobs.resolver_dep_checks_job.cross_dep_tests_job_b as cross_dep_tests_job_b +import tests.jobs.resolver_dep_checks_job.dependency_tests_job as dependency_tests_job import tests.jobs.test_job.test_job as test_job from rialto.jobs.resolver import Resolver from rialto.jobs.test_utils import disable_job_decorators, resolver_resolves @@ -71,3 +73,23 @@ def test_resolver_resolves_fails_missing_dependency(spark): assert exc_info is not None assert str(exc_info.value) == "x declaration not found!" + + +def test_resolver_dep_separation_correct_load_existing(spark): + assert resolver_resolves(spark, cross_dep_tests_job_a.ok_dep_job) + + +def test_resolver_dep_separation_fail_load_missing(spark): + with pytest.raises(Exception) as exc_info: + assert resolver_resolves(spark, cross_dep_tests_job_b.missing_dep_job) + assert exc_info is not None + + +def test_resolver_wont_cross_pollinate(spark): + # This job has imported the dependencies + assert resolver_resolves(spark, cross_dep_tests_job_a.ok_dep_job) + + # This job has no imported dependencies + with pytest.raises(Exception) as exc_info: + assert resolver_resolves(spark, cross_dep_tests_job_b.missing_dep_job) + assert exc_info is not None