Skip to content

Commit

Permalink
Module Register
Browse files Browse the repository at this point in the history
  • Loading branch information
vvancak committed Sep 11, 2024
1 parent 0defd47 commit f9e5e72
Show file tree
Hide file tree
Showing 16 changed files with 242 additions and 114 deletions.
24 changes: 22 additions & 2 deletions rialto/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["load_yaml"]
__all__ = ["load_yaml", "cast_decimals_to_floats", "get_caller_module"]

import inspect
import os
from typing import Any
from typing import Any, List

import pyspark.sql.functions as F
import yaml
Expand Down Expand Up @@ -51,3 +52,22 @@ def cast_decimals_to_floats(df: DataFrame) -> DataFrame:
df = df.withColumn(c, F.col(c).cast(FloatType()))

return df


def get_caller_module() -> Any:
"""
Ged module containing the function which is calling your function.
Inspects the call stack, where:
0th entry is this function
1st entry is the function which needs to know who called it
2nd entry is the calling function
Therefore, we'll return a module which contains the function at the 2nd place on the stack.
:return: Python Module containing the calling function.
"""

stack = inspect.stack()
last_stack = stack[2]
return inspect.getmodule(last_stack[0])
2 changes: 1 addition & 1 deletion rialto/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from rialto.jobs.decorators import config_parser, datasource, job
from rialto.jobs.decorators import config_parser, datasource, job, register_module
22 changes: 10 additions & 12 deletions rialto/jobs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,14 @@
import importlib_metadata
from loguru import logger

from rialto.common.utils import get_caller_module
from rialto.jobs.job_base import JobBase
from rialto.jobs.resolver import Resolver
from rialto.jobs.module_register import ModuleRegister


def register_module(module):
caller_module = get_caller_module()
ModuleRegister.register_dependency(caller_module, module)


def config_parser(cf_getter: typing.Callable) -> typing.Callable:
Expand All @@ -34,7 +40,7 @@ def config_parser(cf_getter: typing.Callable) -> typing.Callable:
:param cf_getter: dataset reader function
:return: raw function, unchanged
"""
Resolver.register_callable(cf_getter)
ModuleRegister.register_callable(cf_getter)
return cf_getter


Expand All @@ -48,16 +54,10 @@ def datasource(ds_getter: typing.Callable) -> typing.Callable:
:param ds_getter: dataset reader function
:return: raw reader function, unchanged
"""
Resolver.register_callable(ds_getter)
ModuleRegister.register_callable(ds_getter)
return ds_getter


def _get_module(stack: typing.List) -> typing.Any:
last_stack = stack[1]
mod = inspect.getmodule(last_stack[0])
return mod


def _get_version(module: typing.Any) -> str:
try:
package_name, _, _ = module.__name__.partition(".")
Expand Down Expand Up @@ -102,9 +102,7 @@ def job(*args, custom_name=None, disable_version=False):
:return: One more job wrapper for run function (if custom name or version override specified).
Otherwise, generates Rialto Transformation Type and returns it for in-module registration.
"""
stack = inspect.stack()

module = _get_module(stack)
module = get_caller_module()
version = _get_version(module)

# Use case where it's just raw @f. Otherwise, we get [] here.
Expand Down
80 changes: 32 additions & 48 deletions rialto/jobs/job_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pyspark.sql import DataFrame, SparkSession

from rialto.common import TableReader
from rialto.jobs.module_register import ModuleRegister
from rialto.jobs.resolver import Resolver
from rialto.loader import PysparkFeatureLoader
from rialto.metadata import MetadataManager
Expand All @@ -50,54 +51,40 @@ def get_job_name(self) -> str:
pass

@contextmanager
def _setup_resolver(self, run_date: datetime.date) -> None:
Resolver.register_callable(lambda: run_date, "run_date")

Resolver.register_callable(self._get_spark, "spark")
Resolver.register_callable(self._get_table_reader, "table_reader")
Resolver.register_callable(self._get_config, "config")

if self._get_feature_loader() is not None:
Resolver.register_callable(self._get_feature_loader, "feature_loader")
if self._get_metadata_manager() is not None:
Resolver.register_callable(self._get_metadata_manager, "metadata_manager")

try:
yield
finally:
Resolver.cache_clear()

def _setup(
def _setup_resolver(
self,
spark: SparkSession,
run_date: datetime.date,
table_reader: TableReader,
config: PipelineConfig = None,
metadata_manager: MetadataManager = None,
feature_loader: PysparkFeatureLoader = None,
) -> None:
self._spark = spark
self._table_rader = table_reader
self._config = config
self._metadata = metadata_manager
self._feature_loader = feature_loader
# Static Always - Available dependencies
Resolver.register_object(spark, "spark")
Resolver.register_object(run_date, "run_date")
Resolver.register_object(config, "config")
Resolver.register_object(table_reader, "table_reader")

def _get_spark(self) -> SparkSession:
return self._spark
# Datasets & Configs
callable_module_name = self.get_custom_callable().__module__
for m in ModuleRegister.get_registered_callables(callable_module_name):
Resolver.register_callable(m)

def _get_table_reader(self) -> TableReader:
return self._table_rader
# Optionals
if feature_loader is not None:
Resolver.register_object(feature_loader, "feature_loader")

def _get_config(self) -> PipelineConfig:
return self._config
if metadata_manager is not None:
Resolver.register_object(metadata_manager, "metadata_manager")

def _get_feature_loader(self) -> PysparkFeatureLoader:
return self._feature_loader
try:
yield

def _get_metadata_manager(self) -> MetadataManager:
return self._metadata
finally:
Resolver.clear()

def _get_timestamp_holder_result(self) -> DataFrame:
spark = self._get_spark()
def _get_timestamp_holder_result(self, spark) -> DataFrame:
return spark.createDataFrame(
[(self.get_job_name(), datetime.datetime.now())], schema="JOB_NAME string, CREATION_TIME timestamp"
)
Expand All @@ -110,17 +97,6 @@ def _add_job_version(self, df: DataFrame) -> DataFrame:

return df

def _run_main_callable(self, run_date: datetime.date) -> DataFrame:
with self._setup_resolver(run_date):
custom_callable = self.get_custom_callable()
raw_result = Resolver.register_resolve(custom_callable)

if raw_result is None:
raw_result = self._get_timestamp_holder_result()

result_with_version = self._add_job_version(raw_result)
return result_with_version

def run(
self,
reader: TableReader,
Expand All @@ -140,8 +116,16 @@ def run(
:return: dataframe
"""
try:
self._setup(spark, reader, config, metadata_manager, feature_loader)
return self._run_main_callable(run_date)
with self._setup_resolver(spark, run_date, reader, config, metadata_manager, feature_loader):
custom_callable = self.get_custom_callable()
raw_result = Resolver.register_resolve(custom_callable)

