diff --git a/rialto/jobs/decorators/decorators.py b/rialto/jobs/decorators/decorators.py index 5f2d100..b619082 100644 --- a/rialto/jobs/decorators/decorators.py +++ b/rialto/jobs/decorators/decorators.py @@ -14,10 +14,10 @@ __all__ = ["datasource", "job"] -import importlib_metadata import inspect import typing +import importlib_metadata from loguru import logger from rialto.jobs.decorators.job_base import JobBase @@ -47,7 +47,7 @@ def _get_module(stack: typing.List) -> typing.Any: def _get_version(module: typing.Any) -> str: try: package_name, _, _ = module.__name__.partition(".") - dist_name = importlib_metadata.packages_distributions()[package_name][0] + dist_name = importlib_metadata.packages_distributions()[package_name][0] return importlib_metadata.version(dist_name) except Exception: @@ -73,15 +73,19 @@ def _generate_rialto_job(callable: typing.Callable, module: object, class_name: return generated_class -def job(name_or_callable: typing.Union[str, typing.Callable]) -> typing.Union[typing.Callable, typing.Type]: +def job(*args, custom_name=None, disable_version=False): """ Rialto jobs decorator. Transforms a python function into a rialto transormation, which can be imported and ran by Rialto Runner. - Allows a custom name, via @job("custom_name_here") or can be just used as @job and the function's name is used. + Is mainly used as @job and the function's name is used, and the outputs get automatic. + To override this behavious, use @job(custom_name=XXX, disable_version=True). + - :param name_or_callable: str for custom job name. Otherwise, run function. - :return: One more job wrapper for run function (if custom name specified). + :param *args: list of positional arguments. Empty in case custom_name or disable_version is specified. + :param custom_name: str for custom job name. + :param disable_version: bool for disable autofilling the VERSION column in the job's outputs. + :return: One more job wrapper for run function (if custom name or version override specified). Otherwise, generates Rialto Transformation Type and returns it for in-module registration. """ stack = inspect.stack() @@ -89,13 +93,24 @@ def job(name_or_callable: typing.Union[str, typing.Callable]) -> typing.Union[ty module = _get_module(stack) version = _get_version(module) - if type(name_or_callable) is str: + # Use case where it's just raw @f. Otherwise we get [] here. + if len(args) == 1 and callable(args[0]): + f = args[0] + return _generate_rialto_job(callable=f, module=module, class_name=f.__name__, version=version) + + # Otherwise we need to return one more wrapper + def inner_wrapper(f): + # Setting default custom name, in case user only disables version + name = f.__name__ + + # User - Specified custom name + if custom_name is not None: + name = custom_name - def inner_wrapper(callable): - return _generate_rialto_job(callable, module, name_or_callable, version) + # Setting version to None causes JobBase to not fill it + if disable_version: + version = None - return inner_wrapper + return _generate_rialto_job(callable=f, module=module, class_name=name, version=version) - else: - name = name_or_callable.__name__ - return _generate_rialto_job(name_or_callable, module, name, version) + return inner_wrapper diff --git a/rialto/jobs/decorators/job_base.py b/rialto/jobs/decorators/job_base.py index c55e09c..9e3ecc8 100644 --- a/rialto/jobs/decorators/job_base.py +++ b/rialto/jobs/decorators/job_base.py @@ -96,7 +96,11 @@ def _get_timestamp_holder_result(self) -> DataFrame: def _add_job_version(self, df: DataFrame) -> DataFrame: version = self.get_job_version() - return df.withColumn("VERSION", F.lit(version)) + + if version is not None: + return df.withColumn("VERSION", F.lit(version)) + + return df def _run_main_callable(self, run_date: datetime.date) -> DataFrame: with self._setup_resolver(run_date): diff --git a/rialto/jobs/decorators/test_utils.py b/rialto/jobs/decorators/test_utils.py index bd21dba..d5cf810 100644 --- a/rialto/jobs/decorators/test_utils.py +++ b/rialto/jobs/decorators/test_utils.py @@ -20,12 +20,11 @@ from unittest.mock import patch -def _passthrough_decorator(x: typing.Callable) -> typing.Callable: - if type(x) is str: +def _passthrough_decorator(*args, **kwargs) -> typing.Callable: + if len(args) == 0: return _passthrough_decorator - else: - return x + return args[0] @contextmanager diff --git a/tests/jobs/resources.py b/tests/jobs/resources.py index 4d33fad..60fda7b 100644 --- a/tests/jobs/resources.py +++ b/tests/jobs/resources.py @@ -41,3 +41,8 @@ def f(spark): return spark.createDataFrame(df) return f + + +class CustomJobNoVersion(CustomJobNoReturnVal): + def get_job_version(self) -> str: + return None diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py index e896cec..c6d05e6 100644 --- a/tests/jobs/test_decorators.py +++ b/tests/jobs/test_decorators.py @@ -57,6 +57,17 @@ def test_custom_name_function(): custom_callable = result_class.get_custom_callable() assert custom_callable() == "custom_job_name_return" + job_name = result_class.get_job_name() + assert job_name == "custom_job_name" + + +def test_job_disabling_version(): + result_class = _rialto_import_stub("tests.jobs.test_job.test_job", "disable_version_job_function") + assert issubclass(type(result_class), JobBase) + + job_version = result_class.get_job_version() + assert job_version is None + def test_job_dependencies_registered(spark): ConfigHolder.set_custom_config(value=123) diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py index 12baec9..460490a 100644 --- a/tests/jobs/test_job/test_job.py +++ b/tests/jobs/test_job/test_job.py @@ -26,11 +26,16 @@ def job_function(): return "job_function_return" -@job("custom_job_name") +@job(custom_name="custom_job_name") def custom_name_job_function(): return "custom_job_name_return" +@job(disable_version=True) +def disable_version_job_function(): + return "disabled_version_job_return" + + @job def job_asking_for_all_deps(spark, run_date, config, dependencies, table_reader): assert spark is not None diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py index 2cdc741..ab8284a 100644 --- a/tests/jobs/test_job_base.py +++ b/tests/jobs/test_job_base.py @@ -91,3 +91,13 @@ def test_return_dataframe_forwarded_with_version(spark): assert result.columns == ["FIRST", "SECOND", "VERSION"] assert result.first()["VERSION"] == "job_version" assert result.count() == 2 + + +def test_none_job_version_wont_fill_job_colun(spark): + table_reader = MagicMock() + date = datetime.date(2023, 1, 1) + + result = resources.CustomJobNoVersion().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None) + + assert type(result) is pyspark.sql.DataFrame + assert "VERSION" not in result.columns