Skip to content

Commit

Permalink
Module Register rework of dependency registration (#15)
Browse files Browse the repository at this point in the history
* Module Register

* Resolver Reworked

* Overlapping function names prevention

* Documentation

* consitent arguments

---------

Co-authored-by: Marek Dobransky <[email protected]>
  • Loading branch information
vvancak and MDobransky authored Sep 21, 2024
1 parent b25a562 commit cf2a12c
Show file tree
Hide file tree
Showing 24 changed files with 904 additions and 604 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ All notable changes to this project will be documented in this file.
- config holder removed from jobs
- metadata_manager and feature_loader are now available arguments, depending on configuration
- added @config decorator, similar use case to @datasource, for parsing configuration
- reworked Resolver + Added ModuleRegister
- datasources no longer just by importing, thus are no longer available for all jobs
- register_dependency_callable and register_dependency_module added to register datasources
- together, it's now possilbe to have 2 datasources with the same name, but different implementations for 2 jobs.
#### TableReader
- function signatures changed
- until -> date_until
Expand Down
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ With that sorted out, we can now provide a quick example of the *rialto.jobs* mo
```python
from pyspark.sql import DataFrame
from rialto.common import TableReader
from rialto.jobs.decorators import config_parser, job, datasource
from rialto.jobs import config_parser, job, datasource
from rialto.runner.config_loader import PipelineConfig
from pydantic import BaseModel

Expand Down Expand Up @@ -419,7 +419,6 @@ If you want to disable versioning of your job (adding package VERSION column to
def my_job(...):
...
```

These parameters can be used separately, or combined.

### Notes & Rules
Expand All @@ -435,6 +434,32 @@ This can be useful in **model training**.
Finally, remember, that your jobs are still just *Rialto Transformations* internally.
Meaning that at the end of the day, you should always read some data, do some operations on it and either return a pyspark DataFrame, or not return anything and let the framework return the placeholder one.


### Importing / Registering Datasources
Datasources required for a job (or another datasource) can be defined in a different module.
To register your module as a datasource, you can use the following functions:

```python3
from rialto.jobs import register_dependency_callable, register_dependency_module
import my_package.my_datasources as md
import my_package.my_datasources_big as big_md

# Register an entire dependency module
register_dependency_module(md)

# Register a single datasource from a bigger module
register_dependency_callable(big_md.sample_datasource)

@job
def my_job(my_datasource, sample_datasource: DataFrame, ...):
...
```

Each job/datasource can only resolve datasources it has defined as dependencies.

**NOTE**: While ```register_dependency_module``` only registers a module as available dependencies, the ```register_dependency_callable``` actually brings the datasource into the targed module - and thus becomes available for export in the dependency chains.


### Testing
One of the main advantages of the jobs module is simplification of unit tests for your transformations. Rialto provides following tools:

Expand Down
719 changes: 375 additions & 344 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tool.poetry]
name = "rialto-dev"
name = "rialto"

version = "2.0.0"

Expand Down Expand Up @@ -31,6 +31,7 @@ pandas = "^2.1.0"
flake8-broken-line = "^1.0.0"
loguru = "^0.7.2"
importlib-metadata = "^7.2.1"
numpy = "<2.0.0"

[tool.poetry.dev-dependencies]
pyspark = "^3.4.1"
Expand Down
24 changes: 22 additions & 2 deletions rialto/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = ["load_yaml"]
__all__ = ["load_yaml", "cast_decimals_to_floats", "get_caller_module"]

import inspect
import os
from typing import Any
from typing import Any, List

import pyspark.sql.functions as F
import yaml
Expand Down Expand Up @@ -51,3 +52,22 @@ def cast_decimals_to_floats(df: DataFrame) -> DataFrame:
df = df.withColumn(c, F.col(c).cast(FloatType()))

return df


def get_caller_module() -> Any:
"""
Ged module containing the function which is calling your function.
Inspects the call stack, where:
0th entry is this function
1st entry is the function which needs to know who called it
2nd entry is the calling function
Therefore, we'll return a module which contains the function at the 2nd place on the stack.
:return: Python Module containing the calling function.
"""

stack = inspect.stack()
last_stack = stack[2]
return inspect.getmodule(last_stack[0])
4 changes: 4 additions & 0 deletions rialto/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@
# limitations under the License.

from rialto.jobs.decorators import config_parser, datasource, job
from rialto.jobs.module_register import (
register_dependency_callable,
register_dependency_module,
)
18 changes: 5 additions & 13 deletions rialto/jobs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

__all__ = ["datasource", "job", "config_parser"]

import inspect
import typing

import importlib_metadata
from loguru import logger

from rialto.common.utils import get_caller_module
from rialto.jobs.job_base import JobBase
from rialto.jobs.resolver import Resolver
from rialto.jobs.module_register import ModuleRegister


