Skip to content

Commit

Permalink
Merge pull request #20 from AbsaOSS/Release/2.0.1
Browse files Browse the repository at this point in the history
Release/2.0.1
  • Loading branch information
MDobransky authored Oct 2, 2024
2 parents 79d51ad + 4160fdd commit eee942d
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
project = "rialto"
copyright = "2022, Marek Dobransky"
author = "Marek Dobransky"
release = "1.3.0"
release = "2.0.1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]
name = "rialto"

version = "2.0.0"
version = "2.0.1"

packages = [
{ include = "rialto" },
Expand Down
9 changes: 9 additions & 0 deletions rialto/jobs/module_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def register_callable(cls, callable):
callable_module = callable.__module__
cls.add_callable_to_module(callable, callable_module)

@classmethod
def remove_module(cls, module):
"""
Remove a module from the storage.
:param module: The module to be removed.
"""
cls._storage.pop(module.__name__, None)

@classmethod
def register_dependency(cls, module, parent_name):
"""
Expand Down
9 changes: 6 additions & 3 deletions rialto/jobs/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from unittest.mock import MagicMock, create_autospec, patch

from rialto.jobs.job_base import JobBase
from rialto.jobs.module_register import ModuleRegister
from rialto.jobs.resolver import Resolver, ResolverException


Expand Down Expand Up @@ -59,15 +60,17 @@ def disable_job_decorators(module) -> None:
:return: None
"""
with _disable_job_decorators():
ModuleRegister.remove_module(module)
importlib.reload(module)
yield

ModuleRegister.remove_module(module)
importlib.reload(module)


def resolver_resolves(spark, job: JobBase) -> bool:
"""
Checker method for your dependency resoultion.
Checker method for your dependency resolution.
If your job's dependencies are all defined and resolvable, returns true.
Otherwise, throws an exception.
Expand Down Expand Up @@ -100,8 +103,8 @@ def stack_watching_resolver_resolve(self, callable):

return result

with patch(f"rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve):
with patch(f"rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x):
with patch("rialto.jobs.job_base.Resolver.resolve", stack_watching_resolver_resolve):
with patch("rialto.jobs.job_base.JobBase._add_job_version", lambda _, x: x):
job().run(
reader=MagicMock(),
run_date=MagicMock(),
Expand Down
7 changes: 5 additions & 2 deletions rialto/maker/feature_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def _set_values(self, df: DataFrame, key: typing.Union[str, typing.List[str]], m
:return: None
"""
self.data_frame = df
self.key = key
if isinstance(key, str):
self.key = [key]
else:
self.key = key
self.make_date = make_date

def _order_by_dependencies(self, feature_holders: typing.List[FeatureHolder]) -> typing.List[FeatureHolder]:
Expand Down Expand Up @@ -136,7 +139,7 @@ def _make_sequential(self, keep_preexisting: bool) -> DataFrame:
)
if not keep_preexisting:
logger.info("Dropping non-selected columns")
self.data_frame = self.data_frame.select(self.key, *feature_names)
self.data_frame = self.data_frame.select(*self.key, *feature_names)
return self._filter_null_keys(self.data_frame)

def _make_aggregated(self) -> DataFrame:
Expand Down
4 changes: 3 additions & 1 deletion rialto/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(

if self.date_from > self.date_until:
raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}")
logger.info(f"Running period from {self.date_from} until {self.date_until}")
logger.info(f"Running period set to: {self.date_from} - {self.date_until}")

def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame:
"""
Expand Down Expand Up @@ -285,6 +285,7 @@ def _run_pipeline(self, pipeline: PipelineConfig):

def __call__(self):
"""Execute pipelines"""
logger.info("Executing pipelines")
try:
if self.op:
selected = [p for p in self.config.pipelines if p.name == self.op]
Expand All @@ -297,3 +298,4 @@ def __call__(self):
finally:
print(self.tracker.records)
self.tracker.report(self.config.runner.mail)
logger.info("Execution finished")
5 changes: 5 additions & 0 deletions tests/jobs/test_job/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ def job_function():
return "job_function_return"


@job
def job_with_datasource(dataset):
return dataset


@job(custom_name="custom_job_name")
def custom_name_job_function():
return "custom_job_name_return"
Expand Down
8 changes: 8 additions & 0 deletions tests/jobs/test_register.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from rialto.jobs.test_utils import disable_job_decorators, resolver_resolves
from tests.jobs.test_job import test_job


def test_resolve_after_disable(spark):
with disable_job_decorators(test_job):
assert test_job.job_with_datasource("test") == "test"
assert resolver_resolves(spark, test_job.job_with_datasource)
7 changes: 7 additions & 0 deletions tests/maker/test_FeatureMaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def test_sequential_multi_key(input_df):
assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns


def test_sequential_multi_key_drop(input_df):
df, _ = FeatureMaker.make(
input_df, ["CUSTOMER_KEY", "TYPE"], date.today(), sequential_outbound, keep_preexisting=False
)
assert "TRANSACTIONS_OUTBOUND_VALUE" in df.columns


def test_sequential_keeps(input_df):
df, _ = FeatureMaker.make(input_df, "CUSTOMER_KEY", date.today(), sequential_outbound, keep_preexisting=True)
assert "AMT" in df.columns
Expand Down

0 comments on commit eee942d

Please sign in to comment.