Skip to content

Commit

Permalink
Resolver Reworked
Browse files Browse the repository at this point in the history
  • Loading branch information
vvancak committed Sep 17, 2024
1 parent f9e5e72 commit a13553c
Show file tree
Hide file tree
Showing 20 changed files with 342 additions and 282 deletions.
6 changes: 5 additions & 1 deletion rialto/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from rialto.jobs.decorators import config_parser, datasource, job, register_module
from rialto.jobs.decorators import config_parser, datasource, job
from rialto.jobs.module_register import (
register_dependency_callable,
register_dependency_module,
)
6 changes: 0 additions & 6 deletions rialto/jobs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

__all__ = ["datasource", "job", "config_parser"]

import inspect
import typing

import importlib_metadata
Expand All @@ -25,11 +24,6 @@
from rialto.jobs.module_register import ModuleRegister


def register_module(module):
caller_module = get_caller_module()
ModuleRegister.register_dependency(caller_module, module)


def config_parser(cf_getter: typing.Callable) -> typing.Callable:
"""
Config parser functions decorator.
Expand Down
39 changes: 15 additions & 24 deletions rialto/jobs/job_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@
import abc
import datetime
import typing
from contextlib import contextmanager

import pyspark.sql.functions as F
from loguru import logger
from pyspark.sql import DataFrame, SparkSession

from rialto.common import TableReader
from rialto.jobs.module_register import ModuleRegister
from rialto.jobs.resolver import Resolver
from rialto.loader import PysparkFeatureLoader
from rialto.metadata import MetadataManager
Expand All @@ -50,39 +48,31 @@ def get_job_name(self) -> str:
"""Job name getter"""
pass

@contextmanager
def _setup_resolver(
def _get_resolver(
self,
spark: SparkSession,
run_date: datetime.date,
table_reader: TableReader,
config: PipelineConfig = None,
metadata_manager: MetadataManager = None,
feature_loader: PysparkFeatureLoader = None,
) -> None:
# Static Always - Available dependencies
Resolver.register_object(spark, "spark")
Resolver.register_object(run_date, "run_date")
Resolver.register_object(config, "config")
Resolver.register_object(table_reader, "table_reader")
) -> Resolver:
resolver = Resolver()

# Datasets & Configs
callable_module_name = self.get_custom_callable().__module__
for m in ModuleRegister.get_registered_callables(callable_module_name):
Resolver.register_callable(m)
# Static Always - Available dependencies
resolver.register_object(spark, "spark")
resolver.register_object(run_date, "run_date")
resolver.register_object(config, "config")
resolver.register_object(table_reader, "table_reader")

# Optionals
if feature_loader is not None:
Resolver.register_object(feature_loader, "feature_loader")
resolver.register_object(feature_loader, "feature_loader")

if metadata_manager is not None:
Resolver.register_object(metadata_manager, "metadata_manager")
resolver.register_object(metadata_manager, "metadata_manager")

try:
yield

finally:
Resolver.clear()
return resolver

def _get_timestamp_holder_result(self, spark) -> DataFrame:
return spark.createDataFrame(
Expand Down Expand Up @@ -116,9 +106,10 @@ def run(
:return: dataframe
"""
try:
with self._setup_resolver(spark, run_date, reader, config, metadata_manager, feature_loader):
custom_callable = self.get_custom_callable()
raw_result = Resolver.register_resolve(custom_callable)
resolver = self._get_resolver(spark, run_date, reader, config, metadata_manager, feature_loader)

custom_callable = self.get_custom_callable()
raw_result = resolver.resolve(custom_callable)

if raw_result is None:
raw_result = self._get_timestamp_holder_result(spark)
Expand Down
85 changes: 69 additions & 16 deletions rialto/jobs/module_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,91 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["ModuleRegister"]
__all__ = ["ModuleRegister", "register_dependency_module", "register_dependency_callable"]

from rialto.common.utils import get_caller_module


class ModuleRegister:
"""
Module register. Class which is used by @datasource and @config_parser decorators to register callables / getters.
Resolver, when searching for a getter for f() defined in module M, uses find_callable("f", "M").
"""

_storage = {}
_dependency_tree = {}

@classmethod
def register_callable(cls, callable):
callable_module = callable.__module__
def add_callable_to_module(cls, callable, module_name):
"""
Adds a callable to the specified module's storage.
module_callables = cls._storage.get(callable_module, [])
:param callable: The callable to be added.
:param module_name: The name of the module to which the callable is added.
"""
module_callables = cls._storage.get(module_name, [])
module_callables.append(callable)

cls._storage[callable_module] = module_callables
cls._storage[module_name] = module_callables

@classmethod
def register_callable(cls, callable):
"""
Registers a callable by adding it to the module's storage.
:param callable: The callable to be registered.
"""
callable_module = callable.__module__
cls.add_callable_to_module(callable, callable_module)

@classmethod
def register_dependency(cls, caller_module, module):
caller_module_name = caller_module.__name__
target_module_name = module.__name__
"""
Registers a module as a dependency of the caller module.
module_dep_tree = cls._dependency_tree.get(caller_module_name, [])
module_dep_tree.append(target_module_name)
:param caller_module: The module that is registering the dependency.
:param module: The module to be registered as a dependency.
"""
module_dep_tree = cls._dependency_tree.get(caller_module, [])
module_dep_tree.append(module)

