@@ -127,7 +126,7 @@ def _make_header(target: str, start: datetime):
- Jobs started {str(start).split('.')[0]}, targeting {target}
+ Jobs started {str(start).split('.')[0]}
|
@@ -228,14 +227,14 @@ def _make_insights(records: List[Record]):
"""
@staticmethod
- def make_report(target: str, start: datetime, records: List[Record]) -> str:
+ def make_report(start: datetime, records: List[Record]) -> str:
"""Create html email report"""
html = [
"""
""",
HTMLMessage._head(),
HTMLMessage._body_open(),
- HTMLMessage._make_header(target, start),
+ HTMLMessage._make_header(start),
HTMLMessage._make_overview(records),
HTMLMessage._make_insights(records),
HTMLMessage._body_close(),
diff --git a/rialto/runner/transformation.py b/rialto/runner/transformation.py
index 210cb0b..5b6f2eb 100644
--- a/rialto/runner/transformation.py
+++ b/rialto/runner/transformation.py
@@ -16,12 +16,13 @@
import abc
import datetime
-from typing import Dict
from pyspark.sql import DataFrame, SparkSession
-from rialto.common import TableReader
+from rialto.common import DataReader
+from rialto.loader import PysparkFeatureLoader
from rialto.metadata import MetadataManager
+from rialto.runner.config_loader import PipelineConfig
class Transformation(metaclass=abc.ABCMeta):
@@ -30,11 +31,12 @@ class Transformation(metaclass=abc.ABCMeta):
@abc.abstractmethod
def run(
self,
- reader: TableReader,
+ reader: DataReader,
run_date: datetime.date,
spark: SparkSession = None,
+ config: PipelineConfig = None,
metadata_manager: MetadataManager = None,
- dependencies: Dict = None,
+ feature_loader: PysparkFeatureLoader = None,
) -> DataFrame:
"""
Run the transformation
@@ -42,7 +44,9 @@ def run(
:param reader: data store api object
:param run_date: date
:param spark: spark session
- :param metadata_manager: metadata api object
+ :param config: pipeline config
+ :param metadata_manager: metadata manager
+ :param feature_loader: feature loader
:return: dataframe
"""
raise NotImplementedError
diff --git a/rialto/runner/utils.py b/rialto/runner/utils.py
new file mode 100644
index 0000000..5af1723
--- /dev/null
+++ b/rialto/runner/utils.py
@@ -0,0 +1,104 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__all__ = ["load_module", "table_exists", "get_partitions", "init_tools", "find_dependency"]
+
+from datetime import date
+from importlib import import_module
+from typing import List, Tuple
+
+from pyspark.sql import SparkSession
+
+from rialto.common import DataReader
+from rialto.loader import PysparkFeatureLoader
+from rialto.metadata import MetadataManager
+from rialto.runner.config_loader import ModuleConfig, PipelineConfig
+from rialto.runner.table import Table
+from rialto.runner.transformation import Transformation
+
+
+def load_module(cfg: ModuleConfig) -> Transformation:
+ """
+ Load feature group
+
+ :param cfg: Feature configuration
+ :return: Transformation object
+ """
+ module = import_module(cfg.python_module)
+ class_obj = getattr(module, cfg.python_class)
+ return class_obj()
+
+
+def table_exists(spark: SparkSession, table: str) -> bool:
+ """
+ Check table exists in spark catalog
+
+ :param table: full table path
+ :return: bool
+ """
+ return spark.catalog.tableExists(table)
+
+
+def get_partitions(reader: DataReader, table: Table) -> List[date]:
+ """
+ Get partition values
+
+ :param table: Table object
+ :return: List of partition values
+ """
+ rows = (
+ reader.get_table(table.get_table_path(), date_column=table.partition)
+ .select(table.partition)
+ .distinct()
+ .collect()
+ )
+ return [r[table.partition] for r in rows]
+
+
+def init_tools(spark: SparkSession, pipeline: PipelineConfig) -> Tuple[MetadataManager, PysparkFeatureLoader]:
+ """
+ Initialize metadata manager and feature loader
+
+ :param spark: Spark session
+ :param pipeline: Pipeline configuration
+ :return: MetadataManager and PysparkFeatureLoader
+ """
+ if pipeline.metadata_manager is not None:
+ metadata_manager = MetadataManager(spark, pipeline.metadata_manager.metadata_schema)
+ else:
+ metadata_manager = None
+
+ if pipeline.feature_loader is not None:
+ feature_loader = PysparkFeatureLoader(
+ spark,
+ feature_schema=pipeline.feature_loader.feature_schema,
+ metadata_schema=pipeline.feature_loader.metadata_schema,
+ )
+ else:
+ feature_loader = None
+ return metadata_manager, feature_loader
+
+
+def find_dependency(config: PipelineConfig, name: str):
+ """
+ Get dependency from config
+
+ :param config: Pipeline configuration
+ :param name: Dependency name
+ :return: Dependency object
+ """
+ for dep in config.dependencies:
+ if dep.name == name:
+ return dep
+ return None
diff --git a/tests/common/test_yaml.py b/tests/common/test_yaml.py
new file mode 100644
index 0000000..9d63b66
--- /dev/null
+++ b/tests/common/test_yaml.py
@@ -0,0 +1,81 @@
+import os
+
+import pytest
+import yaml
+
+from rialto.common.env_yaml import EnvLoader
+
+
+def test_plain():
+ data = {"a": "string_value", "b": 2}
+ cfg = """
+ a: string_value
+ b: 2
+ """
+ assert yaml.load(cfg, EnvLoader) == data
+
+
+def test_full_sub_default():
+ data = {"a": "default_value", "b": 2}
+ cfg = """
+ a: ${EMPTY_VAR:default_value}
+ b: 2
+ """
+ assert yaml.load(cfg, EnvLoader) == data
+
+
+def test_full_sub_env():
+ os.environ["FILLED_VAR"] = "env_value"
+ data = {"a": "env_value", "b": 2}
+ cfg = """
+ a: ${FILLED_VAR:default_value}
+ b: 2
+ """
+ assert yaml.load(cfg, EnvLoader) == data
+
+
+def test_partial_sub_start():
+ data = {"a": "start_string", "b": 2}
+ cfg = """
+ a: ${START_VAR:start}_string
+ b: 2
+ """
+ assert yaml.load(cfg, EnvLoader) == data
+
+
+def test_partial_sub_end():
+ data = {"a": "string_end", "b": 2}
+ cfg = """
+ a: string_${END_VAR:end}
+ b: 2
+ """
+ assert yaml.load(cfg, EnvLoader) == data
+
+
+def test_partial_sub_mid():
+ data = {"a": "string_mid_sub", "b": 2}
+ cfg = """
+ a: string_${MID_VAR:mid}_sub
+ b: 2
+ """
+ assert yaml.load(cfg, EnvLoader) == data
+
+
+def test_partial_sub_no_default_no_value():
+ with pytest.raises(Exception) as e:
+ cfg = """
+ a: string_${MANDATORY_VAL_MISSING}_sub
+ b: 2
+ """
+ assert yaml.load(cfg, EnvLoader)
+ assert str(e.value) == "Environment variable MANDATORY_VAL_MISSING has no assigned value"
+
+
+def test_partial_sub_no_default():
+ os.environ["MANDATORY_VAL"] = "mandatory_value"
+ data = {"a": "string_mandatory_value_sub", "b": 2}
+ cfg = """
+ a: string_${MANDATORY_VAL}_sub
+ b: 2
+ """
+ assert yaml.load(cfg, EnvLoader) == data
diff --git a/tests/jobs/dependency_checks_job/complex_dependency_job.py b/tests/jobs/dependency_checks_job/complex_dependency_job.py
new file mode 100644
index 0000000..eabb70a
--- /dev/null
+++ b/tests/jobs/dependency_checks_job/complex_dependency_job.py
@@ -0,0 +1,26 @@
+import tests.jobs.dependency_checks_job.datasources_a as a
+import tests.jobs.dependency_checks_job.datasources_b as b
+from rialto.jobs import job, register_dependency_callable, register_dependency_module
+
+# module "A" has i(), j(), k()
+# module "B" has i(j), and dependency on module C
+# module "C" has j(), k()
+
+register_dependency_module(b)
+register_dependency_callable(a.j)
+
+
+@job
+def complex_dependency_job(i, j):
+ # If we import module B, and A.j, we should not see any conflicts, because:
+ # A.i won't get imported, thus won't clash with B.i
+ # B has no j it only sees C.j as registered dependency
+
+ assert i == "B.i-C.j"
+ assert j == "A.j"
+
+
+@job
+def unimported_dependency_job(k):
+ # k is in both A and C, but it's not imported here, thus won't get resolved
+ pass
diff --git a/tests/jobs/dependency_checks_job/datasources_a.py b/tests/jobs/dependency_checks_job/datasources_a.py
new file mode 100644
index 0000000..f8ff293
--- /dev/null
+++ b/tests/jobs/dependency_checks_job/datasources_a.py
@@ -0,0 +1,16 @@
+from rialto.jobs import datasource
+
+
+@datasource
+def i():
+ return "A.i"
+
+
+@datasource
+def j():
+ return "A.j"
+
+
+@datasource
+def k():
+ return "A.k"
diff --git a/tests/jobs/dependency_checks_job/datasources_b.py b/tests/jobs/dependency_checks_job/datasources_b.py
new file mode 100644
index 0000000..fce58bc
--- /dev/null
+++ b/tests/jobs/dependency_checks_job/datasources_b.py
@@ -0,0 +1,9 @@
+import tests.jobs.dependency_checks_job.datasources_c as c
+from rialto.jobs import datasource, register_dependency_module
+
+register_dependency_module(c)
+
+
+@datasource
+def i(j):
+ return f"B.i-{j}"
diff --git a/tests/jobs/dependency_checks_job/datasources_c.py b/tests/jobs/dependency_checks_job/datasources_c.py
new file mode 100644
index 0000000..5a08eb0
--- /dev/null
+++ b/tests/jobs/dependency_checks_job/datasources_c.py
@@ -0,0 +1,11 @@
+from rialto.jobs import datasource
+
+
+@datasource
+def j():
+ return "C.j"
+
+
+@datasource
+def k():
+ return "C.k"
diff --git a/tests/jobs/dependency_checks_job/dependency_checks_job.py b/tests/jobs/dependency_checks_job/dependency_checks_job.py
new file mode 100644
index 0000000..5952705
--- /dev/null
+++ b/tests/jobs/dependency_checks_job/dependency_checks_job.py
@@ -0,0 +1,32 @@
+import tests.jobs.dependency_checks_job.main_datasources as ds
+from rialto.jobs import job, register_dependency_module
+
+register_dependency_module(ds)
+
+
+@job
+def ok_dependency_job(c):
+ return c + 1
+
+
+@job
+def circular_dependency_job(circle_third):
+ return circle_third + 1
+
+
+@job
+def missing_dependency_job(a, x):
+ return x + a
+
+
+@job
+def self_dependency_job(self_dependency):
+ return self_dependency + 1
+
+
+@job
+def default_dependency_job(run_date, spark, config, table_reader):
+ assert run_date is not None
+ assert spark is not None
+ assert config is not None
+ assert table_reader is not None
diff --git a/tests/jobs/dependency_checks_job/duplicate_dependency_job.py b/tests/jobs/dependency_checks_job/duplicate_dependency_job.py
new file mode 100644
index 0000000..50f49b2
--- /dev/null
+++ b/tests/jobs/dependency_checks_job/duplicate_dependency_job.py
@@ -0,0 +1,15 @@
+import tests.jobs.dependency_checks_job.datasources_a as a
+import tests.jobs.dependency_checks_job.datasources_b as b
+from rialto.jobs import job, register_dependency_module
+
+# module "A" has i(), j(), k()
+# module "B" has i(j), and dependency on module C
+
+register_dependency_module(b)
+register_dependency_module(a)
+
+
+@job
+def duplicate_dependency_job(i):
+ # i is in both A and B
+ pass
diff --git a/tests/jobs/dependency_checks_job/main_datasources.py b/tests/jobs/dependency_checks_job/main_datasources.py
new file mode 100644
index 0000000..8ac2d94
--- /dev/null
+++ b/tests/jobs/dependency_checks_job/main_datasources.py
@@ -0,0 +1,37 @@
+from rialto.jobs import datasource
+
+
+@datasource
+def a():
+ return 1
+
+
+@datasource
+def b(a):
+ return a + 10
+
+
+@datasource
+def c(a, b):
+ # 1 + 11 = 12
+ return a + b
+
+
+@datasource
+def circle_first(circle_second):
+ return circle_second + 1
+
+
+@datasource
+def circle_second(circle_third):
+ return circle_third + 1
+
+
+@datasource
+def circle_third(circle_first):
+ return circle_first + 1
+
+
+@datasource
+def self_dependency(a, b, c, self_dependency):
+ return a
diff --git a/tests/jobs/resources.py b/tests/jobs/resources.py
index 4d33fad..ddb8bf8 100644
--- a/tests/jobs/resources.py
+++ b/tests/jobs/resources.py
@@ -15,11 +15,19 @@
import pandas as pd
-from rialto.jobs.decorators.job_base import JobBase
+from rialto.jobs.job_base import JobBase
+from rialto.jobs.resolver import Resolver
def custom_callable():
- pass
+ return None
+
+
+def asserting_callable():
+ assert Resolver.resolve("run_date")
+ assert Resolver.resolve("config")
+ assert Resolver.resolve("spark")
+ assert Resolver.resolve("table_reader")
class CustomJobNoReturnVal(JobBase):
@@ -41,3 +49,13 @@ def f(spark):
return spark.createDataFrame(df)
return f
+
+
+class CustomJobNoVersion(CustomJobNoReturnVal):
+ def get_job_version(self) -> str:
+ return None
+
+
+def CustomJobAssertResolverSetup(CustomJobNoReturnVal):
+ def get_custom_callable():
+ return asserting_callable
diff --git a/tests/jobs/test_config_holder.py b/tests/jobs/test_config_holder.py
deleted file mode 100644
index 38fadb1..0000000
--- a/tests/jobs/test_config_holder.py
+++ /dev/null
@@ -1,100 +0,0 @@
-# Copyright 2022 ABSA Group Limited
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from datetime import date
-
-import pytest
-
-from rialto.jobs.configuration.config_holder import (
- ConfigException,
- ConfigHolder,
- FeatureStoreConfig,
-)
-
-
-def test_run_date_unset():
- with pytest.raises(ConfigException):
- ConfigHolder.get_run_date()
-
-
-def test_run_date():
- dt = date(2023, 1, 1)
-
- ConfigHolder.set_run_date(dt)
-
- assert ConfigHolder.get_run_date() == dt
-
-
-def test_feature_store_config_unset():
- with pytest.raises(ConfigException):
- ConfigHolder.get_feature_store_config()
-
-
-def test_feature_store_config():
- ConfigHolder.set_feature_store_config("store_schema", "metadata_schema")
-
- fsc = ConfigHolder.get_feature_store_config()
-
- assert type(fsc) is FeatureStoreConfig
- assert fsc.feature_store_schema == "store_schema"
- assert fsc.feature_metadata_schema == "metadata_schema"
-
-
-def test_config_unset():
- config = ConfigHolder.get_config()
-
- assert type(config) is type({})
- assert len(config.items()) == 0
-
-
-def test_config_dict_copied_not_ref():
- """Test that config holder config can't be set from outside"""
- config = ConfigHolder.get_config()
-
- config["test"] = 123
-
- assert "test" not in ConfigHolder.get_config()
-
-
-def test_config():
- ConfigHolder.set_custom_config(hello=123)
- ConfigHolder.set_custom_config(world="test")
-
- config = ConfigHolder.get_config()
-
- assert config["hello"] == 123
- assert config["world"] == "test"
-
-
-def test_config_from_dict():
- ConfigHolder.set_custom_config(**{"dict_item_1": 123, "dict_item_2": 456})
-
- config = ConfigHolder.get_config()
-
- assert config["dict_item_1"] == 123
- assert config["dict_item_2"] == 456
-
-
-def test_dependencies_unset():
- deps = ConfigHolder.get_dependency_config()
- assert len(deps.keys()) == 0
-
-
-def test_dependencies():
- ConfigHolder.set_dependency_config({"hello": 123})
-
- deps = ConfigHolder.get_dependency_config()
-
- assert deps["hello"] == 123
diff --git a/tests/jobs/test_decorators.py b/tests/jobs/test_decorators.py
index e896cec..d1931c3 100644
--- a/tests/jobs/test_decorators.py
+++ b/tests/jobs/test_decorators.py
@@ -14,16 +14,18 @@
from importlib import import_module
-from rialto.jobs.configuration.config_holder import ConfigHolder
-from rialto.jobs.decorators.job_base import JobBase
-from rialto.jobs.decorators.resolver import Resolver
+from rialto.jobs.job_base import JobBase
+from rialto.jobs.module_register import ModuleRegister
def test_dataset_decorator():
_ = import_module("tests.jobs.test_job.test_job")
- test_dataset = Resolver.resolve("dataset")
+ assert ModuleRegister.find_callable("dataset", "tests.jobs.test_job.test_job") is not None
- assert test_dataset == "dataset_return"
+
+def test_config_decorator():
+ _ = import_module("tests.jobs.test_job.test_job")
+ assert ModuleRegister.find_callable("custom_config", "tests.jobs.test_job.test_job") is not None
def _rialto_import_stub(module_name, class_name):
@@ -57,9 +59,19 @@ def test_custom_name_function():
custom_callable = result_class.get_custom_callable()
assert custom_callable() == "custom_job_name_return"
+ job_name = result_class.get_job_name()
+ assert job_name == "custom_job_name"
+
+
+def test_job_disabling_version():
+ result_class = _rialto_import_stub("tests.jobs.test_job.test_job", "disable_version_job_function")
+ assert issubclass(type(result_class), JobBase)
+
+ job_version = result_class.get_job_version()
+ assert job_version is None
+
def test_job_dependencies_registered(spark):
- ConfigHolder.set_custom_config(value=123)
job_class = _rialto_import_stub("tests.jobs.test_job.test_job", "job_asking_for_all_deps")
# asserts part of the run
- job_class.run(spark=spark, run_date=456, reader=789, metadata_manager=None, dependencies=1011)
+ job_class.run(spark=spark, run_date=456, reader=789, config=123, metadata_manager=654, feature_loader=321)
diff --git a/tests/jobs/test_job/test_job.py b/tests/jobs/test_job/test_job.py
index 12baec9..4e47364 100644
--- a/tests/jobs/test_job/test_job.py
+++ b/tests/jobs/test_job/test_job.py
@@ -11,9 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from rialto.jobs import config_parser, datasource, job
-from rialto.jobs.decorators import datasource, job
+@config_parser
+def custom_config():
+ return "config_return"
@datasource
@@ -26,15 +29,21 @@ def job_function():
return "job_function_return"
-@job("custom_job_name")
+@job(custom_name="custom_job_name")
def custom_name_job_function():
return "custom_job_name_return"
+@job(disable_version=True)
+def disable_version_job_function():
+ return "disabled_version_job_return"
+
+
@job
-def job_asking_for_all_deps(spark, run_date, config, dependencies, table_reader):
+def job_asking_for_all_deps(spark, run_date, config, table_reader, metadata_manager, feature_loader):
assert spark is not None
assert run_date == 456
- assert config["value"] == 123
+ assert config == 123
assert table_reader == 789
- assert dependencies == 1011
+ assert metadata_manager == 654
+ assert feature_loader == 321
diff --git a/tests/jobs/test_job_base.py b/tests/jobs/test_job_base.py
index 2cdc741..1514957 100644
--- a/tests/jobs/test_job_base.py
+++ b/tests/jobs/test_job_base.py
@@ -14,44 +14,21 @@
import datetime
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
import pyspark.sql.types
import tests.jobs.resources as resources
-from rialto.jobs.configuration.config_holder import ConfigHolder, FeatureStoreConfig
-from rialto.jobs.decorators.resolver import Resolver
+from rialto.jobs.resolver import Resolver
from rialto.loader import PysparkFeatureLoader
-def test_setup_except_feature_loader(spark):
+def test_setup(spark):
table_reader = MagicMock()
+ config = MagicMock()
date = datetime.date(2023, 1, 1)
- ConfigHolder.set_custom_config(hello=1, world=2)
-
- resources.CustomJobNoReturnVal().run(
- reader=table_reader, run_date=date, spark=spark, metadata_manager=None, dependencies={1: 1}
- )
-
- assert Resolver.resolve("run_date") == date
- assert Resolver.resolve("config") == ConfigHolder.get_config()
- assert Resolver.resolve("dependencies") == ConfigHolder.get_dependency_config()
- assert Resolver.resolve("spark") == spark
- assert Resolver.resolve("table_reader") == table_reader
-
-
-@patch(
- "rialto.jobs.configuration.config_holder.ConfigHolder.get_feature_store_config",
- return_value=FeatureStoreConfig(feature_store_schema="schema", feature_metadata_schema="metadata_schema"),
-)
-def test_setup_feature_loader(spark):
- table_reader = MagicMock()
- date = datetime.date(2023, 1, 1)
-
- resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None)
-
- assert type(Resolver.resolve("feature_loader")) == PysparkFeatureLoader
+ resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=config)
def test_custom_callable_called(spark, mocker):
@@ -60,7 +37,7 @@ def test_custom_callable_called(spark, mocker):
table_reader = MagicMock()
date = datetime.date(2023, 1, 1)
- resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, metadata_manager=None)
+ resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=None)
spy_cc.assert_called_once()
@@ -69,9 +46,7 @@ def test_no_return_vaue_adds_version_timestamp_dataframe(spark):
table_reader = MagicMock()
date = datetime.date(2023, 1, 1)
- result = resources.CustomJobNoReturnVal().run(
- reader=table_reader, run_date=date, spark=spark, metadata_manager=None
- )
+ result = resources.CustomJobNoReturnVal().run(reader=table_reader, run_date=date, spark=spark, config=None)
assert type(result) is pyspark.sql.DataFrame
assert result.columns == ["JOB_NAME", "CREATION_TIME", "VERSION"]
@@ -83,11 +58,19 @@ def test_return_dataframe_forwarded_with_version(spark):
table_reader = MagicMock()
date = datetime.date(2023, 1, 1)
- result = resources.CustomJobReturnsDataFrame().run(
- reader=table_reader, run_date=date, spark=spark, metadata_manager=None
- )
+ result = resources.CustomJobReturnsDataFrame().run(reader=table_reader, run_date=date, spark=spark, config=None)
assert type(result) is pyspark.sql.DataFrame
assert result.columns == ["FIRST", "SECOND", "VERSION"]
assert result.first()["VERSION"] == "job_version"
assert result.count() == 2
+
+
+def test_none_job_version_wont_fill_job_colun(spark):
+ table_reader = MagicMock()
+ date = datetime.date(2023, 1, 1)
+
+ result = resources.CustomJobNoVersion().run(reader=table_reader, run_date=date, spark=spark, config=None)
+
+ assert type(result) is pyspark.sql.DataFrame
+ assert "VERSION" not in result.columns
diff --git a/tests/jobs/test_resolver.py b/tests/jobs/test_resolver.py
index df56b72..443e27b 100644
--- a/tests/jobs/test_resolver.py
+++ b/tests/jobs/test_resolver.py
@@ -13,53 +13,58 @@
# limitations under the License.
import pytest
-from rialto.jobs.decorators.resolver import Resolver, ResolverException
+from rialto.jobs.resolver import Resolver, ResolverException
def test_simple_resolve_custom_name():
def f():
return 7
- Resolver.register_callable(f, "hello")
+ resolver = Resolver()
+ resolver.register_getter(f, "hello")
- assert Resolver.resolve("hello") == 7
+ assert resolver.resolve(lambda hello: hello) == 7
def test_simple_resolve_infer_f_name():
def f():
- return 7
+ return 8
- Resolver.register_callable(f)
+ resolver = Resolver()
+ resolver.register_getter(f)
- assert Resolver.resolve("f") == 7
+ assert resolver.resolve(lambda f: f) == 8
-def test_dependency_resolve():
- def f():
- return 7
-
- def g(f):
- return f + 1
+def test_resolve_non_defined():
+ resolver = Resolver()
+ with pytest.raises(ResolverException):
+ resolver.resolve(lambda x: ...)
- Resolver.register_callable(f)
- Resolver.register_callable(g)
- assert Resolver.resolve("g") == 8
+def test_resolve_multi_dependency():
+ def a(b, c):
+ return b + c
+ def b():
+ return 1
-def test_resolve_non_defined():
- with pytest.raises(ResolverException):
- Resolver.resolve("whatever")
+ def c(d):
+ return d + 10
+ def d():
+ return 100
-def test_register_resolve(mocker):
- def f():
- return 7
+ resolver = Resolver()
+ resolver.register_getter(a)
+ resolver.register_getter(b)
+ resolver.register_getter(c)
+ resolver.register_getter(d)
- mocker.patch("rialto.jobs.decorators.resolver.Resolver.register_callable", return_value="f")
- mocker.patch("rialto.jobs.decorators.resolver.Resolver.resolve")
+ assert resolver.resolve(a) == 111
- Resolver.register_resolve(f)
- Resolver.register_callable.assert_called_once_with(f)
- Resolver.resolve.assert_called_once_with("f")
+def test_register_objects():
+ resolver = Resolver()
+ resolver.register_object(7, "seven")
+ assert resolver.resolve(lambda seven: seven) == 7
diff --git a/tests/jobs/test_test_utils.py b/tests/jobs/test_test_utils.py
index a6b31b2..dcf41ab 100644
--- a/tests/jobs/test_test_utils.py
+++ b/tests/jobs/test_test_utils.py
@@ -11,23 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import pytest
import rialto.jobs.decorators as decorators
+import tests.jobs.dependency_checks_job.complex_dependency_job as complex_dependency_job
+import tests.jobs.dependency_checks_job.dependency_checks_job as dependency_checks_job
+import tests.jobs.dependency_checks_job.duplicate_dependency_job as duplicate_dependency_job
import tests.jobs.test_job.test_job as test_job
-from rialto.jobs.decorators.resolver import Resolver
-from rialto.jobs.decorators.test_utils import disable_job_decorators
+from rialto.jobs.test_utils import disable_job_decorators, resolver_resolves
def test_raw_dataset_patch(mocker):
- spy_rc = mocker.spy(Resolver, "register_callable")
spy_dec = mocker.spy(decorators, "datasource")
with disable_job_decorators(test_job):
assert test_job.dataset() == "dataset_return"
- spy_dec.assert_not_called()
- spy_rc.assert_not_called()
+ spy_dec.assert_not_called()
def test_job_function_patch(mocker):
@@ -36,7 +36,7 @@ def test_job_function_patch(mocker):
with disable_job_decorators(test_job):
assert test_job.job_function() == "job_function_return"
- spy_dec.assert_not_called()
+ spy_dec.assert_not_called()
def test_custom_name_job_function_patch(mocker):
@@ -45,4 +45,64 @@ def test_custom_name_job_function_patch(mocker):
with disable_job_decorators(test_job):
assert test_job.custom_name_job_function() == "custom_job_name_return"
- spy_dec.assert_not_called()
+ spy_dec.assert_not_called()
+
+
+def test_resolver_resolves_ok_job(spark):
+ assert resolver_resolves(spark, dependency_checks_job.ok_dependency_job)
+
+
+def test_resolver_resolves_default_dependency(spark):
+ assert resolver_resolves(spark, dependency_checks_job.default_dependency_job)
+
+
+def test_resolver_fails_circular_dependency(spark):
+ with pytest.raises(Exception) as exc_info:
+ assert resolver_resolves(spark, dependency_checks_job.circular_dependency_job)
+
+ assert exc_info is not None
+ assert str(exc_info.value) == "Circular Dependence in circle_third!"
+
+
+def test_resolver_fails_missing_dependency(spark):
+ with pytest.raises(Exception) as exc_info:
+ assert resolver_resolves(spark, dependency_checks_job.missing_dependency_job)
+
+ assert exc_info is not None
+ assert str(exc_info.value) == "x declaration not found!"
+
+
+def tests_resolver_fails_self_dependency(spark):
+ with pytest.raises(Exception) as exc_info:
+ assert resolver_resolves(spark, dependency_checks_job.self_dependency_job)
+
+ assert exc_info is not None
+ assert str(exc_info.value) == "Circular Dependence in self_dependency!"
+
+
+def test_complex_dependencies_resolves_correctly(spark):
+ assert resolver_resolves(spark, complex_dependency_job.complex_dependency_job)
+
+
+def test_complex_dependencies_fails_on_unimported(spark):
+ with pytest.raises(Exception) as exc_info:
+ assert resolver_resolves(spark, complex_dependency_job.unimported_dependency_job)
+
+ assert exc_info is not None
+ assert str(exc_info.value) == "k declaration not found!"
+
+
+def test_complex_dependencies_fails_on_unimported(spark):
+ with pytest.raises(Exception) as exc_info:
+ assert resolver_resolves(spark, complex_dependency_job.unimported_dependency_job)
+
+ assert exc_info is not None
+ assert str(exc_info.value) == "k declaration not found!"
+
+
+def test_duplicate_dependency_fails_on_duplicate(spark):
+ with pytest.raises(Exception) as exc_info:
+ assert resolver_resolves(spark, duplicate_dependency_job.duplicate_dependency_job)
+
+ assert exc_info is not None
+ assert str(exc_info.value) == f"Multiple functions with the same name i found !"
diff --git a/tests/loader/pyspark/dummy_loaders.py b/tests/loader/pyspark/dummy_loaders.py
deleted file mode 100644
index a2b0cb8..0000000
--- a/tests/loader/pyspark/dummy_loaders.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Copyright 2022 ABSA Group Limited
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from datetime import date
-
-from rialto.loader.data_loader import DataLoader
-
-
-class DummyDataLoader(DataLoader):
- def __init__(self):
- super().__init__()
-
- def read_group(self, group: str, information_date: date):
- return None
diff --git a/tests/loader/pyspark/test_from_cfg.py b/tests/loader/pyspark/test_from_cfg.py
index 3ad653e..dd2049f 100644
--- a/tests/loader/pyspark/test_from_cfg.py
+++ b/tests/loader/pyspark/test_from_cfg.py
@@ -21,7 +21,6 @@
from rialto.loader.config_loader import get_feature_config
from rialto.loader.pyspark_feature_loader import PysparkFeatureLoader
from tests.loader.pyspark.dataframe_builder import dataframe_builder as dfb
-from tests.loader.pyspark.dummy_loaders import DummyDataLoader
@pytest.fixture(scope="session")
@@ -45,7 +44,7 @@ def spark(request):
@pytest.fixture(scope="session")
def loader(spark):
- return PysparkFeatureLoader(spark, DummyDataLoader(), MagicMock())
+ return PysparkFeatureLoader(spark, MagicMock(), MagicMock())
VALID_LIST = [(["a"], ["a"]), (["a"], ["a", "b", "c"]), (["c", "a"], ["a", "b", "c"])]
@@ -90,7 +89,7 @@ def __call__(self, *args, **kwargs):
metadata = MagicMock()
monkeypatch.setattr(metadata, "get_group", GroupMd())
- loader = PysparkFeatureLoader(spark, DummyDataLoader(), "")
+ loader = PysparkFeatureLoader(spark, "", "")
loader.metadata = metadata
base = dfb(spark, data=r.base_frame_data, columns=r.base_frame_columns)
@@ -105,7 +104,7 @@ def __call__(self, *args, **kwargs):
def test_get_group_metadata(spark, mocker):
mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", return_value=7)
- loader = PysparkFeatureLoader(spark, DummyDataLoader(), "")
+ loader = PysparkFeatureLoader(spark, "", "")
ret_val = loader.get_group_metadata("group_name")
assert ret_val == 7
@@ -115,7 +114,7 @@ def test_get_group_metadata(spark, mocker):
def test_get_feature_metadata(spark, mocker):
mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_feature", return_value=8)
- loader = PysparkFeatureLoader(spark, DummyDataLoader(), "")
+ loader = PysparkFeatureLoader(spark, "", "")
ret_val = loader.get_feature_metadata("group_name", "feature")
assert ret_val == 8
@@ -129,7 +128,7 @@ def test_get_metadata_from_cfg(spark, mocker):
)
mocker.patch("rialto.loader.pyspark_feature_loader.MetadataManager.get_group", side_effect=lambda g: {"B": 10}[g])
- loader = PysparkFeatureLoader(spark, DummyDataLoader(), "")
+ loader = PysparkFeatureLoader(spark, "", "")
metadata = loader.get_metadata_from_cfg("tests/loader/pyspark/example_cfg.yaml")
assert metadata["B_F1"] == 1
diff --git a/tests/runner/conftest.py b/tests/runner/conftest.py
index 44f0c09..4e527be 100644
--- a/tests/runner/conftest.py
+++ b/tests/runner/conftest.py
@@ -39,6 +39,4 @@ def spark(request):
@pytest.fixture(scope="function")
def basic_runner(spark):
- return Runner(
- spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31"
- )
+ return Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31")
diff --git a/tests/runner/overrider.yaml b/tests/runner/overrider.yaml
new file mode 100644
index 0000000..3029730
--- /dev/null
+++ b/tests/runner/overrider.yaml
@@ -0,0 +1,86 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+runner:
+ watched_period_units: "months"
+ watched_period_value: 2
+ mail:
+ sender: test@testing.org
+ smtp: server.test
+ to:
+ - developer@testing.org
+ - developer2@testing.org
+ subject: test report
+pipelines:
+ - name: SimpleGroup
+ module:
+ python_module: tests.runner.transformations
+ python_class: SimpleGroup
+ schedule:
+ frequency: weekly
+ day: 7
+ info_date_shift:
+ - value: 3
+ units: days
+ - value: 2
+ units: weeks
+ dependencies:
+ - table: source.schema.dep1
+ interval:
+ units: "days"
+ value: 1
+ date_col: "DATE"
+ - table: source.schema.dep2
+ interval:
+ units: "months"
+ value: 3
+ date_col: "DATE"
+ target:
+ target_schema: catalog.schema
+ target_partition_column: "INFORMATION_DATE"
+ feature_loader:
+ config_path: path/to/config.yaml
+ feature_schema: catalog.feature_tables
+ metadata_schema: catalog.metadata
+ metadata_manager:
+ metadata_schema: catalog.metadata
+ - name: OtherGroup
+ module:
+ python_module: tests.runner.transformations
+ python_class: SimpleGroup
+ schedule:
+ frequency: weekly
+ day: 7
+ info_date_shift:
+ - value: 3
+ units: days
+ dependencies:
+ - table: source.schema.dep1
+ name: source1
+ interval:
+ units: "days"
+ value: 1
+ date_col: "DATE"
+ - table: source.schema.dep2
+ name: source2
+ interval:
+ units: "months"
+ value: 3
+ date_col: "batch"
+ target:
+ target_schema: catalog.schema
+ target_partition_column: "INFORMATION_DATE"
+ extras:
+ some_value: 3
+ some_other_value: cat
diff --git a/tests/runner/test_date_manager.py b/tests/runner/test_date_manager.py
index 9088e0c..73b61b8 100644
--- a/tests/runner/test_date_manager.py
+++ b/tests/runner/test_date_manager.py
@@ -144,7 +144,7 @@ def test_run_dates_invalid():
[(7, "2023-02-26"), (3, "2023-03-02"), (-5, "2023-03-10"), (0, "2023-03-05")],
)
def test_to_info_date(shift, res):
- cfg = ScheduleConfig(frequency="daily", info_date_shift=IntervalConfig(units="days", value=shift))
+ cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units="days", value=shift)])
base = DateManager.str_to_date("2023-03-05")
info = DateManager.to_info_date(base, cfg)
assert DateManager.str_to_date(res) == info
@@ -155,7 +155,7 @@ def test_to_info_date(shift, res):
[("days", "2023-03-02"), ("weeks", "2023-02-12"), ("months", "2022-12-05"), ("years", "2020-03-05")],
)
def test_info_date_shift_units(unit, result):
- cfg = ScheduleConfig(frequency="daily", info_date_shift=IntervalConfig(units=unit, value=3))
+ cfg = ScheduleConfig(frequency="daily", info_date_shift=[IntervalConfig(units=unit, value=3)])
base = DateManager.str_to_date("2023-03-05")
info = DateManager.to_info_date(base, cfg)
assert DateManager.str_to_date(result) == info
diff --git a/tests/runner/test_overrides.py b/tests/runner/test_overrides.py
new file mode 100644
index 0000000..17fcdbe
--- /dev/null
+++ b/tests/runner/test_overrides.py
@@ -0,0 +1,137 @@
+# Copyright 2022 ABSA Group Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import pytest
+
+from rialto.runner import Runner
+
+
+def test_overrides_simple(spark):
+ runner = Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={"runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"]},
+ )
+ assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"]
+
+
+def test_overrides_array_index(spark):
+ runner = Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={"runner.mail.to[1]": "a@b.c"},
+ )
+ assert runner.config.runner.mail.to == ["developer@testing.org", "a@b.c"]
+
+
+def test_overrides_array_append(spark):
+ runner = Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={"runner.mail.to[-1]": "test"},
+ )
+ assert runner.config.runner.mail.to == ["developer@testing.org", "developer2@testing.org", "test"]
+
+
+def test_overrides_array_lookup(spark):
+ runner = Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={"pipelines[name=SimpleGroup].target.target_schema": "new_schema"},
+ )
+ assert runner.config.pipelines[0].target.target_schema == "new_schema"
+
+
+def test_overrides_combined(spark):
+ runner = Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={
+ "runner.mail.to": ["x@b.c", "y@b.c", "z@b.c"],
+ "pipelines[name=SimpleGroup].target.target_schema": "new_schema",
+ "pipelines[name=SimpleGroup].schedule.info_date_shift[0].value": 1,
+ },
+ )
+ assert runner.config.runner.mail.to == ["x@b.c", "y@b.c", "z@b.c"]
+ assert runner.config.pipelines[0].target.target_schema == "new_schema"
+ assert runner.config.pipelines[0].schedule.info_date_shift[0].value == 1
+
+
+def test_index_out_of_range(spark):
+ with pytest.raises(IndexError) as error:
+ Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={"runner.mail.to[8]": "test"},
+ )
+ assert error.value.args[0] == "Index 8 out of bounds for key to[8]"
+
+
+def test_invalid_index_key(spark):
+ with pytest.raises(ValueError) as error:
+ Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={"runner.mail.test[8]": "test"},
+ )
+ assert error.value.args[0] == "Invalid key test"
+
+
+def test_invalid_key(spark):
+ with pytest.raises(ValueError) as error:
+ Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={"runner.mail.test.param": "test"},
+ )
+ assert error.value.args[0] == "Invalid key test"
+
+
+def test_replace_section(spark):
+ runner = Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={
+ "pipelines[name=SimpleGroup].feature_loader": {
+ "config_path": "features_cfg.yaml",
+ "feature_schema": "catalog.features",
+ "metadata_schema": "catalog.metadata",
+ }
+ },
+ )
+ assert runner.config.pipelines[0].feature_loader.feature_schema == "catalog.features"
+
+
+def test_add_section(spark):
+ runner = Runner(
+ spark,
+ config_path="tests/runner/overrider.yaml",
+ run_date="2023-03-31",
+ overrides={
+ "pipelines[name=OtherGroup].feature_loader": {
+ "config_path": "features_cfg.yaml",
+ "feature_schema": "catalog.features",
+ "metadata_schema": "catalog.metadata",
+ }
+ },
+ )
+ assert runner.config.pipelines[1].feature_loader.feature_schema == "catalog.features"
diff --git a/tests/runner/test_runner.py b/tests/runner/test_runner.py
index 0459411..e23eee8 100644
--- a/tests/runner/test_runner.py
+++ b/tests/runner/test_runner.py
@@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from collections import namedtuple
from datetime import datetime
from typing import Optional
import pytest
from pyspark.sql import DataFrame
+import rialto.runner.utils as utils
from rialto.common.table_reader import DataReader
-from rialto.jobs.configuration.config_holder import ConfigHolder
from rialto.runner.runner import DateManager, Runner
from rialto.runner.table import Table
from tests.runner.runner_resources import (
@@ -38,8 +37,8 @@ def __init__(self, spark):
def get_table(
self,
table: str,
- info_date_from: Optional[datetime.date] = None,
- info_date_to: Optional[datetime.date] = None,
+ date_from: Optional[datetime.date] = None,
+ date_to: Optional[datetime.date] = None,
date_column: str = None,
uppercase_columns: bool = False,
) -> DataFrame:
@@ -53,114 +52,79 @@ def get_table(
def get_latest(
self,
table: str,
- until: Optional[datetime.date] = None,
+ date_until: Optional[datetime.date] = None,
date_column: str = None,
uppercase_columns: bool = False,
) -> DataFrame:
pass
-def test_table_exists(spark, mocker, basic_runner):
+def test_table_exists(spark, mocker):
mock = mocker.patch("pyspark.sql.Catalog.tableExists", return_value=True)
- basic_runner._table_exists("abc")
+ utils.table_exists(spark, "abc")
mock.assert_called_once_with("abc")
-def test_infer_column(spark, mocker, basic_runner):
- column = namedtuple("catalog", ["name", "isPartition"])
- catalog = [column("a", True), column("b", False), column("c", False)]
-
- mock = mocker.patch("pyspark.sql.Catalog.listColumns", return_value=catalog)
- partition = basic_runner._delta_partition("aaa")
- assert partition == "a"
- mock.assert_called_once_with("aaa")
-
-
def test_load_module(spark, basic_runner):
- module = basic_runner._load_module(basic_runner.config.pipelines[0].module)
+ module = utils.load_module(basic_runner.config.pipelines[0].module)
assert isinstance(module, SimpleGroup)
def test_generate(spark, mocker, basic_runner):
run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run")
group = SimpleGroup()
- basic_runner._generate(group, DateManager.str_to_date("2023-01-31"))
+ config = basic_runner.config.pipelines[0]
+ basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), config)
+
run.assert_called_once_with(
reader=basic_runner.reader,
run_date=DateManager.str_to_date("2023-01-31"),
spark=spark,
- metadata_manager=basic_runner.metadata,
- dependencies=None,
+ config=config,
+ metadata_manager=None,
+ feature_loader=None,
)
def test_generate_w_dep(spark, mocker, basic_runner):
run = mocker.patch("tests.runner.transformations.simple_group.SimpleGroup.run")
group = SimpleGroup()
- basic_runner._generate(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2].dependencies)
+ basic_runner._execute(group, DateManager.str_to_date("2023-01-31"), basic_runner.config.pipelines[2])
run.assert_called_once_with(
reader=basic_runner.reader,
run_date=DateManager.str_to_date("2023-01-31"),
spark=spark,
- metadata_manager=basic_runner.metadata,
- dependencies={
- "source1": basic_runner.config.pipelines[2].dependencies[0],
- "source2": basic_runner.config.pipelines[2].dependencies[1],
- },
+ config=basic_runner.config.pipelines[2],
+ metadata_manager=None,
+ feature_loader=None,
)
def test_init_dates(spark):
- runner = Runner(
- spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31"
- )
+ runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31")
assert runner.date_from == DateManager.str_to_date("2023-01-31")
assert runner.date_until == DateManager.str_to_date("2023-03-31")
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
- feature_metadata_schema="",
- date_from="2023-03-01",
- date_until="2023-03-31",
+ run_date="2023-03-31",
+ overrides={"runner.watched_period_units": "weeks", "runner.watched_period_value": 2},
)
- assert runner.date_from == DateManager.str_to_date("2023-03-01")
+ assert runner.date_from == DateManager.str_to_date("2023-03-17")
assert runner.date_until == DateManager.str_to_date("2023-03-31")
runner = Runner(
spark,
config_path="tests/runner/transformations/config2.yaml",
- feature_metadata_schema="",
run_date="2023-03-31",
)
assert runner.date_from == DateManager.str_to_date("2023-02-24")
assert runner.date_until == DateManager.str_to_date("2023-03-31")
-def test_possible_run_dates(spark):
- runner = Runner(
- spark,
- config_path="tests/runner/transformations/config.yaml",
- feature_metadata_schema="",
- date_from="2023-03-01",
- date_until="2023-03-31",
- )
-
- dates = runner.get_possible_run_dates(runner.config.pipelines[0].schedule)
- expected = ["2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
- assert dates == [DateManager.str_to_date(d) for d in expected]
-
-
-def test_info_dates(spark, basic_runner):
- run = ["2023-02-05", "2023-02-12", "2023-02-19", "2023-02-26", "2023-03-05"]
- run = [DateManager.str_to_date(d) for d in run]
- info = basic_runner.get_info_dates(basic_runner.config.pipelines[0].schedule, run)
- expected = ["2023-02-02", "2023-02-09", "2023-02-16", "2023-02-23", "2023-03-02"]
- assert info == [DateManager.str_to_date(d) for d in expected]
-
-
def test_completion(spark, mocker, basic_runner):
- mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+ mocker.patch("rialto.runner.utils.table_exists", return_value=True)
basic_runner.reader = MockReader(spark)
@@ -173,11 +137,9 @@ def test_completion(spark, mocker, basic_runner):
def test_completion_rerun(spark, mocker, basic_runner):
- mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+ mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
- runner = Runner(
- spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", run_date="2023-03-31"
- )
+ runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", run_date="2023-03-31")
runner.reader = MockReader(spark)
dates = ["2023-02-26", "2023-03-05", "2023-03-12", "2023-03-19", "2023-03-26"]
@@ -189,14 +151,12 @@ def test_completion_rerun(spark, mocker, basic_runner):
def test_check_dates_have_partition(spark, mocker):
- mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+ mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
- feature_metadata_schema="",
- date_from="2023-03-01",
- date_until="2023-03-31",
+ run_date="2023-03-31",
)
runner.reader = MockReader(spark)
dates = ["2023-03-04", "2023-03-05", "2023-03-06"]
@@ -207,14 +167,12 @@ def test_check_dates_have_partition(spark, mocker):
def test_check_dates_have_partition_no_table(spark, mocker):
- mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=False)
+ mocker.patch("rialto.runner.runner.utils.table_exists", return_value=False)
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
- feature_metadata_schema="",
- date_from="2023-03-01",
- date_until="2023-03-31",
+ run_date="2023-03-31",
)
dates = ["2023-03-04", "2023-03-05", "2023-03-06"]
dates = [DateManager.str_to_date(d) for d in dates]
@@ -228,14 +186,12 @@ def test_check_dates_have_partition_no_table(spark, mocker):
[("2023-02-26", False), ("2023-03-05", True)],
)
def test_check_dependencies(spark, mocker, r_date, expected):
- mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+ mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
- feature_metadata_schema="",
- date_from="2023-03-01",
- date_until="2023-03-31",
+ run_date="2023-03-31",
)
runner.reader = MockReader(spark)
res = runner.check_dependencies(runner.config.pipelines[0], DateManager.str_to_date(r_date))
@@ -243,14 +199,12 @@ def test_check_dependencies(spark, mocker, r_date, expected):
def test_check_no_dependencies(spark, mocker):
- mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+ mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
- feature_metadata_schema="",
- date_from="2023-03-01",
- date_until="2023-03-31",
+ run_date="2023-03-31",
)
runner.reader = MockReader(spark)
res = runner.check_dependencies(runner.config.pipelines[1], DateManager.str_to_date("2023-03-05"))
@@ -258,14 +212,13 @@ def test_check_no_dependencies(spark, mocker):
def test_select_dates(spark, mocker):
- mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+ mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
- feature_metadata_schema="",
- date_from="2023-03-01",
- date_until="2023-03-31",
+ run_date="2023-03-31",
+ overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 1},
)
runner.reader = MockReader(spark)
@@ -281,14 +234,13 @@ def test_select_dates(spark, mocker):
def test_select_dates_all_done(spark, mocker):
- mocker.patch("rialto.runner.runner.Runner._table_exists", return_value=True)
+ mocker.patch("rialto.runner.runner.utils.table_exists", return_value=True)
runner = Runner(
spark,
config_path="tests/runner/transformations/config.yaml",
- feature_metadata_schema="",
- date_from="2023-03-02",
- date_until="2023-03-02",
+ run_date="2023-03-02",
+ overrides={"runner.watched_period_units": "months", "runner.watched_period_value": 0},
)
runner.reader = MockReader(spark)
@@ -307,9 +259,7 @@ def test_op_selected(spark, mocker):
mocker.patch("rialto.runner.tracker.Tracker.report")
run = mocker.patch("rialto.runner.runner.Runner._run_pipeline")
- runner = Runner(
- spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", op="SimpleGroup"
- )
+ runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="SimpleGroup")
runner()
run.called_once()
@@ -319,42 +269,8 @@ def test_op_bad(spark, mocker):
mocker.patch("rialto.runner.tracker.Tracker.report")
mocker.patch("rialto.runner.runner.Runner._run_pipeline")
- runner = Runner(
- spark, config_path="tests/runner/transformations/config.yaml", feature_metadata_schema="", op="BadOp"
- )
+ runner = Runner(spark, config_path="tests/runner/transformations/config.yaml", op="BadOp")
with pytest.raises(ValueError) as exception:
runner()
assert str(exception.value) == "Unknown operation selected: BadOp"
-
-
-def test_custom_config(spark, mocker):
- cc_spy = mocker.spy(ConfigHolder, "set_custom_config")
- custom_config = {"cc": 42}
-
- _ = Runner(spark, config_path="tests/runner/transformations/config.yaml", custom_job_config=custom_config)
-
- cc_spy.assert_called_once_with(cc=42)
-
-
-def test_feature_store_config(spark, mocker):
- fs_spy = mocker.spy(ConfigHolder, "set_feature_store_config")
-
- _ = Runner(
- spark,
- config_path="tests/runner/transformations/config.yaml",
- feature_store_schema="schema",
- feature_metadata_schema="metadata",
- )
-
- fs_spy.assert_called_once_with("schema", "metadata")
-
-
-def test_no_configs(spark, mocker):
- cc_spy = mocker.spy(ConfigHolder, "set_custom_config")
- fs_spy = mocker.spy(ConfigHolder, "set_feature_store_config")
-
- _ = Runner(spark, config_path="tests/runner/transformations/config.yaml")
-
- cc_spy.assert_not_called()
- fs_spy.assert_not_called()
diff --git a/tests/runner/transformations/config.yaml b/tests/runner/transformations/config.yaml
index 2bfeaf1..3b72107 100644
--- a/tests/runner/transformations/config.yaml
+++ b/tests/runner/transformations/config.yaml
@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-general:
- target_schema: catalog.schema
- target_partition_column: "INFORMATION_DATE"
+runner:
watched_period_units: "months"
watched_period_value: 2
- job: "run" # run/check
mail:
sender: test@testing.org
smtp: server.test
@@ -34,8 +31,8 @@ pipelines:
frequency: weekly
day: 7
info_date_shift:
- value: 3
- units: days
+ - value: 3
+ units: days
dependencies:
- table: source.schema.dep1
interval:
@@ -47,6 +44,9 @@ pipelines:
units: "months"
value: 3
date_col: "DATE"
+ target:
+ target_schema: catalog.schema
+ target_partition_column: "INFORMATION_DATE"
- name: GroupNoDeps
module:
python_module: tests.runner.transformations
@@ -55,8 +55,8 @@ pipelines:
frequency: weekly
day: 7
info_date_shift:
- value: 3
- units: days
+ - value: 3
+ units: days
- name: NamedDeps
module:
python_module: tests.runner.transformations
@@ -65,8 +65,8 @@ pipelines:
frequency: weekly
day: 7
info_date_shift:
- value: 3
- units: days
+ - value: 3
+ units: days
dependencies:
- table: source.schema.dep1
name: source1
@@ -80,3 +80,6 @@ pipelines:
units: "months"
value: 3
date_col: "batch"
+ target:
+ target_schema: catalog.schema
+ target_partition_column: "INFORMATION_DATE"
diff --git a/tests/runner/transformations/config2.yaml b/tests/runner/transformations/config2.yaml
index a91894b..f7b9604 100644
--- a/tests/runner/transformations/config2.yaml
+++ b/tests/runner/transformations/config2.yaml
@@ -12,12 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-general:
- target_schema: catalog.schema
- target_partition_column: "INFORMATION_DATE"
+runner:
watched_period_units: "weeks"
watched_period_value: 5
- job: "run" # run/check
mail:
sender: test@testing.org
smtp: server.test
@@ -43,3 +40,6 @@ pipelines:
units: "months"
value: 1
date_col: "DATE"
+ target:
+ target_schema: catalog.schema
+ target_partition_column: "INFORMATION_DATE"
diff --git a/tests/runner/transformations/simple_group.py b/tests/runner/transformations/simple_group.py
index fcda5c7..ec2311c 100644
--- a/tests/runner/transformations/simple_group.py
+++ b/tests/runner/transformations/simple_group.py
@@ -18,6 +18,7 @@
from pyspark.sql.types import StructType
from rialto.common import TableReader
+from rialto.loader import PysparkFeatureLoader
from rialto.metadata import MetadataManager
from rialto.runner import Transformation
@@ -28,7 +29,8 @@ def run(
reader: TableReader,
run_date: datetime.date,
spark: SparkSession = None,
- metadata_manager: MetadataManager = None,
- dependencies: Dict = None,
+ config: Dict = None,
+ metadata: MetadataManager = None,
+ feature_loader: PysparkFeatureLoader = None,
) -> DataFrame:
return spark.createDataFrame([], StructType([]))