if raw_result is None:
raw_result = self._get_timestamp_holder_result(spark)

result_with_version = self._add_job_version(raw_result)
return result_with_version

except Exception as e:
logger.exception(e)
raise e
49 changes: 49 additions & 0 deletions rialto/jobs/module_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2022 ABSA Group Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["ModuleRegister"]


class ModuleRegister:
_storage = {}
_dependency_tree = {}

@classmethod
def register_callable(cls, callable):
callable_module = callable.__module__

module_callables = cls._storage.get(callable_module, [])
module_callables.append(callable)

cls._storage[callable_module] = module_callables

@classmethod
def register_dependency(cls, caller_module, module):
caller_module_name = caller_module.__name__
target_module_name = module.__name__

module_dep_tree = cls._dependency_tree.get(caller_module_name, [])
module_dep_tree.append(target_module_name)

cls._dependency_tree[caller_module_name] = module_dep_tree

@classmethod
def get_registered_callables(cls, module_name):
callables = cls._storage.get(module_name, [])

for included_module in cls._dependency_tree.get(module_name, []):
included_callables = cls.get_registered_callables(included_module)
callables.extend(included_callables)

return callables
25 changes: 19 additions & 6 deletions rialto/jobs/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ def _get_args_for_call(cls, function: typing.Callable) -> typing.Dict[str, typin

return result_dict

@classmethod
def register_object(cls, object: typing.Any, name: str) -> None:
"""
Register an object with a given name for later resolution.
:param object: object to register (getter)
:param name: str, custom name
:return: None
"""

cls.register_callable(lambda: object, name)

@classmethod
def register_callable(cls, callable: typing.Callable, name: str = None) -> str:
"""
Expand All @@ -58,6 +70,10 @@ def register_callable(cls, callable: typing.Callable, name: str = None) -> str:
"""
if name is None:
name = getattr(callable, "__name__", repr(callable))
"""
if name in cls._storage:
raise ResolverException(f"Resolver already registered {name}!")
"""

cls._storage[name] = callable
return name
Expand Down Expand Up @@ -97,14 +113,11 @@ def register_resolve(cls, callable: typing.Callable) -> typing.Any:
return cls.resolve(name)

@classmethod
def cache_clear(cls) -> None:
def clear(cls) -> None:
"""
Clear resolver cache.
The resolve method caches its results to avoid duplication of resolutions.
However, in case we re-register some callables, we need to clear cache
in order to ensure re-execution of all resolutions.
Clear all registered datasources and jobs.
:return: None
"""
cls.resolve.cache_clear()
cls._storage.clear()
11 changes: 10 additions & 1 deletion rialto/jobs/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ def __setitem__(self, key, value):
def keys(self):
return self._storage.keys()

def clear(self):
self._storage.clear()

def __getitem__(self, func_name):
if func_name in self._call_stack:
raise ResolverException(f"Circular Dependence on {func_name}!")
Expand All @@ -102,6 +105,12 @@ def __getitem__(self, func_name):
return fake_method

with patch("rialto.jobs.resolver.Resolver._storage", SmartStorage()):
job().run(reader=MagicMock(), run_date=MagicMock(), spark=spark)
job().run(
reader=MagicMock(),
run_date=MagicMock(),
spark=spark,
metadata_manager=MagicMock(),
feature_loader=MagicMock(),
)

return True
9 changes: 9 additions & 0 deletions tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import tests.jobs.resolver_dep_checks_job.datasources as ds
from rialto.jobs import job, register_module

register_module(ds)


@job
def ok_dep_job(datasource_pkg, datasource_base):
pass
6 changes: 6 additions & 0 deletions tests/jobs/resolver_dep_checks_job/cross_dep_tests_job_b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from rialto.jobs import job


@job
def missing_dep_job(datasource_pkg, datasource_base):
pass
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from rialto.jobs import datasource, job
import tests.jobs.resolver_dep_checks_job.dep_package.pkg_datasources as pkg_ds
from rialto.jobs import datasource, register_module

register_module(pkg_ds)


@datasource
Expand All @@ -16,11 +19,6 @@ def c(a, b):
return a + b


@job
def ok_dependency_job(c):
return c + 1


@datasource
def d(a, circle_1):
return circle_1 + a
Expand All @@ -36,16 +34,6 @@ def circle_2(circle_1):
return circle_1 + 1


@job
def circular_dependency_job(d):
return d + 1


@job
def missing_dependency_job(a, x):
return x + a


@job
def default_dependency_job(run_date, spark, config):
return 1
@datasource
def datasource_base():
return "dataset_base_return"
Loading

0 comments on commit f9e5e72

Please sign in to comment.