cls._dependency_tree[caller_module_name] = module_dep_tree
cls._dependency_tree[caller_module] = module_dep_tree

@classmethod
def get_registered_callables(cls, module_name):
callables = cls._storage.get(module_name, [])
def find_callable(cls, callable_name, module_name):
"""
Finds a callable by its name in the specified module and its dependencies.
:param callable_name: The name of the callable to find.
:param module_name: The name of the module to search in.
:return: The found callable or None if not found.
"""

# Loop through this module, and its dependencies
searched_modules = [module_name] + cls._dependency_tree.get(module_name, [])
for module in searched_modules:
# Loop through all functions registered in the module
for func in cls._storage.get(module, []):
if func.__name__ == callable_name:
return func


def register_dependency_module(module):
"""
Registers a module as a dependency of the caller module.
:param module: The module to be registered as a dependency.
"""
caller_module = get_caller_module().__name__
ModuleRegister.register_dependency(caller_module, module.__name__)


for included_module in cls._dependency_tree.get(module_name, []):
included_callables = cls.get_registered_callables(included_module)
callables.extend(included_callables)
def register_dependency_callable(callable):
"""
Registers a callable as a dependency of the caller module.
Note that the function will be added to the module's list of available dependencies.
return callables
:param callable: The callable to be registered as a dependency.
"""
caller_module_name = get_caller_module().__name__
ModuleRegister.add_callable_to_module(callable, caller_module_name)
85 changes: 33 additions & 52 deletions rialto/jobs/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import inspect
import typing
from functools import cache

from rialto.jobs.module_register import ModuleRegister


class ResolverException(Exception):
Expand All @@ -33,20 +34,10 @@ class Resolver:
Calling resolve() we attempt to resolve these dependencies.
"""

_storage = {}

@classmethod
def _get_args_for_call(cls, function: typing.Callable) -> typing.Dict[str, typing.Any]:
result_dict = {}
signature = inspect.signature(function)

for param in signature.parameters.values():
result_dict[param.name] = cls.resolve(param.name)

return result_dict
def __init__(self):
self._storage = {}

@classmethod
def register_object(cls, object: typing.Any, name: str) -> None:
def register_object(self, object: typing.Any, name: str) -> None:
"""
Register an object with a given name for later resolution.
Expand All @@ -55,10 +46,9 @@ def register_object(cls, object: typing.Any, name: str) -> None:
:return: None
"""

cls.register_callable(lambda: object, name)
self.register_getter(lambda: object, name)

@classmethod
def register_callable(cls, callable: typing.Callable, name: str = None) -> str:
def register_getter(self, callable: typing.Callable, name: str = None) -> str:
"""
Register callable with a given name for later resolution.
Expand All @@ -70,54 +60,45 @@ def register_callable(cls, callable: typing.Callable, name: str = None) -> str:
"""
if name is None:
name = getattr(callable, "__name__", repr(callable))
"""
if name in cls._storage:

if name in self._storage:
raise ResolverException(f"Resolver already registered {name}!")
"""

cls._storage[name] = callable
self._storage[name] = callable
return name

@classmethod
@cache
def resolve(cls, name: str) -> typing.Any:
def _find_getter(self, name: str, module_name) -> typing.Callable:
if name in self._storage.keys():
return self._storage[name]

callable_from_dependencies = ModuleRegister.find_callable(name, module_name)
if callable_from_dependencies is None:
raise ResolverException(f"{name} declaration not found!")

return callable_from_dependencies

def resolve(self, callable: typing.Callable) -> typing.Dict[str, typing.Any]:
"""
Search for a callable registered prior and attempt to call it with correct arguents.
Take a callable and resolve its dependencies / arguments. Arguments can be
a) objects registered via register_object
b) callables registered via register_getter
c) ModuleRegister registered callables via ModuleRegister.register_callable (+ dependencies)
Arguments are resolved recursively according to requirements; For example, if we have
a(b, c), b(d), and c(), d() registered, then we recursively call resolve() methods until we resolve
c, d -> b -> a
:param name: name of the callable to resolve
:param callable: function to resolve
:return: result of the callable
"""
if name not in cls._storage.keys():
raise ResolverException(f"{name} declaration not found!")

getter = cls._storage[name]
args = cls._get_args_for_call(getter)

return getter(**args)

@classmethod
def register_resolve(cls, callable: typing.Callable) -> typing.Any:
"""
Register and Resolve a callable.
arg_list = {}

Combination of the register() and resolve() methods for a simplified execution.
signature = inspect.signature(callable)
module_name = callable.__module__

:param callable: callable to register and immediately resolve
:return: result of the callable
"""
name = cls.register_callable(callable)
return cls.resolve(name)

@classmethod
def clear(cls) -> None:
"""
Clear all registered datasources and jobs.
for param in signature.parameters.values():
param_getter = self._find_getter(param.name, module_name)
arg_list[param.name] = self.resolve(param_getter)

:return: None
"""
cls.resolve.cache_clear()
cls._storage.clear()
return callable(**arg_list)
Loading

0 comments on commit a13553c

Please sign in to comment.