From a49b06cbb31a71ed4c0ed266016fe3b8de7b509e Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Fri, 23 Aug 2024 19:02:09 +0200 Subject: [PATCH 01/12] v2 changes --- README.md | 25 ++-- poetry.lock | 31 +++-- pyproject.toml | 5 +- rialto/common/table_reader.py | 70 ++++------ rialto/common/utils.py | 5 +- rialto/jobs/__init__.py | 2 + rialto/jobs/configuration/config_holder.py | 130 ------------------- rialto/jobs/decorators/decorators.py | 4 +- rialto/jobs/decorators/job_base.py | 46 ++++--- rialto/jobs/decorators/resolver.py | 2 +- rialto/loader/data_loader.py | 2 +- rialto/runner/config_loader.py | 46 ++++--- rialto/runner/runner.py | 108 ++++++--------- rialto/runner/tracker.py | 13 +- rialto/runner/transformation.py | 11 +- tests/jobs/test_config_holder.py | 100 -------------- tests/jobs/test_decorators.py | 4 +- tests/jobs/test_job/dependency_tests_job.py | 4 +- tests/jobs/test_job/test_job.py | 7 +- tests/jobs/test_job_base.py | 36 ++--- tests/runner/conftest.py | 4 +- tests/runner/test_runner.py | 83 +++--------- tests/runner/transformations/config.yaml | 11 +- tests/runner/transformations/config2.yaml | 8 +- tests/runner/transformations/simple_group.py | 6 +- 25 files changed, 229 insertions(+), 534 deletions(-) delete mode 100644 rialto/jobs/configuration/config_holder.py delete mode 100644 tests/jobs/test_config_holder.py diff --git a/README.md b/README.md index 4f52d50..14f4dae 100644 --- a/README.md +++ b/README.md @@ -312,17 +312,18 @@ from pyspark.sql import DataFrame from rialto.common import TableReader from rialto.jobs.decorators import job, datasource + @datasource def my_datasource(run_date: datetime.date, table_reader: TableReader) -> DataFrame: - return table_reader.get_latest("my_catalog.my_schema.my_table", until=run_date) + return table_reader.get_latest("my_catalog.my_schema.my_table", date_until=run_date) @job def my_job(my_datasource: DataFrame) -> DataFrame: - return my_datasource.withColumn("HelloWorld", F.lit(1)) + return my_datasource.withColumn("HelloWorld", F.lit(1)) ``` -This piece of code -1. creates a rialto transformation called *my_job*, which is then callable by the rialto runner. +This piece of code +1. creates a rialto transformation called *my_job*, which is then callable by the rialto runner. 2. It sources the *my_datasource* and then runs *my_job* on top of that datasource. 3. Rialto adds VERSION (of your package) and INFORMATION_DATE (as per config) columns automatically. 4. The rialto runner stores the final to a catalog, to a table according to the job's name. @@ -383,20 +384,20 @@ import my_package.test_job_module as tjm # Datasource Testing def test_datasource_a(): ... mocks here ... - + with disable_job_decorators(tjm): datasource_a_output = tjm.datasource_a(... mocks ...) - + ... asserts ... - + # Job Testing def test_my_job(): datasource_a_mock = ... ... other mocks... - + with disable_job_decorators(tjm): job_output = tjm.my_job(datasource_a_mock, ... mocks ...) - + ... asserts ... ``` @@ -563,6 +564,7 @@ reader = TableReader(spark=spark_instance) ``` usage of _get_table_: + ```python # get whole table df = reader.get_table(table="catalog.schema.table", date_column="information_date") @@ -573,10 +575,11 @@ from datetime import datetime start = datetime.strptime("2020-01-01", "%Y-%m-%d").date() end = datetime.strptime("2024-01-01", "%Y-%m-%d").date() -df = reader.get_table(table="catalog.schema.table", info_date_from=start, info_date_to=end) +df = reader.get_table(table="catalog.schema.table", date_from=start, date_to=end) ``` usage of _get_latest_: + ```python # most recent partition df = reader.get_latest(table="catalog.schema.table", date_column="information_date") @@ -584,7 +587,7 @@ df = reader.get_latest(table="catalog.schema.table", date_column="information_da # most recent partition until until = datetime.strptime("2020-01-01", "%Y-%m-%d").date() -df = reader.get_latest(table="catalog.schema.table", until=until, date_column="information_date") +df = reader.get_latest(table="catalog.schema.table", date_until=until, date_column="information_date") ``` For full information on parameters and their optionality see technical documentation. diff --git a/poetry.lock b/poetry.lock index 0cb768b..66ca41b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -343,6 +343,20 @@ files = [ {file = "docutils-0.20.1.tar.gz", hash = "sha256:f08a4e276c3a1583a86dce3e34aba3fe04d02bba2dd51ed16106244e8a923e3b"}, ] +[[package]] +name = "env-yaml" +version = "0.0.3" +description = "Provides a yaml loader which substitutes environment variables and supports defaults" +optional = false +python-versions = "*" +files = [ + {file = "env-yaml-0.0.3.tar.gz", hash = "sha256:b6b55b18c28fb623793137a8e55bd666d6483af7fd0162a41a62325ce662fda6"}, + {file = "env_yaml-0.0.3-py3-none-any.whl", hash = "sha256:f56723c8997bea1240bf634b9e29832714dd9745a42cbc2649f1238a6a576244"}, +] + +[package.dependencies] +pyyaml = ">=6.0" + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -751,9 +765,9 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -906,8 +920,8 @@ files = [ annotated-types = ">=0.4.0" pydantic-core = "2.20.1" typing-extensions = [ - {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, {version = ">=4.6.1", markers = "python_version < \"3.13\""}, + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, ] [package.extras] @@ -1170,7 +1184,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1178,16 +1191,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1204,7 +1209,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1212,7 +1216,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -1544,4 +1547,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "243b1919c3e881039c2cd7b4e786f455b15a78872278050e7850e6a21c706c8e" +content-hash = "6e87c6539147b57b03fb983b28d15396c2eccfe95661805eda7d9f77602d1f58" diff --git a/pyproject.toml b/pyproject.toml index 8255885..5812612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] -name = "rialto" +name = "rialto-dev" -version = "1.3.2" +version = "2.0.0" packages = [ { include = "rialto" }, @@ -31,6 +31,7 @@ pandas = "^2.1.0" flake8-broken-line = "^1.0.0" loguru = "^0.7.2" importlib-metadata = "^7.2.1" +env_yaml = "^0.0.3" [tool.poetry.dev-dependencies] pyspark = "^3.4.1" diff --git a/rialto/common/table_reader.py b/rialto/common/table_reader.py index 1aef614..d3926f2 100644 --- a/rialto/common/table_reader.py +++ b/rialto/common/table_reader.py @@ -21,8 +21,6 @@ import pyspark.sql.functions as F from pyspark.sql import DataFrame, SparkSession -from rialto.common.utils import get_date_col_property, get_delta_partition - class DataReader(metaclass=abc.ABCMeta): """ @@ -36,16 +34,15 @@ class DataReader(metaclass=abc.ABCMeta): def get_latest( self, table: str, - until: Optional[datetime.date] = None, - date_column: str = None, + date_column: str, + date_until: Optional[datetime.date] = None, uppercase_columns: bool = False, ) -> DataFrame: """ Get latest available date partition of the table until specified date :param table: input table path - :param until: Optional until date (inclusive) - :param date_column: column to filter dates on, takes highest priority + :param date_until: Optional until date (inclusive) :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ @@ -55,18 +52,17 @@ def get_latest( def get_table( self, table: str, - info_date_from: Optional[datetime.date] = None, - info_date_to: Optional[datetime.date] = None, - date_column: str = None, + date_column: str, + date_from: Optional[datetime.date] = None, + date_to: Optional[datetime.date] = None, uppercase_columns: bool = False, ) -> DataFrame: """ Get a whole table or a slice by selected dates :param table: input table path - :param info_date_from: Optional date from (inclusive) - :param info_date_to: Optional date to (inclusive) - :param date_column: column to filter dates on, takes highest priority + :param date_from: Optional date from (inclusive) + :param date_to: Optional date to (inclusive) :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ @@ -76,17 +72,13 @@ def get_table( class TableReader(DataReader): """An implementation of data reader for databricks tables""" - def __init__(self, spark: SparkSession, date_property: str = "rialto_date_column", infer_partition: bool = False): + def __init__(self, spark: SparkSession): """ Init :param spark: - :param date_property: Databricks table property specifying date column, take priority over inference - :param infer_partition: infer date column as tables partition from delta metadata """ self.spark = spark - self.date_property = date_property - self.infer_partition = infer_partition super().__init__() def _uppercase_column_names(self, df: DataFrame) -> DataFrame: @@ -106,41 +98,26 @@ def _get_latest_available_date(self, df: DataFrame, date_col: str, until: Option df = df.select(F.max(date_col)).alias("latest") return df.head()[0] - def _get_date_col(self, table: str, date_column: str): - """ - Get tables date column - - column specified at get_table/get_latest takes priority, if inference is enabled it - takes 2nd place, last resort is table property - """ - if date_column: - return date_column - elif self.infer_partition: - return get_delta_partition(self.spark, table) - else: - return get_date_col_property(self.spark, table, self.date_property) - def get_latest( self, table: str, - until: Optional[datetime.date] = None, - date_column: str = None, + date_column: str, + date_until: Optional[datetime.date] = None, uppercase_columns: bool = False, ) -> DataFrame: """ Get latest available date partition of the table until specified date :param table: input table path - :param until: Optional until date (inclusive) + :param date_until: Optional until date (inclusive) :param date_column: column to filter dates on, takes highest priority :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ - date_col = self._get_date_col(table, date_column) df = self.spark.read.table(table) - selected_date = self._get_latest_available_date(df, date_col, until) - df = df.filter(F.col(date_col) == selected_date) + selected_date = self._get_latest_available_date(df, date_column, date_until) + df = df.filter(F.col(date_column) == selected_date) if uppercase_columns: df = self._uppercase_column_names(df) @@ -149,28 +126,27 @@ def get_latest( def get_table( self, table: str, - info_date_from: Optional[datetime.date] = None, - info_date_to: Optional[datetime.date] = None, - date_column: str = None, + date_column: str, + date_from: Optional[datetime.date] = None, + date_to: Optional[datetime.date] = None, uppercase_columns: bool = False, ) -> DataFrame: """ Get a whole table or a slice by selected dates :param table: input table path - :param info_date_from: Optional date from (inclusive) - :param info_date_to: Optional date to (inclusive) + :param date_from: Optional date from (inclusive) + :param date_to: Optional date to (inclusive) :param date_column: column to filter dates on, takes highest priority :param uppercase_columns: Option to refactor all column names to uppercase :return: Dataframe """ - date_col = self._get_date_col(table, date_column) df = self.spark.read.table(table) - if info_date_from: - df = df.filter(F.col(date_col) >= info_date_from) - if info_date_to: - df = df.filter(F.col(date_col) <= info_date_to) + if date_from: + df = df.filter(F.col(date_column) >= date_from) + if date_to: + df = df.filter(F.col(date_column) <= date_to) if uppercase_columns: df = self._uppercase_column_names(df) return df diff --git a/rialto/common/utils.py b/rialto/common/utils.py index c5527a8..6c2952c 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["load_yaml", "get_date_col_property", "get_delta_partition"] +__all__ = ["load_yaml"] import os from typing import Any import pyspark.sql.functions as F import yaml +from env_yaml import EnvLoader from pyspark.sql import DataFrame from pyspark.sql.types import FloatType @@ -34,7 +35,7 @@ def load_yaml(path: str) -> Any: raise FileNotFoundError(f"Can't find {path}.") with open(path, "r") as stream: - return yaml.safe_load(stream) + return yaml.load(stream, EnvLoader) def get_date_col_property(spark, table: str, property: str) -> str: diff --git a/rialto/jobs/__init__.py b/rialto/jobs/__init__.py index 79c3773..90183bd 100644 --- a/rialto/jobs/__init__.py +++ b/rialto/jobs/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from rialto.jobs.decorators import datasource, job diff --git a/rialto/jobs/configuration/config_holder.py b/rialto/jobs/configuration/config_holder.py deleted file mode 100644 index 161c61a..0000000 --- a/rialto/jobs/configuration/config_holder.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ["ConfigException", "FeatureStoreConfig", "ConfigHolder"] - -import datetime -import typing - -from pydantic import BaseModel - - -class ConfigException(Exception): - """Wrong Configuration Exception""" - - pass - - -class FeatureStoreConfig(BaseModel): - """Configuration of Feature Store Paths""" - - feature_store_schema: str = None - feature_metadata_schema: str = None - - -class ConfigHolder: - """ - Main Rialto Jobs config holder. - - Configured via job_runner and then called from job_base / job decorators. - """ - - _config = {} - _dependencies = {} - _run_date = None - _feature_store_config: FeatureStoreConfig = None - - @classmethod - def set_run_date(cls, run_date: datetime.date) -> None: - """ - Inicialize run Date - - :param run_date: datetime.date, run date - :return: None - """ - cls._run_date = run_date - - @classmethod - def get_run_date(cls) -> datetime.date: - """ - Run date - - :return: datetime.date, Run date - """ - if cls._run_date is None: - raise ConfigException("Run Date not Set !") - return cls._run_date - - @classmethod - def set_feature_store_config(cls, feature_store_schema: str, feature_metadata_schema: str) -> None: - """ - Inicialize feature store config - - :param feature_store_schema: str, schema name - :param feature_metadata_schema: str, metadata schema name - :return: None - """ - cls._feature_store_config = FeatureStoreConfig( - feature_store_schema=feature_store_schema, feature_metadata_schema=feature_metadata_schema - ) - - @classmethod - def get_feature_store_config(cls) -> FeatureStoreConfig: - """ - Feature Store Config - - :return: FeatureStoreConfig - """ - if cls._feature_store_config is None: - raise ConfigException("Feature Store Config not Set !") - - return cls._feature_store_config - - @classmethod - def get_config(cls) -> typing.Dict: - """ - Get config dictionary - - :return: dictionary of key-value pairs - """ - return cls._config.copy() - - @classmethod - def set_custom_config(cls, **kwargs) -> None: - """ - Set custom key-value pairs for custom config - - :param kwargs: key-value pairs to setup - :return: None - """ - cls._config.update(kwargs) - - @classmethod - def get_dependency_config(cls) -> typing.Dict: - """ - Get rialto job dependency config - - :return: dictionary with dependency config - """ - return cls._dependencies - - @classmethod - def set_dependency_config(cls, dependencies: typing.Dict) -> None: - """ - Get rialto job dependency config - - :param dependencies: dictionary with the config - :return: None - """ - cls._dependencies = dependencies diff --git a/rialto/jobs/decorators/decorators.py b/rialto/jobs/decorators/decorators.py index f900726..2949ad6 100644 --- a/rialto/jobs/decorators/decorators.py +++ b/rialto/jobs/decorators/decorators.py @@ -77,14 +77,14 @@ def job(*args, custom_name=None, disable_version=False): """ Rialto jobs decorator. - Transforms a python function into a rialto transormation, which can be imported and ran by Rialto Runner. + Transforms a python function into a rialto transformation, which can be imported and ran by Rialto Runner. Is mainly used as @job and the function's name is used, and the outputs get automatic. To override this behavior, use @job(custom_name=XXX, disable_version=True). :param *args: list of positional arguments. Empty in case custom_name or disable_version is specified. :param custom_name: str for custom job name. - :param disable_version: bool for disabling autofilling the VERSION column in the job's outputs. + :param disable_version: bool for disabling automobiling the VERSION column in the job's outputs. :return: One more job wrapper for run function (if custom name or version override specified). Otherwise, generates Rialto Transformation Type and returns it for in-module registration. """ diff --git a/rialto/jobs/decorators/job_base.py b/rialto/jobs/decorators/job_base.py index 9e3ecc8..d91537f 100644 --- a/rialto/jobs/decorators/job_base.py +++ b/rialto/jobs/decorators/job_base.py @@ -24,11 +24,11 @@ from pyspark.sql import DataFrame, SparkSession from rialto.common import TableReader -from rialto.jobs.configuration.config_holder import ConfigHolder from rialto.jobs.decorators.resolver import Resolver -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner import Transformation +from rialto.runner.config_loader import PipelineConfig class JobBase(Transformation): @@ -53,12 +53,14 @@ def get_job_name(self) -> str: def _setup_resolver(self, run_date: datetime.date) -> None: Resolver.register_callable(lambda: run_date, "run_date") - Resolver.register_callable(ConfigHolder.get_config, "config") - Resolver.register_callable(ConfigHolder.get_dependency_config, "dependencies") - Resolver.register_callable(self._get_spark, "spark") Resolver.register_callable(self._get_table_reader, "table_reader") - Resolver.register_callable(self._get_feature_loader, "feature_loader") + Resolver.register_callable(self._get_config, "config") + + if self._get_feature_loader() is not None: + Resolver.register_callable(self._get_feature_loader, "feature_loader") + if self._get_metadata_manager() is not None: + Resolver.register_callable(self._get_metadata_manager, "metadata_manager") try: yield @@ -66,13 +68,18 @@ def _setup_resolver(self, run_date: datetime.date) -> None: Resolver.cache_clear() def _setup( - self, spark: SparkSession, run_date: datetime.date, table_reader: TableReader, dependencies: typing.Dict = None + self, + spark: SparkSession, + table_reader: TableReader, + config: PipelineConfig = None, + metadata_manager: MetadataManager = None, + feature_loader: PysparkFeatureLoader = None, ) -> None: self._spark = spark self._table_rader = table_reader - - ConfigHolder.set_dependency_config(dependencies) - ConfigHolder.set_run_date(run_date) + self._config = config + self._metadata = metadata_manager + self._feature_loader = feature_loader def _get_spark(self) -> SparkSession: return self._spark @@ -80,13 +87,14 @@ def _get_spark(self) -> SparkSession: def _get_table_reader(self) -> TableReader: return self._table_rader - def _get_feature_loader(self) -> PysparkFeatureLoader: - config = ConfigHolder.get_feature_store_config() + def _get_config(self) -> PipelineConfig: + return self._config - databricks_loader = DatabricksLoader(self._spark, config.feature_store_schema) - feature_loader = PysparkFeatureLoader(self._spark, databricks_loader, config.feature_metadata_schema) + def _get_feature_loader(self) -> PysparkFeatureLoader: + return self._feature_loader - return feature_loader + def _get_metadata_manager(self) -> MetadataManager: + return self._metadata def _get_timestamp_holder_result(self) -> DataFrame: spark = self._get_spark() @@ -118,8 +126,9 @@ def run( reader: TableReader, run_date: datetime.date, spark: SparkSession = None, + config: PipelineConfig = None, metadata_manager: MetadataManager = None, - dependencies: typing.Dict = None, + feature_loader: PysparkFeatureLoader = None, ) -> DataFrame: """ Rialto transformation run @@ -127,12 +136,11 @@ def run( :param reader: data store api object :param info_date: date :param spark: spark session - :param metadata_manager: metadata api object - :param dependencies: rialto job dependencies + :param config: pipeline config :return: dataframe """ try: - self._setup(spark, run_date, reader, dependencies) + self._setup(spark, reader, config, metadata_manager, feature_loader) return self._run_main_callable(run_date) except Exception as e: logger.exception(e) diff --git a/rialto/jobs/decorators/resolver.py b/rialto/jobs/decorators/resolver.py index 9f90e5a..f13f0eb 100644 --- a/rialto/jobs/decorators/resolver.py +++ b/rialto/jobs/decorators/resolver.py @@ -30,7 +30,7 @@ class Resolver: Resolver handles dependency management between datasets and jobs. We register different callables, which can depend on other callables. - Calling resolve() we attempts to resolve these dependencies. + Calling resolve() we attempt to resolve these dependencies. """ _storage = {} diff --git a/rialto/loader/data_loader.py b/rialto/loader/data_loader.py index 930c2b0..dc13572 100644 --- a/rialto/loader/data_loader.py +++ b/rialto/loader/data_loader.py @@ -41,5 +41,5 @@ def read_group(self, group: str, information_date: date) -> DataFrame: :return: dataframe """ return self.reader.get_latest( - f"{self.schema}.{group}", until=information_date, date_column=self.date_col, uppercase_columns=True + f"{self.schema}.{group}", date_until=information_date, date_column=self.date_col, uppercase_columns=True ) diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index af6640b..d5b0150 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["get_pipelines_config", "transform_dependencies"] +__all__ = [ + "get_pipelines_config", +] from typing import Dict, List, Optional, Union @@ -35,7 +37,7 @@ class ScheduleConfig(BaseModel): class DependencyConfig(BaseModel): table: str name: Optional[str] = None - date_col: Optional[str] = None + date_col: str interval: IntervalConfig @@ -52,37 +54,43 @@ class MailConfig(BaseModel): sent_empty: Optional[bool] = False -class GeneralConfig(BaseModel): - target_schema: str - target_partition_column: str - source_date_column_property: Optional[str] = None +class RunnerConfig(BaseModel): watched_period_units: str watched_period_value: int - job: str mail: MailConfig +class TargetConfig(BaseModel): + target_schema: str + target_partition_column: str + + +class MetadataManagerConfig(BaseModel): + metadata_schema: str + + +class FeatureLoaderConfig(BaseModel): + config_path: str + feature_schema: str + metadata_schema: str + + class PipelineConfig(BaseModel): name: str - module: Optional[ModuleConfig] = None + module: ModuleConfig schedule: ScheduleConfig - dependencies: List[DependencyConfig] = [] + dependencies: Optional[List[DependencyConfig]] = [] + target: Optional[TargetConfig] = None + metadata_manager: Optional[MetadataManagerConfig] = None + feature_loader: Optional[FeatureLoaderConfig] = None + extras: Optional[Dict] = {} class PipelinesConfig(BaseModel): - general: GeneralConfig + runner: RunnerConfig pipelines: list[PipelineConfig] def get_pipelines_config(path) -> PipelinesConfig: """Load and parse yaml config""" return PipelinesConfig(**load_yaml(path)) - - -def transform_dependencies(dependencies: List[DependencyConfig]) -> Dict: - """Transform dependency config list into a dictionary""" - res = {} - for dep in dependencies: - if dep.name: - res[dep.name] = dep - return res diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 343d2fe..e3efe01 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -24,16 +24,13 @@ from pyspark.sql import DataFrame, SparkSession from rialto.common import TableReader -from rialto.common.utils import get_date_col_property, get_delta_partition -from rialto.jobs.configuration.config_holder import ConfigHolder +from rialto.loader import DatabricksLoader, PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner.config_loader import ( - DependencyConfig, ModuleConfig, PipelineConfig, ScheduleConfig, get_pipelines_config, - transform_dependencies, ) from rialto.runner.date_manager import DateManager from rialto.runner.table import Table @@ -48,35 +45,23 @@ def __init__( self, spark: SparkSession, config_path: str, - feature_metadata_schema: str = None, run_date: str = None, date_from: str = None, date_until: str = None, - feature_store_schema: str = None, - custom_job_config: dict = None, rerun: bool = False, op: str = None, + skip_dependencies: bool = False, ): self.spark = spark self.config = get_pipelines_config(config_path) - self.reader = TableReader( - spark, date_property=self.config.general.source_date_column_property, infer_partition=False - ) - if feature_metadata_schema: - self.metadata = MetadataManager(spark, feature_metadata_schema) - else: - self.metadata = None + self.reader = TableReader(spark) + self.date_from = date_from self.date_until = date_until self.rerun = rerun + self.skip_dependencies = skip_dependencies self.op = op - self.tracker = Tracker(self.config.general.target_schema) - - if (feature_store_schema is not None) and (feature_metadata_schema is not None): - ConfigHolder.set_feature_store_config(feature_store_schema, feature_metadata_schema) - - if custom_job_config is not None: - ConfigHolder.set_custom_config(**custom_job_config) + self.tracker = Tracker() if run_date: run_date = DateManager.str_to_date(run_date) @@ -90,8 +75,8 @@ def __init__( if not self.date_from: self.date_from = DateManager.date_subtract( run_date=run_date, - units=self.config.general.watched_period_units, - value=self.config.general.watched_period_value, + units=self.config.runner.watched_period_units, + value=self.config.runner.watched_period_value, ) if not self.date_until: self.date_until = run_date @@ -110,24 +95,36 @@ def _load_module(self, cfg: ModuleConfig) -> Transformation: class_obj = getattr(module, cfg.python_class) return class_obj() - def _generate( - self, instance: Transformation, run_date: date, dependencies: List[DependencyConfig] = None - ) -> DataFrame: + def _generate(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: """ Run feature group :param instance: Instance of Transformation :param run_date: date to run for + :param pipeline: pipeline configuration :return: Dataframe """ - if dependencies is not None: - dependencies = transform_dependencies(dependencies) + if pipeline.metadata_manager is not None: + metadata_manager = MetadataManager(self.spark, pipeline.metadata_manager.metadata_schema) + else: + metadata_manager = None + + if pipeline.feature_loader is not None: + feature_loader = PysparkFeatureLoader( + self.spark, + DatabricksLoader(self.spark, schema=pipeline.feature_loader.feature_schema), + metadata_schema=pipeline.feature_loader.metadata_schema, + ) + else: + feature_loader = None + df = instance.run( - reader=self.reader, - run_date=run_date, spark=self.spark, - metadata_manager=self.metadata, - dependencies=dependencies, + run_date=run_date, + config=pipeline, + reader=self.reader, + metadata_manager=metadata_manager, + feature_loader=feature_loader, ) logger.info(f"Generated {df.count()} records") @@ -155,15 +152,6 @@ def _write(self, df: DataFrame, info_date: date, table: Table) -> None: df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path()) logger.info(f"Results writen to {table.get_table_path()}") - try: - get_date_col_property(self.spark, table.get_table_path(), "rialto_date_column") - except RuntimeError: - sql_query = ( - f"ALTER TABLE {table.get_table_path()} SET TBLPROPERTIES ('rialto_date_column' = '{table.partition}')" - ) - self.spark.sql(sql_query) - logger.info(f"Set table property rialto_date_column to {table.partition}") - def _delta_partition(self, table: str) -> str: """ Select first partition column, should be only one @@ -226,18 +214,9 @@ def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: possible_dep_dates = DateManager.all_dates(dep_from, run_date) - # date column options prioritization (manual column, table property, inferred from delta) - if dependency.date_col: - date_col = dependency.date_col - elif self.config.general.source_date_column_property: - date_col = get_date_col_property( - self.spark, dependency.table, self.config.general.source_date_column_property - ) - else: - date_col = get_delta_partition(self.spark, dependency.table) - logger.debug(f"Date column for {dependency.table} is {date_col}") + logger.debug(f"Date column for {dependency.table} is {dependency.date_col}") - source = Table(table_path=dependency.table, partition=date_col) + source = Table(table_path=dependency.table, partition=dependency.date_col) if True in self.check_dates_have_partition(source, possible_dep_dates): logger.info(f"Dependency for {dependency.table} from {dep_from} until {run_date} is fulfilled") else: @@ -318,18 +297,17 @@ def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: dat :param target: target Table :return: success bool """ - if self.check_dependencies(pipeline, run_date): + if self.skip_dependencies or self.check_dependencies(pipeline, run_date): logger.info(f"Running {pipeline.name} for {run_date}") - if self.config.general.job == "run": - feature_group = self._load_module(pipeline.module) - df = self._generate(feature_group, run_date, pipeline.dependencies) - records = df.count() - if records > 0: - self._write(df, info_date, target) - return records - else: - raise RuntimeError("No records generated") + feature_group = self._load_module(pipeline.module) + df = self._generate(feature_group, run_date, pipeline) + records = df.count() + if records > 0: + self._write(df, info_date, target) + return records + else: + raise RuntimeError("No records generated") return 0 def _run_pipeline(self, pipeline: PipelineConfig): @@ -340,9 +318,9 @@ def _run_pipeline(self, pipeline: PipelineConfig): :return: success bool """ target = Table( - schema_path=self.config.general.target_schema, + schema_path=pipeline.target.target_schema, class_name=pipeline.module.python_class, - partition=self.config.general.target_partition_column, + partition=pipeline.target.target_partition_column, ) logger.info(f"Loaded pipeline {pipeline.name}") @@ -413,4 +391,4 @@ def __call__(self): self._run_pipeline(pipeline) finally: print(self.tracker.records) - self.tracker.report(self.config.general.mail) + self.tracker.report(self.config.runner.mail) diff --git a/rialto/runner/tracker.py b/rialto/runner/tracker.py index de97fb0..57a24e6 100644 --- a/rialto/runner/tracker.py +++ b/rialto/runner/tracker.py @@ -41,8 +41,7 @@ class Record: class Tracker: """Collect information about runs and sent them out via email""" - def __init__(self, target_schema: str): - self.target_schema = target_schema + def __init__(self): self.records = [] self.last_error = None self.pipeline_start = datetime.now() @@ -55,7 +54,7 @@ def add(self, record: Record) -> None: def report(self, mail_cfg: MailConfig): """Create and send html report""" if len(self.records) or mail_cfg.sent_empty: - report = HTMLMessage.make_report(self.target_schema, self.pipeline_start, self.records) + report = HTMLMessage.make_report(self.pipeline_start, self.records) for receiver in mail_cfg.to: message = Mailer.create_message( subject=mail_cfg.subject, sender=mail_cfg.sender, receiver=receiver, body=report @@ -118,7 +117,7 @@ def _make_overview_header(): """ @staticmethod - def _make_header(target: str, start: datetime): + def _make_header(start: datetime): return f"""
@@ -127,7 +126,7 @@ def _make_header(target: str, start: datetime):
- Jobs started {str(start).split('.')[0]}, targeting {target} + Jobs started {str(start).split('.')[0]}
@@ -228,14 +227,14 @@ def _make_insights(records: List[Record]): """ @staticmethod - def make_report(target: str, start: datetime, records: List[Record]) -> str: + def make_report(start: datetime, records: List[Record]) -> str: """Create html email report""" html = [ """ """, HTMLMessage._head(), HTMLMessage._body_open(), - HTMLMessage._make_header(target, start), + HTMLMessage._make_header(start), HTMLMessage._make_overview(records), HTMLMessage._make_insights(records), HTMLMessage._body_close(), diff --git a/rialto/runner/transformation.py b/rialto/runner/transformation.py index 4399ce0..7b5eaa8 100644 --- a/rialto/runner/transformation.py +++ b/rialto/runner/transformation.py @@ -16,12 +16,13 @@ import abc import datetime -from typing import Dict from pyspark.sql import DataFrame, SparkSession from rialto.common import TableReader +from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager +from rialto.runner.config_loader import PipelineConfig class Transformation(metaclass=abc.ABCMeta): @@ -33,8 +34,9 @@ def run( reader: TableReader, run_date: datetime.date, spark: SparkSession = None, + config: PipelineConfig = None, metadata_manager: MetadataManager = None, - dependencies: Dict = None, + feature_loader: PysparkFeatureLoader = None, ) -> DataFrame: """ Run the transformation @@ -42,8 +44,9 @@ def run( :param reader: data store api object :param run_date: date :param spark: spark session - :param metadata_manager: metadata api object - :param dependencies: dictionary of dependencies + :param config: pipeline config + :param metadata_manager: metadata manager + :param feature_loader: feature loader :return: dataframe """ raise NotImplementedError diff --git a/tests/jobs/test_config_holder.py b/tests/jobs/test_config_holder.py deleted file mode 100644 index 38fadb1..0000000 --- a/tests/jobs/test_config_holder.py +++ /dev/null @@ -1,100 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from datetime import date - -import pytest - -from rialto.jobs.configuration.config_holder import ( - ConfigException, - ConfigHolder, - FeatureStoreConfig, -) - - -def test_run_date_unset(): - with pytest.raises(ConfigException): - ConfigHolder.get_run_date() - - -def test_run_date(): - dt = date(2023, 1, 1) - - ConfigHolder.set_run_date(dt) - - assert ConfigHolder.get_run_date() == dt - - -def test_feature_store_config_unset(): - with pytest.raises(ConfigException): - ConfigHolder.get_feature_store_config() - - -def test_feature_store_config(): - ConfigHolder.set_feature_store_config("store_schema", "metadata_schema") - - fsc = ConfigHolder.get_feature_store_config() - - assert type(fsc) is FeatureStoreConfig - assert fsc.feature_store_schema == "store_schema" - assert fsc.feature_metadata_schema == "metadata_schema" - - -def test_config_unset(): - config = ConfigHolder.get_config() - - assert type(config) is type({}) - assert len(config.items()) == 0 - - -def test_config_dict_copied_not_ref(): - """Test that config holder config can't be set from outside""" - config = ConfigHolder.get_config() - - config["test"] = 123 - - assert "test" not in ConfigHolder.get_config() - - -def test_config(): - ConfigHolder.set_custom_config(hello=123) - ConfigHolder.set_custom_config(world="test") - - config = ConfigHolder.get_config() - - assert config["hello"] == 123 - assert config["world"] == "test" - - -def test_config_from_dict(): - ConfigHolder.set_custom_config(**{"dict_item_1": 123, "dict_item_2": 456}) - - config = ConfigHolder.get_config() - - assert config["dict_item_1"] == 123 - assert config["dict_item_2"] == 456 - - -def test_dependencies_unset(): - deps = ConfigHolder.get_dependency_config() - assert len(deps.keys()) == 0 - - -def test_dependencies(): - ConfigHolder.set_dependency_config({"hello": 123}) - - deps = ConfigHolder.get_dependency_config() - - assert deps["hello"] == 123 diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py index c6d05e6..6496a2d 100644 --- a/tests/jobs/test_decorators.py +++ b/tests/jobs/test_decorators.py @@ -14,7 +14,6 @@ from importlib import import_module -from rialto.jobs.configuration.config_holder import ConfigHolder from rialto.jobs.decorators.job_base import JobBase from rialto.jobs.decorators.resolver import Resolver @@ -70,7 +69,6 @@ def test_job_disabling_version(): def test_job_dependencies_registered(spark): - ConfigHolder.set_custom_config(value=123) job_class = _rialto_import_stub("tests.jobs.test_job.test_job", "job_asking_for_all_deps") # asserts part of the run - job_class.run(spark=spark, run_date=456, reader=789, metadata_manager=None, dependencies=1011) + job_class.run(spark=spark, run_date=456, reader=789, config=123, metadata_manager=654, feature_loader=321) diff --git a/tests/jobs/test_job/dependency_tests_job.py b/tests/jobs/test_job/dependency_tests_job.py index 3029b33..38e10ba 100644 --- a/tests/jobs/test_job/dependency_tests_job.py +++ b/tests/jobs/test_job/dependency_tests_job.py @@ -1,4 +1,4 @@ -from rialto.jobs.decorators import job, datasource +from rialto.jobs.decorators import datasource, job @datasource @@ -47,5 +47,5 @@ def missing_dependency_job(a, x): @job -def default_dependency_job(run_date, spark, config, dependencies, table_reader, feature_loader): +def default_dependency_job(run_date, spark, config, table_reader, feature_loader): return 1 diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py index 460490a..bc3cb69 100644 --- a/tests/jobs/test_job/test_job.py +++ b/tests/jobs/test_job/test_job.py @@ -37,9 +37,10 @@ def disable_version_job_function(): @job -def job_asking_for_all_deps(spark, run_date, config, dependencies, table_reader): +def job_asking_for_all_deps(spark, run_date, config, table_reader, metadata_manager, feature_loader): assert spark is not None assert run_date == 456 - assert config["value"] == 123 + assert config == 123 assert table_reader == 789 - assert dependencies == 1011 + assert metadata_manager == 654 + assert feature_loader == 321 diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py index ab8284a..fa8f19c 100644 --- a/tests/jobs/test_job_base.py +++ b/tests/jobs/test_job_base.py @@ -14,42 +14,36 @@ import datetime -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pyspark.sql.types import tests.jobs.resources as resources -from rialto.jobs.configuration.config_holder import ConfigHolder, FeatureStoreConfig from rialto.jobs.decorators.resolver import Resolver -from rialto.loader import PysparkFeatureLoader +from rialto.loader import DatabricksLoader, PysparkFeatureLoader def test_setup_except_feature_loader(spark): table_reader = MagicMock() + config = MagicMock() date = datetime.date(2023, 1, 1) - ConfigHolder.set_custom_config(hello=1, world=2) - - resources.CustomJobNoReturnVal().run( - reader=table_reader, run_date=date, spark=spark, metadata_manager=None, dependencies={1: 1} - ) + resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=config) assert Resolver.resolve("run_date") == date - assert Resolver.resolve("config") == ConfigHolder.get_config() - assert Resolver.resolve("dependencies") == ConfigHolder.get_dependency_config() + assert Resolver.resolve("config") == config assert Resolver.resolve("spark") == spark assert Resolver.resolve("table_reader") == table_reader -@patch( - "rialto.jobs.configuration.config_holder.ConfigHolder.get_feature_store_config", - return_value=FeatureStoreConfig(feature_store_schema="schema", feature_metadata_schema="metadata_schema"), -) def test_setup_feature_loader(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) + feature_loader = PysparkFeatureLoader(spark, DatabricksLoader(spark, "", ""), "") - resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None) + resources.CustomJobNoReturnVal().run( + reader=table_reader, run_date=date, spark=spark, config=None, feature_loader=feature_loader + ) assert type(Resolver.resolve("feature_loader")) == PysparkFeatureLoader @@ -60,7 +54,7 @@ def test_custom_callable_called(spark, mocker): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None) + resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=None) spy_cc.assert_called_once() @@ -69,9 +63,7 @@ def test_no_return_vaue_adds_version_timestamp_dataframe(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - result = resources.CustomJobNoReturnVal().run( - reader=table_reader, run_date=date, spark=spark, metadata_manager=None - ) + result = resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=None) assert type(result) is pyspark.sql.DataFrame assert result.columns == ["JOB_NAME", "CREATION_TIME", "VERSION"] @@ -83,9 +75,7 @@ def test_return_dataframe_forwarded_with_version(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - result = resources.CustomJobReturnsDataFrame().run( - reader=table_reader, run_date=date, spark=spark, metadata_manager=None - ) + result = resources.CustomJobReturnsDataFrame().run(reader=table_reader, run_date=date, spark=spark, config=None) assert type(result) is pyspark.sql.DataFrame assert result.columns == ["FIRST", "SECOND", "VERSION"] @@ -97,7 +87,7 @@ def test_none_job_version_wont_fill_job_colun(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - result = resources.CustomJobNoVersion().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None) + result = resources.CustomJobNoVersion().run(reader=table_reader, run_date=date, spark=spark, config=None) assert type(result) is pyspark.sql.DataFrame assert "VERSION" not in result.columns diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py index 44f0c09..4e527be 100644 --- a/tests/runner/conftest.py +++ b/tests/runner/conftest.py @@ -39,6 +39,4 @@ def spark(request): @pytest.fixture(scope="function") def basic_runner(spark): - return Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31" - ) + return Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 0459411..85ddf95 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -19,7 +19,6 @@ from pyspark.sql import DataFrame from rialto.common.table_reader import DataReader -from rialto.jobs.configuration.config_holder import ConfigHolder from rialto.runner.runner import DateManager, Runner from rialto.runner.table import Table from tests.runner.runner_resources import ( @@ -38,8 +37,8 @@ def __init__(self, spark): def get_table( self, table: str, - info_date_from: Optional[datetime.date] = None, - info_date_to: Optional[datetime.date] = None, + date_from: Optional[datetime.date] = None, + date_to: Optional[datetime.date] = None, date_column: str = None, uppercase_columns: bool = False, ) -> DataFrame: @@ -53,7 +52,7 @@ def get_table( def get_latest( self, table: str, - until: Optional[datetime.date] = None, + date_until: Optional[datetime.date] = None, date_column: str = None, uppercase_columns: bool = False, ) -> DataFrame: @@ -84,43 +83,41 @@ def test_load_module(spark, basic_runner): def test_generate(spark, mocker, basic_runner): run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") group = SimpleGroup() - basic_runner._generate(group, DateManager.str_to_date("2023-01-31")) + config = basic_runner.config.pipelines[0] + basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), config) + run.assert_called_once_with( reader=basic_runner.reader, run_date=DateManager.str_to_date("2023-01-31"), spark=spark, - metadata_manager=basic_runner.metadata, - dependencies=None, + config=config, + metadata_manager=None, + feature_loader=None, ) def test_generate_w_dep(spark, mocker, basic_runner): run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") group = SimpleGroup() - basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2].dependencies) + basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2]) run.assert_called_once_with( reader=basic_runner.reader, run_date=DateManager.str_to_date("2023-01-31"), spark=spark, - metadata_manager=basic_runner.metadata, - dependencies={ - "source1": basic_runner.config.pipelines[2].dependencies[0], - "source2": basic_runner.config.pipelines[2].dependencies[1], - }, + config=basic_runner.config.pipelines[2], + metadata_manager=None, + feature_loader=None, ) def test_init_dates(spark): - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31" - ) + runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") assert runner.date_from == DateManager.str_to_date("2023-01-31") assert runner.date_until == DateManager.str_to_date("2023-03-31") runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", date_from="2023-03-01", date_until="2023-03-31", ) @@ -130,7 +127,6 @@ def test_init_dates(spark): runner = Runner( spark, config_path="tests/runner/transformations/config2.yaml", - feature_metadata_schema="", run_date="2023-03-31", ) assert runner.date_from == DateManager.str_to_date("2023-02-24") @@ -141,7 +137,6 @@ def test_possible_run_dates(spark): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", date_from="2023-03-01", date_until="2023-03-31", ) @@ -175,9 +170,7 @@ def test_completion(spark, mocker, basic_runner): def test_completion_rerun(spark, mocker, basic_runner): mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31" - ) + runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") runner.reader = MockReader(spark) dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] @@ -194,7 +187,6 @@ def test_check_dates_have_partition(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", date_from="2023-03-01", date_until="2023-03-31", ) @@ -212,7 +204,6 @@ def test_check_dates_have_partition_no_table(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", date_from="2023-03-01", date_until="2023-03-31", ) @@ -233,7 +224,6 @@ def test_check_dependencies(spark, mocker, r_date, expected): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", date_from="2023-03-01", date_until="2023-03-31", ) @@ -248,7 +238,6 @@ def test_check_no_dependencies(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", date_from="2023-03-01", date_until="2023-03-31", ) @@ -263,7 +252,6 @@ def test_select_dates(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", date_from="2023-03-01", date_until="2023-03-31", ) @@ -286,7 +274,6 @@ def test_select_dates_all_done(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - feature_metadata_schema="", date_from="2023-03-02", date_until="2023-03-02", ) @@ -307,9 +294,7 @@ def test_op_selected(spark, mocker): mocker.patch("rialto.runner.tracker.Tracker.report") run = mocker.patch("rialto.runner.runner.Runner._run_pipeline") - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", op="SimpleGroup" - ) + runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="SimpleGroup") runner() run.called_once() @@ -319,42 +304,8 @@ def test_op_bad(spark, mocker): mocker.patch("rialto.runner.tracker.Tracker.report") mocker.patch("rialto.runner.runner.Runner._run_pipeline") - runner = Runner( - spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", op="BadOp" - ) + runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="BadOp") with pytest.raises(ValueError) as exception: runner() assert str(exception.value) == "Unknown operation selected: BadOp" - - -def test_custom_config(spark, mocker): - cc_spy = mocker.spy(ConfigHolder, "set_custom_config") - custom_config = {"cc": 42} - - _ = Runner(spark, config_path="tests/runner/transformations/config.yaml", custom_job_config=custom_config) - - cc_spy.assert_called_once_with(cc=42) - - -def test_feature_store_config(spark, mocker): - fs_spy = mocker.spy(ConfigHolder, "set_feature_store_config") - - _ = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - feature_store_schema="schema", - feature_metadata_schema="metadata", - ) - - fs_spy.assert_called_once_with("schema", "metadata") - - -def test_no_configs(spark, mocker): - cc_spy = mocker.spy(ConfigHolder, "set_custom_config") - fs_spy = mocker.spy(ConfigHolder, "set_feature_store_config") - - _ = Runner(spark, config_path="tests/runner/transformations/config.yaml") - - cc_spy.assert_not_called() - fs_spy.assert_not_called() diff --git a/tests/runner/transformations/config.yaml b/tests/runner/transformations/config.yaml index 2bfeaf1..0ed82ce 100644 --- a/tests/runner/transformations/config.yaml +++ b/tests/runner/transformations/config.yaml @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -general: - target_schema: catalog.schema - target_partition_column: "INFORMATION_DATE" +runner: watched_period_units: "months" watched_period_value: 2 - job: "run" # run/check mail: sender: test@testing.org smtp: server.test @@ -47,6 +44,9 @@ pipelines: units: "months" value: 3 date_col: "DATE" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" - name: GroupNoDeps module: python_module: tests.runner.transformations @@ -80,3 +80,6 @@ pipelines: units: "months" value: 3 date_col: "batch" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" diff --git a/tests/runner/transformations/config2.yaml b/tests/runner/transformations/config2.yaml index a91894b..f7b9604 100644 --- a/tests/runner/transformations/config2.yaml +++ b/tests/runner/transformations/config2.yaml @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -general: - target_schema: catalog.schema - target_partition_column: "INFORMATION_DATE" +runner: watched_period_units: "weeks" watched_period_value: 5 - job: "run" # run/check mail: sender: test@testing.org smtp: server.test @@ -43,3 +40,6 @@ pipelines: units: "months" value: 1 date_col: "DATE" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" diff --git a/tests/runner/transformations/simple_group.py b/tests/runner/transformations/simple_group.py index fcda5c7..ec2311c 100644 --- a/tests/runner/transformations/simple_group.py +++ b/tests/runner/transformations/simple_group.py @@ -18,6 +18,7 @@ from pyspark.sql.types import StructType from rialto.common import TableReader +from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner import Transformation @@ -28,7 +29,8 @@ def run( reader: TableReader, run_date: datetime.date, spark: SparkSession = None, - metadata_manager: MetadataManager = None, - dependencies: Dict = None, + config: Dict = None, + metadata: MetadataManager = None, + feature_loader: PysparkFeatureLoader = None, ) -> DataFrame: return spark.createDataFrame([], StructType([])) From 8e18da2371b0a3b5c627a127d95e2ad885c86952 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Fri, 23 Aug 2024 19:20:36 +0200 Subject: [PATCH 02/12] documented changes --- CHANGELOG.md | 21 +++++++++++++++++++++ rialto/jobs/decorators/decorators.py | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cfd48eb..b2a4b2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,27 @@ # Change Log All notable changes to this project will be documented in this file. +## 2.0.0 - 2024-mm-dd + #### Runner + - runner config now accepts environment variables + - restructured runner config + - added metadata and feature loader sections + - target moved to pipeline + - dependency date_col is mandatory + - custom extras config is available in each pipeline and will be passed as dictionary + - general section is renamed to runner + - transformation header changed + - added argument to skip dependency checking + #### Jobs + - config holder removed from jobs + - metadata_manager and feature_loader are now available arguments, depending on configuration + #### TableReader + - function signatures changed + - until -> date_until + - info_date_from -> date_from, info_date_to -> date_to + - date_column is now mandatory + - removed TableReaders ability to infer schema from partitions or properties + ## 1.3.0 - 2024-06-07 diff --git a/rialto/jobs/decorators/decorators.py b/rialto/jobs/decorators/decorators.py index 2949ad6..94b7409 100644 --- a/rialto/jobs/decorators/decorators.py +++ b/rialto/jobs/decorators/decorators.py @@ -84,7 +84,7 @@ def job(*args, custom_name=None, disable_version=False): :param *args: list of positional arguments. Empty in case custom_name or disable_version is specified. :param custom_name: str for custom job name. - :param disable_version: bool for disabling automobiling the VERSION column in the job's outputs. + :param disable_version: bool for disabling automatically filling the VERSION column in the job's outputs. :return: One more job wrapper for run function (if custom name or version override specified). Otherwise, generates Rialto Transformation Type and returns it for in-module registration. """ From 1699043f7b90e6619b56a1ff1e0df648b8c21658 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Fri, 23 Aug 2024 19:31:24 +0200 Subject: [PATCH 03/12] updated readme --- README.md | 45 +++++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index 14f4dae..cc7b01d 100644 --- a/README.md +++ b/README.md @@ -53,31 +53,22 @@ runner() A runner by default executes all the jobs provided in the configuration file, for all the viable execution dates according to the configuration file for which the job has not yet run successfully (i.e. the date partition doesn't exist on the storage) This behavior can be modified by various parameters and switches available. -* **feature_metadata_schema** - path to schema where feature metadata are read and stored, needed for [maker](#maker) jobs and jobs that utilized feature [loader](#loader) * **run_date** - date at which the runner is triggered (defaults to day of running) * **date_from** - starting date (defaults to rundate - config watch period) * **date_until** - end date (defaults to rundate) -* **feature_store_schema** - location of features, needed for jobs utilizing feature [loader](#loader) -* **custom_job_config** - dictionary with key-value pairs that will be accessible under the "config" variable in your rialto jobs * **rerun** - rerun all jobs even if they already succeeded in the past runs * **op** - run only selected operation / pipeline - +* **skip_dependencies** - ignore dependency checks and run all jobs Transformations are not included in the runner itself, it imports them dynamically according to the configuration, therefore it's necessary to have them locally installed. -A runner created table has will have automatically created **rialto_date_column** table property set according to target partition set in the configuration. - ### Configuration ```yaml -general: - target_schema: catalog.schema # schema where tables will be created, must exist - target_partition_column: INFORMATION_DATE # date to partition new tables on - source_date_column_property: rialto_date_column # name of the date property on source tables +runner: watched_period_units: "months" # unit of default run period watched_period_value: 2 # value of default run period - job: "run" # run for running the pipelines, check for only checking dependencies mail: to: # a list of email addresses - name@host.domain @@ -100,7 +91,7 @@ pipelines: # a list of pipelines to run dependencies: # list of dependent tables - table: catalog.schema.table1 name: "table1" # Optional table name, used to recall dependency details in transformation - date_col: generation_date # Optional date column name, takes priority + date_col: generation_date # Mandatory date column name interval: # mandatory availability interval, subtracted from scheduled day units: "days" value: 1 @@ -109,6 +100,18 @@ pipelines: # a list of pipelines to run interval: units: "months" value: 1 + target: + target_schema: catalog.schema # schema where tables will be created, must exist + target_partition_column: INFORMATION_DATE # date to partition new tables on + metadata_manager: # optional + metadata_schema: catalog.metadata # schema where metadata is stored + feature_loader: # optional + config_path: model_features_config.yaml # path to the feature loader configuration file + feature_schema: catalog.feature_tables # schema where feature tables are stored + metadata_schema: catalog.metadata # schema where metadata is stored + extras: #optional arguments processed as dictionary + some_value: 3 + some_other_value: giraffe - name: PipelineTable1 # will be written as pipeline_table1 module: @@ -302,6 +305,7 @@ We have a set of pre-defined dependencies: * **dependencies** returns a dictionary containing the job dependencies config * **table_reader** returns *TableReader* * **feature_loader** provides *PysparkFeatureLoader* +* **metadata_manager** provides *MetadataManager* Apart from that, each **datasource** also becomes a fully usable dependency. Note, that this means that datasources can also be dependent on other datasources - just beware of any circular dependencies! @@ -575,7 +579,7 @@ from datetime import datetime start = datetime.strptime("2020-01-01", "%Y-%m-%d").date() end = datetime.strptime("2024-01-01", "%Y-%m-%d").date() -df = reader.get_table(table="catalog.schema.table", date_from=start, date_to=end) +df = reader.get_table(table="catalog.schema.table", date_from=start, date_to=end, date_column="information_date") ``` usage of _get_latest_: @@ -595,21 +599,6 @@ For full information on parameters and their optionality see technical documenta _TableReader_ needs an active spark session and an information which column is the **date column**. There are three options how to pass that information on. -In order of priority from highest: -* Explicit _date_column_ parameter in _get_table_ and _get_latest_ -```python -reader.get_latest(table="catalog.schema.table", date_column="information_date") -``` -* Inferred from delta metadata, triggered by init parameter, only works on delta tables (e.g. doesn't work on views) -```python -reader = TableReader(spark=spark_instance, infer_partition=True) -reader.get_latest(table="catalog.schema.table") -``` -* A custom sql property defined on the table containing the date column name, defaults to _rialto_date_column_ -```python -reader = TableReader(spark=spark_instance, date_property="rialto_date_column") -reader.get_latest(table="catalog.schema.table") -``` # 3. Contributing Contributing: From a185b20d39841477d23dceefaa1257e5deb4f351 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Mon, 26 Aug 2024 13:45:39 +0200 Subject: [PATCH 04/12] loader simplified, runner cleaned --- CHANGELOG.md | 8 +- README.md | 30 ++----- rialto/common/__init__.py | 2 +- rialto/common/utils.py | 34 +------ rialto/jobs/decorators/decorators.py | 2 +- rialto/jobs/decorators/resolver.py | 2 +- rialto/loader/__init__.py | 1 - rialto/loader/data_loader.py | 45 ---------- rialto/loader/interfaces.py | 20 +---- rialto/loader/pyspark_feature_loader.py | 43 ++++++--- rialto/runner/runner.py | 115 +++--------------------- rialto/runner/transformation.py | 4 +- rialto/runner/utils.py | 74 +++++++++++++++ tests/jobs/test_job_base.py | 4 +- tests/loader/pyspark/dummy_loaders.py | 24 ----- tests/loader/pyspark/test_from_cfg.py | 11 ++- tests/runner/test_runner.py | 38 +------- 17 files changed, 148 insertions(+), 309 deletions(-) delete mode 100644 rialto/loader/data_loader.py create mode 100644 rialto/runner/utils.py delete mode 100644 tests/loader/pyspark/dummy_loaders.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b2a4b2e..8c25f0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,12 +7,13 @@ All notable changes to this project will be documented in this file. - restructured runner config - added metadata and feature loader sections - target moved to pipeline - - dependency date_col is mandatory - - custom extras config is available in each pipeline and will be passed as dictionary + - dependency date_col is now mandatory + - custom extras config is available in each pipeline and will be passed as dictionary available under pipeline_config.extras - general section is renamed to runner - transformation header changed - added argument to skip dependency checking #### Jobs + - jobs are now the main way to create all pipelines - config holder removed from jobs - metadata_manager and feature_loader are now available arguments, depending on configuration #### TableReader @@ -21,7 +22,8 @@ All notable changes to this project will be documented in this file. - info_date_from -> date_from, info_date_to -> date_to - date_column is now mandatory - removed TableReaders ability to infer schema from partitions or properties - + #### Loader + - removed DataLoader class, now only PysparkFeatureLoader is needed with additional parameters ## 1.3.0 - 2024-06-07 diff --git a/README.md b/README.md index cc7b01d..56ccaea 100644 --- a/README.md +++ b/README.md @@ -423,19 +423,6 @@ This module is used to load features from feature store into your models and scr Two public classes are exposed form this module. **DatabricksLoader**(DataLoader), **PysparkFeatureLoader**(FeatureLoaderInterface). -### DatabricksLoader -This is a support class for feature loader and provides the data reading capability from the feature store. - -This class needs to be instantiated with an active spark session and a path to the feature store schema (in the format of "catalog_name.schema_name"). -Optionally a date_column information can be passed, otherwise it defaults to use INFORMATION_DATE -```python -from rialto.loader import DatabricksLoader - -data_loader = DatabricksLoader(spark= spark_instance, schema= "catalog.schema", date_column= "INFORMATION_DATE") -``` - -This class provides one method, read_group(...), which returns a whole feature group for selected date. This is mostly used inside feature loader. - ### PysparkFeatureLoader This class needs to be instantiated with an active spark session, data loader and a path to the metadata schema (in the format of "catalog_name.schema_name"). @@ -443,17 +430,16 @@ This class needs to be instantiated with an active spark session, data loader an ```python from rialto.loader import PysparkFeatureLoader -feature_loader = PysparkFeatureLoader(spark= spark_instance, data_loader= data_loader_instance, metadata_schema= "catalog.schema") +feature_loader = PysparkFeatureLoader(spark= spark_instance, feature_schema="catalog.schema", metadata_schema= "catalog.schema2", date_column="information_date") ``` #### Single feature ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() feature = feature_loader.get_feature(group_name="CustomerFeatures", feature_name="AGE", information_date=my_date) @@ -464,11 +450,10 @@ metadata = feature_loader.get_feature_metadata(group_name="CustomerFeatures", fe This method of data access is only recommended for experimentation, as the group schema can evolve over time. ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() features = feature_loader.get_group(group_name="CustomerFeatures", information_date=my_date) @@ -478,11 +463,10 @@ metadata = feature_loader.get_group_metadata(group_name="CustomerFeatures") #### Configuration ```python -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader from datetime import datetime -data_loader = DatabricksLoader(spark, "feature_catalog.feature_schema") -feature_loader = PysparkFeatureLoader(spark, data_loader, "metadata_catalog.metadata_schema") +feature_loader = PysparkFeatureLoader(spark, "feature_catalog.feature_schema", "metadata_catalog.metadata_schema") my_date = datetime.strptime("2020-01-01", "%Y-%m-%d").date() features = feature_loader.get_features_from_cfg(path="local/configuration/file.yaml", information_date=my_date) diff --git a/rialto/common/__init__.py b/rialto/common/__init__.py index 93e8922..1bd5055 100644 --- a/rialto/common/__init__.py +++ b/rialto/common/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.common.table_reader import TableReader +from rialto.common.table_reader import DataReader, TableReader diff --git a/rialto/common/utils.py b/rialto/common/utils.py index 6c2952c..b2e19b4 100644 --- a/rialto/common/utils.py +++ b/rialto/common/utils.py @@ -38,43 +38,11 @@ def load_yaml(path: str) -> Any: return yaml.load(stream, EnvLoader) -def get_date_col_property(spark, table: str, property: str) -> str: - """ - Retrieve a data column name from a given table property - - :param spark: spark session - :param table: path to table - :param property: name of the property - :return: data column name - """ - props = spark.sql(f"show tblproperties {table}") - date_col = props.filter(F.col("key") == property).select("value").collect() - if len(date_col): - return date_col[0].value - else: - raise RuntimeError(f"Table {table} has no property {property}.") - - -def get_delta_partition(spark, table: str) -> str: - """ - Select first partition column of the delta table - - :param table: full table name - :return: partition column name - """ - columns = spark.catalog.listColumns(table) - partition_columns = list(filter(lambda c: c.isPartition, columns)) - if len(partition_columns): - return partition_columns[0].name - else: - raise RuntimeError(f"Delta table has no partitions: {table}.") - - def cast_decimals_to_floats(df: DataFrame) -> DataFrame: """ Find all decimal types in the table and cast them to floats. Fixes errors in .toPandas() conversions. - :param df: pyspark DataFrame + :param df: input df :return: pyspark DataFrame with fixed types """ decimal_cols = [col_name for col_name, data_type in df.dtypes if "decimal" in data_type] diff --git a/rialto/jobs/decorators/decorators.py b/rialto/jobs/decorators/decorators.py index 94b7409..217b436 100644 --- a/rialto/jobs/decorators/decorators.py +++ b/rialto/jobs/decorators/decorators.py @@ -93,7 +93,7 @@ def job(*args, custom_name=None, disable_version=False): module = _get_module(stack) version = _get_version(module) - # Use case where it's just raw @f. Otherwise we get [] here. + # Use case where it's just raw @f. Otherwise, we get [] here. if len(args) == 1 and callable(args[0]): f = args[0] return _generate_rialto_job(callable=f, module=module, class_name=f.__name__, version=version) diff --git a/rialto/jobs/decorators/resolver.py b/rialto/jobs/decorators/resolver.py index f13f0eb..26856d1 100644 --- a/rialto/jobs/decorators/resolver.py +++ b/rialto/jobs/decorators/resolver.py @@ -101,7 +101,7 @@ def cache_clear(cls) -> None: """ Clear resolver cache. - The resolve mehtod caches its results to avoid duplication of resolutions. + The resolve method caches its results to avoid duplication of resolutions. However, in case we re-register some callables, we need to clear cache in order to ensure re-execution of all resolutions. diff --git a/rialto/loader/__init__.py b/rialto/loader/__init__.py index 7adc52d..7e1e936 100644 --- a/rialto/loader/__init__.py +++ b/rialto/loader/__init__.py @@ -12,5 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.loader.data_loader import DatabricksLoader from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader diff --git a/rialto/loader/data_loader.py b/rialto/loader/data_loader.py deleted file mode 100644 index dc13572..0000000 --- a/rialto/loader/data_loader.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -__all__ = ["DatabricksLoader"] - -from datetime import date - -from pyspark.sql import DataFrame, SparkSession - -from rialto.common.table_reader import TableReader -from rialto.loader.interfaces import DataLoader - - -class DatabricksLoader(DataLoader): - """Implementation of DataLoader using TableReader to access feature tables""" - - def __init__(self, spark: SparkSession, schema: str, date_column: str = "INFORMATION_DATE"): - super().__init__() - - self.reader = TableReader(spark) - self.schema = schema - self.date_col = date_column - - def read_group(self, group: str, information_date: date) -> DataFrame: - """ - Read a feature group by getting the latest partition by date - - :param group: group name - :param information_date: partition date - :return: dataframe - """ - return self.reader.get_latest( - f"{self.schema}.{group}", date_until=information_date, date_column=self.date_col, uppercase_columns=True - ) diff --git a/rialto/loader/interfaces.py b/rialto/loader/interfaces.py index dad08e6..9089f40 100644 --- a/rialto/loader/interfaces.py +++ b/rialto/loader/interfaces.py @@ -12,31 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["DataLoader", "FeatureLoaderInterface"] +__all__ = ["FeatureLoaderInterface"] import abc from datetime import date from typing import Dict -class DataLoader(metaclass=abc.ABCMeta): - """ - An interface to read feature groups from storage - - Requires read_group function. - """ - - @abc.abstractmethod - def read_group(self, group: str, information_date: date): - """ - Read one feature group - - :param group: Group name - :param information_date: date - """ - raise NotImplementedError - - class FeatureLoaderInterface(metaclass=abc.ABCMeta): """ A definition of feature loading interface diff --git a/rialto/loader/pyspark_feature_loader.py b/rialto/loader/pyspark_feature_loader.py index d0eef20..7ee78fc 100644 --- a/rialto/loader/pyspark_feature_loader.py +++ b/rialto/loader/pyspark_feature_loader.py @@ -20,9 +20,9 @@ from pyspark.sql import DataFrame, SparkSession +from rialto.common import TableReader from rialto.common.utils import cast_decimals_to_floats from rialto.loader.config_loader import FeatureConfig, GroupConfig, get_feature_config -from rialto.loader.data_loader import DataLoader from rialto.loader.interfaces import FeatureLoaderInterface from rialto.metadata.metadata_manager import ( FeatureMetadata, @@ -34,7 +34,13 @@ class PysparkFeatureLoader(FeatureLoaderInterface): """Implementation of feature loader for pyspark environment""" - def __init__(self, spark: SparkSession, data_loader: DataLoader, metadata_schema: str): + def __init__( + self, + spark: SparkSession, + feature_schema: str, + metadata_schema: str, + date_column: str = "INFORMATION_DATE", + ): """ Init @@ -44,11 +50,28 @@ def __init__(self, spark: SparkSession, data_loader: DataLoader, metadata_schema """ super().__init__() self.spark = spark - self.data_loader = data_loader + self.reader = TableReader(spark) + self.feature_schema = feature_schema + self.date_col = date_column self.metadata = MetadataManager(spark, metadata_schema) KeyMap = namedtuple("KeyMap", ["df", "key"]) + def read_group(self, group: str, information_date: date) -> DataFrame: + """ + Read a feature group by getting the latest partition by date + + :param group: group name + :param information_date: partition date + :return: dataframe + """ + return self.reader.get_latest( + f"{self.feature_schema}.{group}", + date_until=information_date, + date_column=self.date_col, + uppercase_columns=True, + ) + def get_feature(self, group_name: str, feature_name: str, information_date: date) -> DataFrame: """ Get single feature @@ -60,9 +83,7 @@ def get_feature(self, group_name: str, feature_name: str, information_date: date """ print("This function is untested, use with caution!") key = self.get_group_metadata(group_name).key - return self.data_loader.read_group(self.get_group_fs_name(group_name), information_date).select( - *key, feature_name - ) + return self.read_group(self.get_group_fs_name(group_name), information_date).select(*key, feature_name) def get_feature_metadata(self, group_name: str, feature_name: str) -> FeatureMetadata: """ @@ -83,7 +104,7 @@ def get_group(self, group_name: str, information_date: date) -> DataFrame: :return: A dataframe containing feature group key """ print("This function is untested, use with caution!") - return self.data_loader.read_group(self.get_group_fs_name(group_name), information_date) + return self.read_group(self.get_group_fs_name(group_name), information_date) def get_group_metadata(self, group_name: str) -> GroupMetadata: """ @@ -144,7 +165,7 @@ def _get_keymaps(self, config: FeatureConfig, information_date: date) -> List[Ke """ key_maps = [] for mapping in config.maps: - df = self.data_loader.read_group(self.get_group_fs_name(mapping), information_date).drop("INFORMATION_DATE") + df = self.read_group(self.get_group_fs_name(mapping), information_date).drop("INFORMATION_DATE") key = self.metadata.get_group(mapping).key key_maps.append(PysparkFeatureLoader.KeyMap(df, key)) return key_maps @@ -174,9 +195,7 @@ def get_features_from_cfg(self, path: str, information_date: date) -> DataFrame: """ config = get_feature_config(path) # 1 select keys from base - base = self.data_loader.read_group(self.get_group_fs_name(config.base.group), information_date).select( - config.base.keys - ) + base = self.read_group(self.get_group_fs_name(config.base.group), information_date).select(config.base.keys) # 2 join maps onto base (resolve keys) if config.maps: key_maps = self._get_keymaps(config, information_date) @@ -184,7 +203,7 @@ def get_features_from_cfg(self, path: str, information_date: date) -> DataFrame: # 3 read, select and join other tables for group_cfg in config.selection: - df = self.data_loader.read_group(self.get_group_fs_name(group_cfg.group), information_date) + df = self.read_group(self.get_group_fs_name(group_cfg.group), information_date) base = self._add_feature_group(base, df, group_cfg) # 4 fix dtypes for pandas conversion diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index e3efe01..3fc13b4 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -16,22 +16,15 @@ import datetime from datetime import date -from importlib import import_module from typing import List, Tuple import pyspark.sql.functions as F from loguru import logger from pyspark.sql import DataFrame, SparkSession +import rialto.runner.utils as utils from rialto.common import TableReader -from rialto.loader import DatabricksLoader, PysparkFeatureLoader -from rialto.metadata import MetadataManager -from rialto.runner.config_loader import ( - ModuleConfig, - PipelineConfig, - ScheduleConfig, - get_pipelines_config, -) +from rialto.runner.config_loader import PipelineConfig, get_pipelines_config from rialto.runner.date_manager import DateManager from rialto.runner.table import Table from rialto.runner.tracker import Record, Tracker @@ -84,39 +77,16 @@ def __init__( raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") logger.info(f"Running period from {self.date_from} until {self.date_until}") - def _load_module(self, cfg: ModuleConfig) -> Transformation: + def _execute(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: """ - Load feature group - - :param cfg: Feature configuration - :return: Transformation object - """ - module = import_module(cfg.python_module) - class_obj = getattr(module, cfg.python_class) - return class_obj() - - def _generate(self, instance: Transformation, run_date: date, pipeline: PipelineConfig) -> DataFrame: - """ - Run feature group + Run the job :param instance: Instance of Transformation :param run_date: date to run for :param pipeline: pipeline configuration :return: Dataframe """ - if pipeline.metadata_manager is not None: - metadata_manager = MetadataManager(self.spark, pipeline.metadata_manager.metadata_schema) - else: - metadata_manager = None - - if pipeline.feature_loader is not None: - feature_loader = PysparkFeatureLoader( - self.spark, - DatabricksLoader(self.spark, schema=pipeline.feature_loader.feature_schema), - metadata_schema=pipeline.feature_loader.metadata_schema, - ) - else: - feature_loader = None + metadata_manager, feature_loader = utils.init_tools(self.spark, pipeline) df = instance.run( spark=self.spark, @@ -130,15 +100,6 @@ def _generate(self, instance: Transformation, run_date: date, pipeline: Pipeline return df - def _table_exists(self, table: str) -> bool: - """ - Check table exists in spark catalog - - :param table: full table path - :return: bool - """ - return self.spark.catalog.tableExists(table) - def _write(self, df: DataFrame, info_date: date, table: Table) -> None: """ Write dataframe to storage @@ -152,35 +113,6 @@ def _write(self, df: DataFrame, info_date: date, table: Table) -> None: df.write.partitionBy(table.partition).mode("overwrite").saveAsTable(table.get_table_path()) logger.info(f"Results writen to {table.get_table_path()}") - def _delta_partition(self, table: str) -> str: - """ - Select first partition column, should be only one - - :param table: full table name - :return: partition column name - """ - columns = self.spark.catalog.listColumns(table) - partition_columns = list(filter(lambda c: c.isPartition, columns)) - if len(partition_columns): - return partition_columns[0].name - else: - raise RuntimeError(f"Delta table has no partitions: {table}.") - - def _get_partitions(self, table: Table) -> List[date]: - """ - Get partition values - - :param table: Table object - :return: List of partition values - """ - rows = ( - self.reader.get_table(table.get_table_path(), date_column=table.partition) - .select(table.partition) - .distinct() - .collect() - ) - return [r[table.partition] for r in rows] - def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bool]: """ For given list of dates, check if there is a matching partition for each @@ -189,8 +121,8 @@ def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bo :param dates: list of dates to check :return: list of bool """ - if self._table_exists(table.get_table_path()): - partitions = self._get_partitions(table) + if utils.table_exists(table.get_table_path()): + partitions = utils.get_partitions(self.reader, table) return [(date in partitions) for date in dates] else: logger.info(f"Table {table.get_table_path()} doesn't exist!") @@ -230,25 +162,6 @@ def check_dependencies(self, pipeline: PipelineConfig, run_date: date) -> bool: return True - def get_possible_run_dates(self, schedule: ScheduleConfig) -> List[date]: - """ - List possible run dates according to parameters and config - - :param schedule: schedule config - :return: List of dates - """ - return DateManager.run_dates(self.date_from, self.date_until, schedule) - - def get_info_dates(self, schedule: ScheduleConfig, run_dates: List[date]) -> List[date]: - """ - Transform given dates into info dates according to the config - - :param schedule: schedule config - :param run_dates: date list - :return: list of modified dates - """ - return [DateManager.to_info_date(x, schedule) for x in run_dates] - def _get_completion(self, target: Table, info_dates: List[date]) -> List[bool]: """ Check if model has run for given dates @@ -270,8 +183,8 @@ def _select_run_dates(self, pipeline: PipelineConfig, table: Table) -> Tuple[Lis :param table: table path :return: list of run dates and list of info dates """ - possible_run_dates = self.get_possible_run_dates(pipeline.schedule) - possible_info_dates = self.get_info_dates(pipeline.schedule, possible_run_dates) + possible_run_dates = DateManager.run_dates(self.date_from, self.date_until, pipeline.schedule) + possible_info_dates = [DateManager.to_info_date(x, pipeline.schedule) for x in possible_run_dates] current_state = self._get_completion(table, possible_info_dates) selection = [ @@ -300,8 +213,8 @@ def _run_one_date(self, pipeline: PipelineConfig, run_date: date, info_date: dat if self.skip_dependencies or self.check_dependencies(pipeline, run_date): logger.info(f"Running {pipeline.name} for {run_date}") - feature_group = self._load_module(pipeline.module) - df = self._generate(feature_group, run_date, pipeline) + feature_group = utils.load_module(pipeline.module) + df = self._execute(feature_group, run_date, pipeline) records = df.count() if records > 0: self._write(df, info_date, target) @@ -349,8 +262,8 @@ def _run_pipeline(self, pipeline: PipelineConfig): ) ) except Exception as error: - print(f"An exception occurred in pipeline {pipeline.name}") - print(error) + logger.error(f"An exception occurred in pipeline {pipeline.name}") + logger.error(error) self.tracker.add( Record( job=pipeline.name, @@ -364,7 +277,7 @@ def _run_pipeline(self, pipeline: PipelineConfig): ) ) except KeyboardInterrupt: - print(f"Pipeline {pipeline.name} interrupted") + logger.error(f"Pipeline {pipeline.name} interrupted") self.tracker.add( Record( job=pipeline.name, diff --git a/rialto/runner/transformation.py b/rialto/runner/transformation.py index 7b5eaa8..5b6f2eb 100644 --- a/rialto/runner/transformation.py +++ b/rialto/runner/transformation.py @@ -19,7 +19,7 @@ from pyspark.sql import DataFrame, SparkSession -from rialto.common import TableReader +from rialto.common import DataReader from rialto.loader import PysparkFeatureLoader from rialto.metadata import MetadataManager from rialto.runner.config_loader import PipelineConfig @@ -31,7 +31,7 @@ class Transformation(metaclass=abc.ABCMeta): @abc.abstractmethod def run( self, - reader: TableReader, + reader: DataReader, run_date: datetime.date, spark: SparkSession = None, config: PipelineConfig = None, diff --git a/rialto/runner/utils.py b/rialto/runner/utils.py new file mode 100644 index 0000000..b74ec1b --- /dev/null +++ b/rialto/runner/utils.py @@ -0,0 +1,74 @@ +from datetime import date +from importlib import import_module +from typing import List, Tuple + +from pyspark.sql import SparkSession + +from rialto.common import DataReader +from rialto.loader import PysparkFeatureLoader +from rialto.metadata import MetadataManager +from rialto.runner.config_loader import ModuleConfig, PipelineConfig +from rialto.runner.table import Table +from rialto.runner.transformation import Transformation + + +def load_module(cfg: ModuleConfig) -> Transformation: + """ + Load feature group + + :param cfg: Feature configuration + :return: Transformation object + """ + module = import_module(cfg.python_module) + class_obj = getattr(module, cfg.python_class) + return class_obj() + + +def table_exists(spark: SparkSession, table: str) -> bool: + """ + Check table exists in spark catalog + + :param table: full table path + :return: bool + """ + return spark.catalog.tableExists(table) + + +def get_partitions(reader: DataReader, table: Table) -> List[date]: + """ + Get partition values + + :param table: Table object + :return: List of partition values + """ + rows = ( + reader.get_table(table.get_table_path(), date_column=table.partition) + .select(table.partition) + .distinct() + .collect() + ) + return [r[table.partition] for r in rows] + + +def init_tools(spark: SparkSession, pipeline: PipelineConfig) -> Tuple[MetadataManager, PysparkFeatureLoader]: + """ + Initialize metadata manager and feature loader + + :param spark: Spark session + :param pipeline: Pipeline configuration + :return: MetadataManager and PysparkFeatureLoader + """ + if pipeline.metadata_manager is not None: + metadata_manager = MetadataManager(spark, pipeline.metadata_manager.metadata_schema) + else: + metadata_manager = None + + if pipeline.feature_loader is not None: + feature_loader = PysparkFeatureLoader( + spark, + feature_schema=pipeline.feature_loader.feature_schema, + metadata_schema=pipeline.feature_loader.metadata_schema, + ) + else: + feature_loader = None + return metadata_manager, feature_loader diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py index fa8f19c..55fced1 100644 --- a/tests/jobs/test_job_base.py +++ b/tests/jobs/test_job_base.py @@ -20,7 +20,7 @@ import tests.jobs.resources as resources from rialto.jobs.decorators.resolver import Resolver -from rialto.loader import DatabricksLoader, PysparkFeatureLoader +from rialto.loader import PysparkFeatureLoader def test_setup_except_feature_loader(spark): @@ -39,7 +39,7 @@ def test_setup_except_feature_loader(spark): def test_setup_feature_loader(spark): table_reader = MagicMock() date = datetime.date(2023, 1, 1) - feature_loader = PysparkFeatureLoader(spark, DatabricksLoader(spark, "", ""), "") + feature_loader = PysparkFeatureLoader(spark, "", "", "") resources.CustomJobNoReturnVal().run( reader=table_reader, run_date=date, spark=spark, config=None, feature_loader=feature_loader diff --git a/tests/loader/pyspark/dummy_loaders.py b/tests/loader/pyspark/dummy_loaders.py deleted file mode 100644 index a2b0cb8..0000000 --- a/tests/loader/pyspark/dummy_loaders.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright 2022 ABSA Group Limited -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from datetime import date - -from rialto.loader.data_loader import DataLoader - - -class DummyDataLoader(DataLoader): - def __init__(self): - super().__init__() - - def read_group(self, group: str, information_date: date): - return None diff --git a/tests/loader/pyspark/test_from_cfg.py b/tests/loader/pyspark/test_from_cfg.py index 3ad653e..dd2049f 100644 --- a/tests/loader/pyspark/test_from_cfg.py +++ b/tests/loader/pyspark/test_from_cfg.py @@ -21,7 +21,6 @@ from rialto.loader.config_loader import get_feature_config from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader from tests.loader.pyspark.dataframe_builder import dataframe_builder as dfb -from tests.loader.pyspark.dummy_loaders import DummyDataLoader @pytest.fixture(scope="session") @@ -45,7 +44,7 @@ def spark(request): @pytest.fixture(scope="session") def loader(spark): - return PysparkFeatureLoader(spark, DummyDataLoader(), MagicMock()) + return PysparkFeatureLoader(spark, MagicMock(), MagicMock()) VALID_LIST = [(["a"], ["a"]), (["a"], ["a", "b", "c"]), (["c", "a"], ["a", "b", "c"])] @@ -90,7 +89,7 @@ def __call__(self, *args, **kwargs): metadata = MagicMock() monkeypatch.setattr(metadata, "get_group", GroupMd()) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") loader.metadata = metadata base = dfb(spark, data=r.base_frame_data, columns=r.base_frame_columns) @@ -105,7 +104,7 @@ def __call__(self, *args, **kwargs): def test_get_group_metadata(spark, mocker): mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", return_value=7) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") ret_val = loader.get_group_metadata("group_name") assert ret_val == 7 @@ -115,7 +114,7 @@ def test_get_group_metadata(spark, mocker): def test_get_feature_metadata(spark, mocker): mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_feature", return_value=8) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") ret_val = loader.get_feature_metadata("group_name", "feature") assert ret_val == 8 @@ -129,7 +128,7 @@ def test_get_metadata_from_cfg(spark, mocker): ) mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", side_effect=lambda g: {"B": 10}[g]) - loader = PysparkFeatureLoader(spark, DummyDataLoader(), "") + loader = PysparkFeatureLoader(spark, "", "") metadata = loader.get_metadata_from_cfg("tests/loader/pyspark/example_cfg.yaml") assert metadata["B_F1"] == 1 diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 85ddf95..2171c7b 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import namedtuple from datetime import datetime from typing import Optional @@ -61,20 +60,10 @@ def get_latest( def test_table_exists(spark, mocker, basic_runner): mock = mocker.patch("pyspark.sql.Catalog.tableExists", return_value=True) - basic_runner._table_exists("abc") + basic_runner.table_exists("abc") mock.assert_called_once_with("abc") -def test_infer_column(spark, mocker, basic_runner): - column = namedtuple("catalog", ["name", "isPartition"]) - catalog = [column("a", True), column("b", False), column("c", False)] - - mock = mocker.patch("pyspark.sql.Catalog.listColumns", return_value=catalog) - partition = basic_runner._delta_partition("aaa") - assert partition == "a" - mock.assert_called_once_with("aaa") - - def test_load_module(spark, basic_runner): module = basic_runner._load_module(basic_runner.config.pipelines[0].module) assert isinstance(module, SimpleGroup) @@ -84,7 +73,7 @@ def test_generate(spark, mocker, basic_runner): run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") group = SimpleGroup() config = basic_runner.config.pipelines[0] - basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), config) + basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), config) run.assert_called_once_with( reader=basic_runner.reader, @@ -99,7 +88,7 @@ def test_generate(spark, mocker, basic_runner): def test_generate_w_dep(spark, mocker, basic_runner): run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run") group = SimpleGroup() - basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2]) + basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2]) run.assert_called_once_with( reader=basic_runner.reader, run_date=DateManager.str_to_date("2023-01-31"), @@ -133,27 +122,6 @@ def test_init_dates(spark): assert runner.date_until == DateManager.str_to_date("2023-03-31") -def test_possible_run_dates(spark): - runner = Runner( - spark, - config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-01", - date_until="2023-03-31", - ) - - dates = runner.get_possible_run_dates(runner.config.pipelines[0].schedule) - expected = ["2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"] - assert dates == [DateManager.str_to_date(d) for d in expected] - - -def test_info_dates(spark, basic_runner): - run = ["2023-02-05", "2023-02-12", "2023-02-19", "2023-02-26", "2023-03-05"] - run = [DateManager.str_to_date(d) for d in run] - info = basic_runner.get_info_dates(basic_runner.config.pipelines[0].schedule, run) - expected = ["2023-02-02", "2023-02-09", "2023-02-16", "2023-02-23", "2023-03-02"] - assert info == [DateManager.str_to_date(d) for d in expected] - - def test_completion(spark, mocker, basic_runner): mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) From 20b9cb2d1cd2e3fbede4c4b7c82b12444419ea06 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Mon, 26 Aug 2024 13:54:11 +0200 Subject: [PATCH 05/12] fixed impacted tests --- tests/runner/test_runner.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 2171c7b..d3209bf 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -17,6 +17,7 @@ import pytest from pyspark.sql import DataFrame +import rialto.runner.utils as utils from rialto.common.table_reader import DataReader from rialto.runner.runner import DateManager, Runner from rialto.runner.table import Table @@ -58,14 +59,14 @@ def get_latest( pass -def test_table_exists(spark, mocker, basic_runner): +def test_table_exists(spark, mocker): mock = mocker.patch("pyspark.sql.Catalog.tableExists", return_value=True) - basic_runner.table_exists("abc") + utils.table_exists(spark, "abc") mock.assert_called_once_with("abc") def test_load_module(spark, basic_runner): - module = basic_runner._load_module(basic_runner.config.pipelines[0].module) + module = utils.load_module(basic_runner.config.pipelines[0].module) assert isinstance(module, SimpleGroup) @@ -123,7 +124,7 @@ def test_init_dates(spark): def test_completion(spark, mocker, basic_runner): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.utils.table_exists", return_value=True) basic_runner.reader = MockReader(spark) @@ -136,7 +137,7 @@ def test_completion(spark, mocker, basic_runner): def test_completion_rerun(spark, mocker, basic_runner): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31") runner.reader = MockReader(spark) @@ -150,7 +151,7 @@ def test_completion_rerun(spark, mocker, basic_runner): def test_check_dates_have_partition(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, @@ -167,7 +168,7 @@ def test_check_dates_have_partition(spark, mocker): def test_check_dates_have_partition_no_table(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=False) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=False) runner = Runner( spark, @@ -201,7 +202,7 @@ def test_check_dependencies(spark, mocker, r_date, expected): def test_check_no_dependencies(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, @@ -215,7 +216,7 @@ def test_check_no_dependencies(spark, mocker): def test_select_dates(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, @@ -237,7 +238,7 @@ def test_select_dates(spark, mocker): def test_select_dates_all_done(spark, mocker): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, From a0d5a02020d9e71e003a2c8f48794c2ceeee16f1 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Mon, 26 Aug 2024 13:55:20 +0200 Subject: [PATCH 06/12] fixed impacted tests --- tests/runner/test_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index d3209bf..9f16ea0 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -188,7 +188,7 @@ def test_check_dates_have_partition_no_table(spark, mocker): [("2023-02-26", False), ("2023-03-05", True)], ) def test_check_dependencies(spark, mocker, r_date, expected): - mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True) + mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True) runner = Runner( spark, From 5401307eeb9cddee97f72d7e0b2a1b809f7025b1 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Mon, 26 Aug 2024 16:00:50 +0200 Subject: [PATCH 07/12] missing parameter --- rialto/runner/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 3fc13b4..9f9b276 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -121,7 +121,7 @@ def check_dates_have_partition(self, table: Table, dates: List[date]) -> List[bo :param dates: list of dates to check :return: list of bool """ - if utils.table_exists(table.get_table_path()): + if utils.table_exists(self.spark, table.get_table_path()): partitions = utils.get_partitions(self.reader, table) return [(date in partitions) for date in dates] else: From eea51cdb4182ea0496207775bce162a6a350ab24 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Tue, 27 Aug 2024 17:28:43 +0200 Subject: [PATCH 08/12] override parser messy but working --- rialto/runner/config_loader.py | 10 +++- rialto/runner/config_overrides.py | 49 ++++++++++++++++ rialto/runner/runner.py | 5 +- tests/runner/overrider.yaml | 96 +++++++++++++++++++++++++++++++ tests/runner/test_overrides.py | 41 +++++++++++++ 5 files changed, 197 insertions(+), 4 deletions(-) create mode 100644 rialto/runner/config_overrides.py create mode 100644 tests/runner/overrider.yaml create mode 100644 tests/runner/test_overrides.py diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index d5b0150..821a75f 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -21,6 +21,7 @@ from pydantic import BaseModel from rialto.common.utils import load_yaml +from rialto.runner.config_overrides import override_config class IntervalConfig(BaseModel): @@ -91,6 +92,11 @@ class PipelinesConfig(BaseModel): pipelines: list[PipelineConfig] -def get_pipelines_config(path) -> PipelinesConfig: +def get_pipelines_config(path: str, overrides: Dict) -> PipelinesConfig: """Load and parse yaml config""" - return PipelinesConfig(**load_yaml(path)) + raw_config = load_yaml(path) + if overrides: + cfg = override_config(raw_config, overrides) + return PipelinesConfig(**cfg) + else: + return PipelinesConfig(**raw_config) diff --git a/rialto/runner/config_overrides.py b/rialto/runner/config_overrides.py new file mode 100644 index 0000000..310c58e --- /dev/null +++ b/rialto/runner/config_overrides.py @@ -0,0 +1,49 @@ +from typing import Dict + +from loguru import logger + + +def _override(config, path, value) -> Dict: + key = path[0] + if "[" in key: + name = key.split("[")[0] + index = key.split("[")[1].replace("]", "") + if "=" in index: + index_key, index_value = index.split("=") + position = next(i for i, x in enumerate(config[name]) if x.get(index_key) == index_value) + if len(path) == 1: + config[name][position] = value + else: + config[name][position] = _override(config[name][position], path[1:], value) + else: + index = int(index) + if index >= 0: + if len(path) == 1: + config[name][index] = value + else: + config[name][index] = _override(config[name][index], path[1:], value) + else: + if len(path) == 1: + config[name].append(value) + else: + raise ValueError(f"Invalid index {index} for key {name} in path {path}") + else: + if len(path) == 1: + config[key] = value + else: + config[key] = _override(config[key], path[1:], value) + return config + + +def override_config(config: Dict, overrides: Dict) -> Dict: + """Override config with user input + + :param config: Config dictionary + :param overrides: Dictionary of overrides + :return: Overridden config + """ + for path, value in overrides.items(): + logger.info("Applying override: ", path, value) + config = _override(config, path.split("."), value) + + return config diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 9f9b276..0178f05 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -16,7 +16,7 @@ import datetime from datetime import date -from typing import List, Tuple +from typing import Dict, List, Tuple import pyspark.sql.functions as F from loguru import logger @@ -44,9 +44,10 @@ def __init__( rerun: bool = False, op: str = None, skip_dependencies: bool = False, + overrides: Dict = None, ): self.spark = spark - self.config = get_pipelines_config(config_path) + self.config = get_pipelines_config(config_path, overrides) self.reader = TableReader(spark) self.date_from = date_from diff --git a/tests/runner/overrider.yaml b/tests/runner/overrider.yaml new file mode 100644 index 0000000..750699e --- /dev/null +++ b/tests/runner/overrider.yaml @@ -0,0 +1,96 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +runner: + watched_period_units: "months" + watched_period_value: 2 + mail: + sender: test@testing.org + smtp: server.test + to: + - developer@testing.org + - developer2@testing.org + subject: test report +pipelines: + - name: SimpleGroup + module: + python_module: tests.runner.transformations + python_class: SimpleGroup + schedule: + frequency: weekly + day: 7 + info_date_shift: + - value: 3 + units: days + - value: 2 + units: weeks + dependencies: + - table: source.schema.dep1 + interval: + units: "days" + value: 1 + date_col: "DATE" + - table: source.schema.dep2 + interval: + units: "months" + value: 3 + date_col: "DATE" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" + loader: + config_path: path/to/config.yaml + feature_schema: catalog.feature_tables + metadata_schema: catalog.metadata + metadata_manager: + metadata_schema: catalog.metadata + - name: GroupNoDeps + module: + python_module: tests.runner.transformations + python_class: SimpleGroup + schedule: + frequency: weekly + day: 7 + info_date_shift: + value: 3 + units: days + - name: NamedDeps + module: + python_module: tests.runner.transformations + python_class: SimpleGroup + schedule: + frequency: weekly + day: 7 + info_date_shift: + value: 3 + units: days + dependencies: + - table: source.schema.dep1 + name: source1 + interval: + units: "days" + value: 1 + date_col: "DATE" + - table: source.schema.dep2 + name: source2 + interval: + units: "months" + value: 3 + date_col: "batch" + target: + target_schema: catalog.schema + target_partition_column: "INFORMATION_DATE" + extras: + some_value: 3 + some_other_value: cat diff --git a/tests/runner/test_overrides.py b/tests/runner/test_overrides.py new file mode 100644 index 0000000..0962c0e --- /dev/null +++ b/tests/runner/test_overrides.py @@ -0,0 +1,41 @@ +from rialto.runner import Runner + + +def test_overrides_simple(spark): + runner = Runner( + spark, + config_path="tests/runner/transformations/config.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"]}, + ) + assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] + + +def test_overrides_array_index(spark): + runner = Runner( + spark, + config_path="tests/runner/transformations/config.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to[1]": "a@b.c"}, + ) + assert runner.config.runner.mail.to == ["developer@testing.org", "a@b.c"] + + +def test_overrides_array_append(spark): + runner = Runner( + spark, + config_path="tests/runner/transformations/config.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to[-1]": "test"}, + ) + assert runner.config.runner.mail.to == ["developer@testing.org", "developer2@testing.org", "test"] + + +def test_overrides_array_lookup(spark): + runner = Runner( + spark, + config_path="tests/runner/transformations/config.yaml", + run_date="2023-03-31", + overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"}, + ) + assert runner.config.pipelines[0].target.target_schema == "new_schema" From 64276263f0345578cb39aed93f39e4d41ca91a3a Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Wed, 28 Aug 2024 12:19:04 +0200 Subject: [PATCH 09/12] override tests and cleanup --- rialto/runner/config_overrides.py | 67 +++++++++++++++++++--------- tests/runner/test_overrides.py | 72 +++++++++++++++++++++++++++++-- 2 files changed, 115 insertions(+), 24 deletions(-) diff --git a/rialto/runner/config_overrides.py b/rialto/runner/config_overrides.py index 310c58e..72d59ac 100644 --- a/rialto/runner/config_overrides.py +++ b/rialto/runner/config_overrides.py @@ -1,33 +1,60 @@ -from typing import Dict +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["override_config"] + +from typing import Dict, List, Tuple from loguru import logger +def _split_index_key(key: str) -> Tuple[str, str]: + name = key.split("[")[0] + index = key.split("[")[1].replace("]", "") + return name, index + + +def _find_first_match(config: List, index: str) -> int: + index_key, index_value = index.split("=") + return next(i for i, x in enumerate(config) if x.get(index_key) == index_value) + + def _override(config, path, value) -> Dict: key = path[0] if "[" in key: - name = key.split("[")[0] - index = key.split("[")[1].replace("]", "") + name, index = _split_index_key(key) + if name not in config: + raise ValueError(f"Invalid key {name}") if "=" in index: - index_key, index_value = index.split("=") - position = next(i for i, x in enumerate(config[name]) if x.get(index_key) == index_value) - if len(path) == 1: - config[name][position] = value - else: - config[name][position] = _override(config[name][position], path[1:], value) + index = _find_first_match(config[name], index) else: index = int(index) - if index >= 0: - if len(path) == 1: - config[name][index] = value - else: - config[name][index] = _override(config[name][index], path[1:], value) + if index >= 0 and index < len(config[name]): + if len(path) == 1: + config[name][index] = value else: - if len(path) == 1: - config[name].append(value) - else: - raise ValueError(f"Invalid index {index} for key {name} in path {path}") + config[name][index] = _override(config[name][index], path[1:], value) + elif index == -1: + if len(path) == 1: + config[name].append(value) + else: + raise ValueError(f"Invalid index {index} for key {name} in path {path}") + else: + raise IndexError(f"Index {index} out of bounds for key {key}") else: + if key not in config: + raise ValueError(f"Invalid key {key}") if len(path) == 1: config[key] = value else: @@ -38,8 +65,8 @@ def _override(config, path, value) -> Dict: def override_config(config: Dict, overrides: Dict) -> Dict: """Override config with user input - :param config: Config dictionary - :param overrides: Dictionary of overrides + :param config: config dictionary + :param overrides: dictionary of overrides :return: Overridden config """ for path, value in overrides.items(): diff --git a/tests/runner/test_overrides.py b/tests/runner/test_overrides.py index 0962c0e..996fa10 100644 --- a/tests/runner/test_overrides.py +++ b/tests/runner/test_overrides.py @@ -1,10 +1,25 @@ +# Copyright 2022 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + from rialto.runner import Runner def test_overrides_simple(spark): runner = Runner( spark, - config_path="tests/runner/transformations/config.yaml", + config_path="tests/runner/overrider.yaml", run_date="2023-03-31", overrides={"runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"]}, ) @@ -14,7 +29,7 @@ def test_overrides_simple(spark): def test_overrides_array_index(spark): runner = Runner( spark, - config_path="tests/runner/transformations/config.yaml", + config_path="tests/runner/overrider.yaml", run_date="2023-03-31", overrides={"runner.mail.to[1]": "a@b.c"}, ) @@ -24,7 +39,7 @@ def test_overrides_array_index(spark): def test_overrides_array_append(spark): runner = Runner( spark, - config_path="tests/runner/transformations/config.yaml", + config_path="tests/runner/overrider.yaml", run_date="2023-03-31", overrides={"runner.mail.to[-1]": "test"}, ) @@ -34,8 +49,57 @@ def test_overrides_array_append(spark): def test_overrides_array_lookup(spark): runner = Runner( spark, - config_path="tests/runner/transformations/config.yaml", + config_path="tests/runner/overrider.yaml", run_date="2023-03-31", overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"}, ) assert runner.config.pipelines[0].target.target_schema == "new_schema" + + +def test_overrides_combined(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={ + "runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"], + "pipelines[name=SimpleGroup].target.target_schema": "new_schema", + "pipelines[name=SimpleGroup].schedule.info_date_shift[0].value": 1, + }, + ) + assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"] + assert runner.config.pipelines[0].target.target_schema == "new_schema" + assert runner.config.pipelines[0].schedule.info_date_shift[0].value == 1 + + +def test_index_out_of_range(spark): + with pytest.raises(IndexError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.to[8]": "test"}, + ) + assert error.value.args[0] == "Index 8 out of bounds for key to[8]" + + +def test_invalid_index_key(spark): + with pytest.raises(ValueError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.test[8]": "test"}, + ) + assert error.value.args[0] == "Invalid key test" + + +def test_invalid_key(spark): + with pytest.raises(ValueError) as error: + Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.mail.test": "test"}, + ) + assert error.value.args[0] == "Invalid key test" From 714ad584745b397f01655909dc2a419abac04da2 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Thu, 29 Aug 2024 11:36:41 +0200 Subject: [PATCH 10/12] documenting changes --- CHANGELOG.md | 2 + README.md | 57 ++++++++++++++++++++++++ rialto/runner/config_loader.py | 4 +- rialto/runner/config_overrides.py | 4 +- tests/runner/overrider.yaml | 18 ++------ tests/runner/test_overrides.py | 34 +++++++++++++- tests/runner/transformations/config.yaml | 8 ++-- 7 files changed, 104 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c25f0e..1b8b5da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,10 @@ All notable changes to this project will be documented in this file. - dependency date_col is now mandatory - custom extras config is available in each pipeline and will be passed as dictionary available under pipeline_config.extras - general section is renamed to runner + - info_date_shift is always a list - transformation header changed - added argument to skip dependency checking + - added overrides parameter to allow for dynamic overriding of config values #### Jobs - jobs are now the main way to create all pipelines - config holder removed from jobs diff --git a/README.md b/README.md index 56ccaea..974244b 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ This behavior can be modified by various parameters and switches available. * **rerun** - rerun all jobs even if they already succeeded in the past runs * **op** - run only selected operation / pipeline * **skip_dependencies** - ignore dependency checks and run all jobs +* **overrides** - dictionary of overrides for the configuration Transformations are not included in the runner itself, it imports them dynamically according to the configuration, therefore it's necessary to have them locally installed. @@ -132,6 +133,62 @@ pipelines: # a list of pipelines to run value: 6 ``` +The configuration can be dynamically overridden by providing a dictionary of overrides to the runner. All overrides must adhere to configurations schema, with pipeline.extras section available for custom schema. +Here are few examples of overrides: + +#### Simple override of a single value +Specify the path to the value in the configuration file as a dot-separated string + +```python +Runner( + spark, + config_path="tests/overrider.yaml", + run_date="2023-03-31", + overrides={"runner.watch_period_value": 4}, + ) +``` + +#### Override list element +You can refer to list elements by their index (starting with 0) +```python +overrides={"runner.mail.to[1]": "a@b.c"} +``` + +#### Append to list +You can append to list by using index -1 +```python +overrides={"runner.mail.to[-1]": "test@test.com"} +``` + +#### Lookup by attribute value in a list +You can use the following syntax to find a specific element in a list by its attribute value +```python +overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"}, +``` + +#### Injecting/Replacing whole sections +You can directly replace a bigger section of the configuration by providing a dictionary +When the whole section doesn't exist, it will be added to the configuration, however it needs to be added as a whole. +i.e. if the yaml file doesn't specify feature_loader, you can't just add a feature_loader.config_path, you need to add the whole section. +```python +overrides={"pipelines[name=SimpleGroup].feature_loader": + {"config_path": "features_cfg.yaml", + "feature_schema": "catalog.features", + "metadata_schema": "catalog.metadata"}} +``` + +#### Multiple overrides +You can provide multiple overrides at once, the order of execution is not guaranteed +```python +overrides={"runner.watch_period_value": 4, + "runner.watch_period_units": "weeks", + "pipelines[name=SimpleGroup].target.target_schema": "new_schema", + "pipelines[name=SimpleGroup].feature_loader": + {"config_path": "features_cfg.yaml", + "feature_schema": "catalog.features", + "metadata_schema": "catalog.metadata"} + } +``` ## 2.2 - maker diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index 821a75f..dea7b07 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -16,7 +16,7 @@ "get_pipelines_config", ] -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional from pydantic import BaseModel @@ -32,7 +32,7 @@ class IntervalConfig(BaseModel): class ScheduleConfig(BaseModel): frequency: str day: Optional[int] = 0 - info_date_shift: Union[Optional[IntervalConfig], List[IntervalConfig]] = IntervalConfig(units="days", value=0) + info_date_shift: List[IntervalConfig] = IntervalConfig(units="days", value=0) class DependencyConfig(BaseModel): diff --git a/rialto/runner/config_overrides.py b/rialto/runner/config_overrides.py index 72d59ac..a525525 100644 --- a/rialto/runner/config_overrides.py +++ b/rialto/runner/config_overrides.py @@ -53,11 +53,11 @@ def _override(config, path, value) -> Dict: else: raise IndexError(f"Index {index} out of bounds for key {key}") else: - if key not in config: - raise ValueError(f"Invalid key {key}") if len(path) == 1: config[key] = value else: + if key not in config: + raise ValueError(f"Invalid key {key}") config[key] = _override(config[key], path[1:], value) return config diff --git a/tests/runner/overrider.yaml b/tests/runner/overrider.yaml index 750699e..3029730 100644 --- a/tests/runner/overrider.yaml +++ b/tests/runner/overrider.yaml @@ -49,23 +49,13 @@ pipelines: target: target_schema: catalog.schema target_partition_column: "INFORMATION_DATE" - loader: + feature_loader: config_path: path/to/config.yaml feature_schema: catalog.feature_tables metadata_schema: catalog.metadata metadata_manager: metadata_schema: catalog.metadata - - name: GroupNoDeps - module: - python_module: tests.runner.transformations - python_class: SimpleGroup - schedule: - frequency: weekly - day: 7 - info_date_shift: - value: 3 - units: days - - name: NamedDeps + - name: OtherGroup module: python_module: tests.runner.transformations python_class: SimpleGroup @@ -73,8 +63,8 @@ pipelines: frequency: weekly day: 7 info_date_shift: - value: 3 - units: days + - value: 3 + units: days dependencies: - table: source.schema.dep1 name: source1 diff --git a/tests/runner/test_overrides.py b/tests/runner/test_overrides.py index 996fa10..17fcdbe 100644 --- a/tests/runner/test_overrides.py +++ b/tests/runner/test_overrides.py @@ -100,6 +100,38 @@ def test_invalid_key(spark): spark, config_path="tests/runner/overrider.yaml", run_date="2023-03-31", - overrides={"runner.mail.test": "test"}, + overrides={"runner.mail.test.param": "test"}, ) assert error.value.args[0] == "Invalid key test" + + +def test_replace_section(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={ + "pipelines[name=SimpleGroup].feature_loader": { + "config_path": "features_cfg.yaml", + "feature_schema": "catalog.features", + "metadata_schema": "catalog.metadata", + } + }, + ) + assert runner.config.pipelines[0].feature_loader.feature_schema == "catalog.features" + + +def test_add_section(spark): + runner = Runner( + spark, + config_path="tests/runner/overrider.yaml", + run_date="2023-03-31", + overrides={ + "pipelines[name=OtherGroup].feature_loader": { + "config_path": "features_cfg.yaml", + "feature_schema": "catalog.features", + "metadata_schema": "catalog.metadata", + } + }, + ) + assert runner.config.pipelines[1].feature_loader.feature_schema == "catalog.features" diff --git a/tests/runner/transformations/config.yaml b/tests/runner/transformations/config.yaml index 0ed82ce..743a07f 100644 --- a/tests/runner/transformations/config.yaml +++ b/tests/runner/transformations/config.yaml @@ -31,8 +31,8 @@ pipelines: frequency: weekly day: 7 info_date_shift: - value: 3 - units: days + - value: 3 + units: days dependencies: - table: source.schema.dep1 interval: @@ -65,8 +65,8 @@ pipelines: frequency: weekly day: 7 info_date_shift: - value: 3 - units: days + - value: 3 + units: days dependencies: - table: source.schema.dep1 name: source1 From 3f2ce2f6f40a381ece3ede6eab7b7773539bfeab Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Thu, 29 Aug 2024 11:48:58 +0200 Subject: [PATCH 11/12] info_date_shift should be optional --- rialto/runner/config_loader.py | 2 +- tests/runner/test_date_manager.py | 4 ++-- tests/runner/transformations/config.yaml | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index dea7b07..c4ce193 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -32,7 +32,7 @@ class IntervalConfig(BaseModel): class ScheduleConfig(BaseModel): frequency: str day: Optional[int] = 0 - info_date_shift: List[IntervalConfig] = IntervalConfig(units="days", value=0) + info_date_shift: Optional[List[IntervalConfig]] = IntervalConfig(units="days", value=0) class DependencyConfig(BaseModel): diff --git a/tests/runner/test_date_manager.py b/tests/runner/test_date_manager.py index 9088e0c..73b61b8 100644 --- a/tests/runner/test_date_manager.py +++ b/tests/runner/test_date_manager.py @@ -144,7 +144,7 @@ def test_run_dates_invalid(): [(7, "2023-02-26"), (3, "2023-03-02"), (-5, "2023-03-10"), (0, "2023-03-05")], ) def test_to_info_date(shift, res): - cfg = ScheduleConfig(frequency="daily", info_date_shift=IntervalConfig(units="days", value=shift)) + cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units="days", value=shift)]) base = DateManager.str_to_date("2023-03-05") info = DateManager.to_info_date(base, cfg) assert DateManager.str_to_date(res) == info @@ -155,7 +155,7 @@ def test_to_info_date(shift, res): [("days", "2023-03-02"), ("weeks", "2023-02-12"), ("months", "2022-12-05"), ("years", "2020-03-05")], ) def test_info_date_shift_units(unit, result): - cfg = ScheduleConfig(frequency="daily", info_date_shift=IntervalConfig(units=unit, value=3)) + cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units=unit, value=3)]) base = DateManager.str_to_date("2023-03-05") info = DateManager.to_info_date(base, cfg) assert DateManager.str_to_date(result) == info diff --git a/tests/runner/transformations/config.yaml b/tests/runner/transformations/config.yaml index 743a07f..3b72107 100644 --- a/tests/runner/transformations/config.yaml +++ b/tests/runner/transformations/config.yaml @@ -55,8 +55,8 @@ pipelines: frequency: weekly day: 7 info_date_shift: - value: 3 - units: days + - value: 3 + units: days - name: NamedDeps module: python_module: tests.runner.transformations From 68a7ad73964c1c92ed356ae5b027d97c86db19f1 Mon Sep 17 00:00:00 2001 From: Marek Dobransky Date: Tue, 3 Sep 2024 10:50:15 +0200 Subject: [PATCH 12/12] added config job --- CHANGELOG.md | 2 ++ README.md | 23 +++++++++++++++++------ rialto/jobs/__init__.py | 2 +- rialto/jobs/decorators/__init__.py | 2 +- rialto/jobs/decorators/decorators.py | 16 +++++++++++++++- rialto/jobs/decorators/test_utils.py | 7 +++++-- rialto/runner/config_loader.py | 3 +-- rialto/runner/runner.py | 27 +++++++++------------------ tests/jobs/test_decorators.py | 7 +++++++ tests/jobs/test_job/test_job.py | 5 ++++- tests/runner/test_runner.py | 26 +++++++++++--------------- 11 files changed, 73 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b8b5da..63e9791 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,10 +14,12 @@ All notable changes to this project will be documented in this file. - transformation header changed - added argument to skip dependency checking - added overrides parameter to allow for dynamic overriding of config values + - removed date_from and date_to from arguments, use overrides instead #### Jobs - jobs are now the main way to create all pipelines - config holder removed from jobs - metadata_manager and feature_loader are now available arguments, depending on configuration + - added @config decorator, similar use case to @datasource, for parsing configuration #### TableReader - function signatures changed - until -> date_until diff --git a/README.md b/README.md index 974244b..2ac915f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ - +from pydantic import BaseModelfrom rialto.runner.config_loader import PipelineConfigfrom rialto.jobs import config # Rialto @@ -54,8 +54,6 @@ A runner by default executes all the jobs provided in the configuration file, fo This behavior can be modified by various parameters and switches available. * **run_date** - date at which the runner is triggered (defaults to day of running) -* **date_from** - starting date (defaults to rundate - config watch period) -* **date_until** - end date (defaults to rundate) * **rerun** - rerun all jobs even if they already succeeded in the past runs * **op** - run only selected operation / pipeline * **skip_dependencies** - ignore dependency checks and run all jobs @@ -131,6 +129,9 @@ pipelines: # a list of pipelines to run interval: units: "days" value: 6 + target: + target_schema: catalog.schema # schema where tables will be created, must exist + target_partition_column: INFORMATION_DATE # date to partition new tables on ``` The configuration can be dynamically overridden by providing a dictionary of overrides to the runner. All overrides must adhere to configurations schema, with pipeline.extras section available for custom schema. @@ -371,8 +372,18 @@ With that sorted out, we can now provide a quick example of the *rialto.jobs* mo ```python from pyspark.sql import DataFrame from rialto.common import TableReader -from rialto.jobs.decorators import job, datasource +from rialto.jobs.decorators import config, job, datasource +from rialto.runner.config_loader import PipelineConfig +from pydantic import BaseModel + + +class ConfigModel(BaseModel): + some_value: int + some_other_value: str +@config +def my_config(config: PipelineConfig): + return ConfigModel(**config.extras) @datasource def my_datasource(run_date: datetime.date, table_reader: TableReader) -> DataFrame: @@ -380,8 +391,8 @@ def my_datasource(run_date: datetime.date, table_reader: TableReader) -> DataFra @job -def my_job(my_datasource: DataFrame) -> DataFrame: - return my_datasource.withColumn("HelloWorld", F.lit(1)) +def my_job(my_datasource: DataFrame, my_config: ConfigModel) -> DataFrame: + return my_datasource.withColumn("HelloWorld", F.lit(my_config.some_value)) ``` This piece of code 1. creates a rialto transformation called *my_job*, which is then callable by the rialto runner. diff --git a/rialto/jobs/__init__.py b/rialto/jobs/__init__.py index 90183bd..a6ee6cb 100644 --- a/rialto/jobs/__init__.py +++ b/rialto/jobs/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rialto.jobs.decorators import datasource, job +from rialto.jobs.decorators import config, datasource, job diff --git a/rialto/jobs/decorators/__init__.py b/rialto/jobs/decorators/__init__.py index ba62141..6f2713a 100644 --- a/rialto/jobs/decorators/__init__.py +++ b/rialto/jobs/decorators/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .decorators import datasource, job +from .decorators import config, datasource, job diff --git a/rialto/jobs/decorators/decorators.py b/rialto/jobs/decorators/decorators.py index 217b436..d288b7b 100644 --- a/rialto/jobs/decorators/decorators.py +++ b/rialto/jobs/decorators/decorators.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["datasource", "job"] +__all__ = ["datasource", "job", "config"] import inspect import typing @@ -24,6 +24,20 @@ from rialto.jobs.decorators.resolver import Resolver +def config(ds_getter: typing.Callable) -> typing.Callable: + """ + Config parser functions decorator. + + Registers a config parsing function into a rialto job prerequisite. + You can then request the job via job function arguments. + + :param ds_getter: dataset reader function + :return: raw reader function, unchanged + """ + Resolver.register_callable(ds_getter) + return ds_getter + + def datasource(ds_getter: typing.Callable) -> typing.Callable: """ Dataset reader functions decorator. diff --git a/rialto/jobs/decorators/test_utils.py b/rialto/jobs/decorators/test_utils.py index 5465d6e..39d76ce 100644 --- a/rialto/jobs/decorators/test_utils.py +++ b/rialto/jobs/decorators/test_utils.py @@ -17,9 +17,10 @@ import importlib import typing from contextlib import contextmanager -from unittest.mock import patch, create_autospec, MagicMock -from rialto.jobs.decorators.resolver import Resolver, ResolverException +from unittest.mock import MagicMock, create_autospec, patch + from rialto.jobs.decorators.job_base import JobBase +from rialto.jobs.decorators.resolver import Resolver, ResolverException def _passthrough_decorator(*args, **kwargs) -> typing.Callable: @@ -34,6 +35,8 @@ def _disable_job_decorators() -> None: patches = [ patch("rialto.jobs.decorators.datasource", _passthrough_decorator), patch("rialto.jobs.decorators.decorators.datasource", _passthrough_decorator), + patch("rialto.jobs.decorators.config", _passthrough_decorator), + patch("rialto.jobs.decorators.decorators.config", _passthrough_decorator), patch("rialto.jobs.decorators.job", _passthrough_decorator), patch("rialto.jobs.decorators.decorators.job", _passthrough_decorator), ] diff --git a/rialto/runner/config_loader.py b/rialto/runner/config_loader.py index c4ce193..86c142d 100644 --- a/rialto/runner/config_loader.py +++ b/rialto/runner/config_loader.py @@ -71,7 +71,6 @@ class MetadataManagerConfig(BaseModel): class FeatureLoaderConfig(BaseModel): - config_path: str feature_schema: str metadata_schema: str @@ -81,7 +80,7 @@ class PipelineConfig(BaseModel): module: ModuleConfig schedule: ScheduleConfig dependencies: Optional[List[DependencyConfig]] = [] - target: Optional[TargetConfig] = None + target: TargetConfig = None metadata_manager: Optional[MetadataManagerConfig] = None feature_loader: Optional[FeatureLoaderConfig] = None extras: Optional[Dict] = {} diff --git a/rialto/runner/runner.py b/rialto/runner/runner.py index 0178f05..ac9d6bc 100644 --- a/rialto/runner/runner.py +++ b/rialto/runner/runner.py @@ -39,8 +39,6 @@ def __init__( spark: SparkSession, config_path: str, run_date: str = None, - date_from: str = None, - date_until: str = None, rerun: bool = False, op: str = None, skip_dependencies: bool = False, @@ -49,9 +47,6 @@ def __init__( self.spark = spark self.config = get_pipelines_config(config_path, overrides) self.reader = TableReader(spark) - - self.date_from = date_from - self.date_until = date_until self.rerun = rerun self.skip_dependencies = skip_dependencies self.op = op @@ -61,19 +56,15 @@ def __init__( run_date = DateManager.str_to_date(run_date) else: run_date = date.today() - if self.date_from: - self.date_from = DateManager.str_to_date(date_from) - if self.date_until: - self.date_until = DateManager.str_to_date(date_until) - - if not self.date_from: - self.date_from = DateManager.date_subtract( - run_date=run_date, - units=self.config.runner.watched_period_units, - value=self.config.runner.watched_period_value, - ) - if not self.date_until: - self.date_until = run_date + + self.date_from = DateManager.date_subtract( + run_date=run_date, + units=self.config.runner.watched_period_units, + value=self.config.runner.watched_period_value, + ) + + self.date_until = run_date + if self.date_from > self.date_until: raise ValueError(f"Invalid date range from {self.date_from} until {self.date_until}") logger.info(f"Running period from {self.date_from} until {self.date_until}") diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py index 6496a2d..54cb4a4 100644 --- a/tests/jobs/test_decorators.py +++ b/tests/jobs/test_decorators.py @@ -25,6 +25,13 @@ def test_dataset_decorator(): assert test_dataset == "dataset_return" +def test_config_decorator(): + _ = import_module("tests.jobs.test_job.test_job") + test_dataset = Resolver.resolve("custom_config") + + assert test_dataset == "config_return" + + def _rialto_import_stub(module_name, class_name): module = import_module(module_name) class_obj = getattr(module, class_name) diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py index bc3cb69..3d648b5 100644 --- a/tests/jobs/test_job/test_job.py +++ b/tests/jobs/test_job/test_job.py @@ -11,9 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from rialto.jobs.decorators import config, datasource, job -from rialto.jobs.decorators import datasource, job +@config +def custom_config(): + return "config_return" @datasource diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py index 9f16ea0..e23eee8 100644 --- a/tests/runner/test_runner.py +++ b/tests/runner/test_runner.py @@ -108,10 +108,10 @@ def test_init_dates(spark): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", + overrides={"runner.watched_period_units": "weeks", "runner.watched_period_value": 2}, ) - assert runner.date_from == DateManager.str_to_date("2023-03-01") + assert runner.date_from == DateManager.str_to_date("2023-03-17") assert runner.date_until == DateManager.str_to_date("2023-03-31") runner = Runner( @@ -156,8 +156,7 @@ def test_check_dates_have_partition(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", ) runner.reader = MockReader(spark) dates = ["2023-03-04", "2023-03-05", "2023-03-06"] @@ -173,8 +172,7 @@ def test_check_dates_have_partition_no_table(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", ) dates = ["2023-03-04", "2023-03-05", "2023-03-06"] dates = [DateManager.str_to_date(d) for d in dates] @@ -193,8 +191,7 @@ def test_check_dependencies(spark, mocker, r_date, expected): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", ) runner.reader = MockReader(spark) res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date)) @@ -207,8 +204,7 @@ def test_check_no_dependencies(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", ) runner.reader = MockReader(spark) res = runner.check_dependencies(runner.config.pipelines[1], DateManager.str_to_date("2023-03-05")) @@ -221,8 +217,8 @@ def test_select_dates(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-01", - date_until="2023-03-31", + run_date="2023-03-31", + overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 1}, ) runner.reader = MockReader(spark) @@ -243,8 +239,8 @@ def test_select_dates_all_done(spark, mocker): runner = Runner( spark, config_path="tests/runner/transformations/config.yaml", - date_from="2023-03-02", - date_until="2023-03-02", + run_date="2023-03-02", + overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 0}, ) runner.reader = MockReader(spark)