def config_parser(cf_getter: typing.Callable) -> typing.Callable:
Expand All @@ -34,7 +34,7 @@ def config_parser(cf_getter: typing.Callable) -> typing.Callable:
:param cf_getter: dataset reader function
:return: raw function, unchanged
"""
Resolver.register_callable(cf_getter)
ModuleRegister.register_callable(cf_getter)
return cf_getter


Expand All @@ -48,16 +48,10 @@ def datasource(ds_getter: typing.Callable) -> typing.Callable:
:param ds_getter: dataset reader function
:return: raw reader function, unchanged
"""
Resolver.register_callable(ds_getter)
ModuleRegister.register_callable(ds_getter)
return ds_getter


def _get_module(stack: typing.List) -> typing.Any:
last_stack = stack[1]
mod = inspect.getmodule(last_stack[0])
return mod


def _get_version(module: typing.Any) -> str:
try:
package_name, _, _ = module.__name__.partition(".")
Expand Down Expand Up @@ -102,9 +96,7 @@ def job(*args, custom_name=None, disable_version=False):
:return: One more job wrapper for run function (if custom name or version override specified).
Otherwise, generates Rialto Transformation Type and returns it for in-module registration.
"""
stack = inspect.stack()

module = _get_module(stack)
module = get_caller_module()
version = _get_version(module)

# Use case where it's just raw @f. Otherwise, we get [] here.
Expand Down
79 changes: 27 additions & 52 deletions rialto/jobs/job_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import abc
import datetime
import typing
from contextlib import contextmanager

import pyspark.sql.functions as F
from loguru import logger
Expand Down Expand Up @@ -49,55 +48,33 @@ def get_job_name(self) -> str:
"""Job name getter"""
pass

@contextmanager
def _setup_resolver(self, run_date: datetime.date) -> None:
Resolver.register_callable(lambda: run_date, "run_date")

Resolver.register_callable(self._get_spark, "spark")
Resolver.register_callable(self._get_table_reader, "table_reader")
Resolver.register_callable(self._get_config, "config")

if self._get_feature_loader() is not None:
Resolver.register_callable(self._get_feature_loader, "feature_loader")
if self._get_metadata_manager() is not None:
Resolver.register_callable(self._get_metadata_manager, "metadata_manager")

try:
yield
finally:
Resolver.cache_clear()

def _setup(
def _get_resolver(
self,
spark: SparkSession,
run_date: datetime.date,
table_reader: TableReader,
config: PipelineConfig = None,
metadata_manager: MetadataManager = None,
feature_loader: PysparkFeatureLoader = None,
) -> None:
self._spark = spark
self._table_rader = table_reader
self._config = config
self._metadata = metadata_manager
self._feature_loader = feature_loader
) -> Resolver:
resolver = Resolver()

def _get_spark(self) -> SparkSession:
return self._spark
# Static Always - Available dependencies
resolver.register_object(spark, "spark")
resolver.register_object(run_date, "run_date")
resolver.register_object(config, "config")
resolver.register_object(table_reader, "table_reader")

def _get_table_reader(self) -> TableReader:
return self._table_rader
# Optionals
if feature_loader is not None:
resolver.register_object(feature_loader, "feature_loader")

def _get_config(self) -> PipelineConfig:
return self._config
if metadata_manager is not None:
resolver.register_object(metadata_manager, "metadata_manager")

def _get_feature_loader(self) -> PysparkFeatureLoader:
return self._feature_loader
return resolver

def _get_metadata_manager(self) -> MetadataManager:
return self._metadata

def _get_timestamp_holder_result(self) -> DataFrame:
spark = self._get_spark()
def _get_timestamp_holder_result(self, spark) -> DataFrame:
return spark.createDataFrame(
[(self.get_job_name(), datetime.datetime.now())], schema="JOB_NAME string, CREATION_TIME timestamp"
)
Expand All @@ -110,17 +87,6 @@ def _add_job_version(self, df: DataFrame) -> DataFrame:

return df

def _run_main_callable(self, run_date: datetime.date) -> DataFrame:
with self._setup_resolver(run_date):
custom_callable = self.get_custom_callable()
raw_result = Resolver.register_resolve(custom_callable)

if raw_result is None:
raw_result = self._get_timestamp_holder_result()

result_with_version = self._add_job_version(raw_result)
return result_with_version

def run(
self,
reader: TableReader,
Expand All @@ -140,8 +106,17 @@ def run(
:return: dataframe
"""
try:
self._setup(spark, reader, config, metadata_manager, feature_loader)
return self._run_main_callable(run_date)
resolver = self._get_resolver(spark, run_date, reader, config, metadata_manager, feature_loader)

custom_callable = self.get_custom_callable()
raw_result = resolver.resolve(custom_callable)

if raw_result is None:
raw_result = self._get_timestamp_holder_result(spark)

result_with_version = self._add_job_version(raw_result)
return result_with_version

except Exception as e:
logger.exception(e)
raise e
Loading

0 comments on commit cf2a12c

Please sign in to comment.