Skip to content

Commit

Permalink
Switchable Version Column
Browse files Browse the repository at this point in the history
  • Loading branch information
vvancak committed Jul 23, 2024
1 parent f0ae0c7 commit 78e0bb3
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 19 deletions.
41 changes: 28 additions & 13 deletions rialto/jobs/decorators/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

__all__ = ["datasource", "job"]

import importlib_metadata
import inspect
import typing

import importlib_metadata
from loguru import logger

from rialto.jobs.decorators.job_base import JobBase
Expand Down Expand Up @@ -47,7 +47,7 @@ def _get_module(stack: typing.List) -> typing.Any:
def _get_version(module: typing.Any) -> str:
try:
package_name, _, _ = module.__name__.partition(".")
dist_name = importlib_metadata.packages_distributions()[package_name][0]
dist_name = importlib_metadata.packages_distributions()[package_name][0]
return importlib_metadata.version(dist_name)

except Exception:
Expand All @@ -73,29 +73,44 @@ def _generate_rialto_job(callable: typing.Callable, module: object, class_name:
return generated_class


def job(name_or_callable: typing.Union[str, typing.Callable]) -> typing.Union[typing.Callable, typing.Type]:
def job(*args, custom_name=None, disable_version=False):
"""
Rialto jobs decorator.
Transforms a python function into a rialto transormation, which can be imported and ran by Rialto Runner.
Allows a custom name, via @job("custom_name_here") or can be just used as @job and the function's name is used.
Is mainly used as @job and the function's name is used, and the outputs get automatic.
To override this behavious, use @job(custom_name=XXX, disable_version=True).
:param name_or_callable: str for custom job name. Otherwise, run function.
:return: One more job wrapper for run function (if custom name specified).
:param *args: list of positional arguments. Empty in case custom_name or disable_version is specified.
:param custom_name: str for custom job name.
:param disable_version: bool for disable autofilling the VERSION column in the job's outputs.
:return: One more job wrapper for run function (if custom name or version override specified).
Otherwise, generates Rialto Transformation Type and returns it for in-module registration.
"""
stack = inspect.stack()

module = _get_module(stack)
version = _get_version(module)

if type(name_or_callable) is str:
# Use case where it's just raw @f. Otherwise we get [] here.
if len(args) == 1 and callable(args[0]):
f = args[0]
return _generate_rialto_job(callable=f, module=module, class_name=f.__name__, version=version)

# Otherwise we need to return one more wrapper
def inner_wrapper(f):
# Setting default custom name, in case user only disables version
name = f.__name__

# User - Specified custom name
if custom_name is not None:
name = custom_name

def inner_wrapper(callable):
return _generate_rialto_job(callable, module, name_or_callable, version)
# Setting version to None causes JobBase to not fill it
if disable_version:
version = None

return inner_wrapper
return _generate_rialto_job(callable=f, module=module, class_name=name, version=version)

else:
name = name_or_callable.__name__
return _generate_rialto_job(name_or_callable, module, name, version)
return inner_wrapper
6 changes: 5 additions & 1 deletion rialto/jobs/decorators/job_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,11 @@ def _get_timestamp_holder_result(self) -> DataFrame:

def _add_job_version(self, df: DataFrame) -> DataFrame:
version = self.get_job_version()
return df.withColumn("VERSION", F.lit(version))

if version is not None:
return df.withColumn("VERSION", F.lit(version))

return df

def _run_main_callable(self, run_date: datetime.date) -> DataFrame:
with self._setup_resolver(run_date):
Expand Down
7 changes: 3 additions & 4 deletions rialto/jobs/decorators/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
from unittest.mock import patch


def _passthrough_decorator(x: typing.Callable) -> typing.Callable:
if type(x) is str:
def _passthrough_decorator(*args, **kwargs) -> typing.Callable:
if len(args) == 0:
return _passthrough_decorator

else:
return x
return args[0]


@contextmanager
Expand Down
5 changes: 5 additions & 0 deletions tests/jobs/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,8 @@ def f(spark):
return spark.createDataFrame(df)

return f


class CustomJobNoVersion(CustomJobNoReturnVal):
def get_job_version(self) -> str:
return None
11 changes: 11 additions & 0 deletions tests/jobs/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ def test_custom_name_function():
custom_callable = result_class.get_custom_callable()
assert custom_callable() == "custom_job_name_return"

job_name = result_class.get_job_name()
assert job_name == "custom_job_name"


def test_job_disabling_version():
result_class = _rialto_import_stub("tests.jobs.test_job.test_job", "disable_version_job_function")
assert issubclass(type(result_class), JobBase)

job_version = result_class.get_job_version()
assert job_version is None


def test_job_dependencies_registered(spark):
ConfigHolder.set_custom_config(value=123)
Expand Down
7 changes: 6 additions & 1 deletion tests/jobs/test_job/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,16 @@ def job_function():
return "job_function_return"


@job("custom_job_name")
@job(custom_name="custom_job_name")
def custom_name_job_function():
return "custom_job_name_return"


@job(disable_version=True)
def disable_version_job_function():
return "disabled_version_job_return"


@job
def job_asking_for_all_deps(spark, run_date, config, dependencies, table_reader):
assert spark is not None
Expand Down
10 changes: 10 additions & 0 deletions tests/jobs/test_job_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,13 @@ def test_return_dataframe_forwarded_with_version(spark):
assert result.columns == ["FIRST", "SECOND", "VERSION"]
assert result.first()["VERSION"] == "job_version"
assert result.count() == 2


def test_none_job_version_wont_fill_job_colun(spark):
table_reader = MagicMock()
date = datetime.date(2023, 1, 1)

result = resources.CustomJobNoVersion().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None)

assert type(result) is pyspark.sql.DataFrame
assert "VERSION" not in result.columns

0 comments on commit 78e0bb3

Please sign in to comment.