From 650514c815b9d4f688cbddd5244a24aeb5f8b2c7 Mon Sep 17 00:00:00 2001 From: Artur Dobrogowski Date: Mon, 20 Nov 2023 18:18:00 +0100 Subject: [PATCH] refactor: added run_config and context as optional arguments of nodegrouper --- kedro_vertexai/config.py | 21 +++++++++++---------- kedro_vertexai/dynamic_config.py | 4 +--- kedro_vertexai/generator.py | 5 ++++- kedro_vertexai/grouping.py | 12 +++++++++++- tests/test_config.py | 6 +++--- tests/test_grouping.py | 6 +++--- 6 files changed, 33 insertions(+), 21 deletions(-) diff --git a/kedro_vertexai/config.py b/kedro_vertexai/config.py index fd1b7fe..3f9bb5e 100644 --- a/kedro_vertexai/config.py +++ b/kedro_vertexai/config.py @@ -111,9 +111,7 @@ # the only place to put it to avoid circular dependencies -def dynamic_load_class( - load_class, args: Optional[list] = None, kwargs: Optional[dict] = None -): +def dynamic_load_class(load_class, *args, **kwargs): if args is None: args = [] if kwargs is None: @@ -121,11 +119,11 @@ def dynamic_load_class( try: module_name, class_name = load_class.rsplit(".", 1) logger.info(f"Initializing {class_name}") - cls = getattr(import_module(module_name), class_name) - return cls(*args, **kwargs) + class_load = getattr(import_module(module_name), class_name) + return class_load(*args, **kwargs) except: # noqa: E722 logger.error( - f"Could not dynamically load class {load_class} with it init params, " + f"Could not dynamically load class {load_class} with its init params, " f"make sure it's valid and accessible from the current Python interpreter", exc_info=True, ) @@ -137,12 +135,15 @@ class GroupingConfig(BaseModel): @validator("cls") def class_valid(cls, v, values, **kwargs): - c = dynamic_load_class(v) - if c is None: - raise ValueError(f"Could not validate grouping class {v} with its params.") try: if "params" in values: - c(**values["params"]) + c = dynamic_load_class(v, None, None, **values["params"]) + else: + c = dynamic_load_class(v, None, None) + if c is None: + raise ValueError( + f"Could not validate grouping class {v} with its params." + ) except: # noqa: E722 raise ValueError(f"Invalid parameters for grouping class {v}.") return v diff --git a/kedro_vertexai/dynamic_config.py b/kedro_vertexai/dynamic_config.py index 40786c3..1aada3e 100644 --- a/kedro_vertexai/dynamic_config.py +++ b/kedro_vertexai/dynamic_config.py @@ -20,9 +20,7 @@ def build( config: PluginConfig, provider_config: DynamicConfigProviderConfig, ) -> "DynamicConfigProvider": - return dynamic_load_class( - provider_config.cls, args=[config], kwargs=provider_config.params - ) + return dynamic_load_class(provider_config.cls, config, **provider_config.params) def __init__(self, config: PluginConfig, **kwargs): self.config = config diff --git a/kedro_vertexai/generator.py b/kedro_vertexai/generator.py index 4cad28c..82e3a87 100644 --- a/kedro_vertexai/generator.py +++ b/kedro_vertexai/generator.py @@ -51,7 +51,10 @@ def __init__(self, config, project_name, context, run_name: str): self.run_config: RunConfig = config.run_config self.catalog = context.config_loader.get("catalog*") self.grouping: NodeGrouper = dynamic_load_class( - self.run_config.grouping.cls, kwargs=self.run_config.grouping.params + self.run_config.grouping.cls, + context, + self.run_config, + **self.run_config.grouping.params, ) def get_pipeline_name(self): diff --git a/kedro_vertexai/grouping.py b/kedro_vertexai/grouping.py index e18b01e..8d23095 100644 --- a/kedro_vertexai/grouping.py +++ b/kedro_vertexai/grouping.py @@ -2,9 +2,12 @@ from dataclasses import dataclass, field from typing import Dict, Set +from kedro.framework.context import KedroContext from kedro.pipeline.node import Node from toposort import CircularDependencyError, toposort +from kedro_vertexai.config import RunConfig + TagsDict = Dict[str, Set[str]] PipelineDependenciesDict = Dict[Node, Set[Node]] GroupDependenciesDict = Dict[str, Set[str]] @@ -40,6 +43,10 @@ class NodeGrouper(ABC): For each node it tells which set of nodes are parents of them, based on nodes outputs """ + def __init__(self, kedro_context: KedroContext, run_config: RunConfig): + self.context = kedro_context + self.run_config = run_config + def group(self, node_dependencies: PipelineDependenciesDict) -> Grouping: raise NotImplementedError @@ -69,7 +76,10 @@ class TagNodeGrouper(NodeGrouper): """Grouping class that uses special tag prefix convention to aggregate nodes together. Only one such tag is allowed per node.""" - def __init__(self, tag_prefix="group:") -> None: + def __init__( + self, kedro_context: KedroContext, run_config: RunConfig, tag_prefix="group:" + ) -> None: + super().__init__(kedro_context, run_config) self.tag_prefix = tag_prefix def group(self, node_dependencies: PipelineDependenciesDict) -> Grouping: diff --git a/tests/test_config.py b/tests/test_config.py index 2c9b952..39b7f0e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -54,7 +54,7 @@ def test_grouping_config(self): assert ( cfg.run_config.grouping.cls == "kedro_vertexai.grouping.IdentityNodeGrouper" ) - c_obj = dynamic_load_class(cfg.run_config.grouping.cls) + c_obj = dynamic_load_class(cfg.run_config.grouping.cls, None, None) assert isinstance(c_obj, IdentityNodeGrouper) cfg_tag_group = """ @@ -71,7 +71,7 @@ def test_grouping_config(self): cfg = PluginConfig.parse_obj(yaml.safe_load(cfg_tag_group)) assert cfg.run_config.grouping is not None c_obj = dynamic_load_class( - cfg.run_config.grouping.cls, kwargs=cfg.run_config.grouping.params + cfg.run_config.grouping.cls, None, None, **cfg.run_config.grouping.params ) assert isinstance(c_obj, TagNodeGrouper) assert c_obj.tag_prefix == "group:" @@ -91,7 +91,7 @@ def test_grouping_config_error(self, log_error): """ cfg = PluginConfig.parse_obj(yaml.safe_load(cfg_tag_group)) c = dynamic_load_class( - cfg.run_config.grouping.cls, kwargs=cfg.run_config.grouping.params + cfg.run_config.grouping.cls, **cfg.run_config.grouping.params ) assert c is None log_error.assert_called_once() diff --git a/tests/test_grouping.py b/tests/test_grouping.py index ea3c2ef..2f05f60 100644 --- a/tests/test_grouping.py +++ b/tests/test_grouping.py @@ -71,7 +71,7 @@ def create_pipeline_deps(self): def test_identity_grouping(self): # given deps = self.create_pipeline_deps() - grouper = IdentityNodeGrouper() + grouper = IdentityNodeGrouper(None, None) # when group = grouper.group(deps) for name in self.node_names: @@ -91,7 +91,7 @@ def test_legal_tag_groups(self): deps = self.create_pipeline_deps() for prefix in self.legal_groups: with self.subTest(msg=f"test_{prefix}", group_prefix=prefix): - grouper = TagNodeGrouper(prefix + ":") + grouper = TagNodeGrouper(None, None, prefix + ":") # when group = grouper.group(deps) # assert @@ -116,7 +116,7 @@ def test_illegal_tag_groups(self): deps = self.create_pipeline_deps() for prefix in self.illegal_groups: with self.subTest(msg=f"test_{prefix}", group_prefix=prefix): - grouper = TagNodeGrouper(prefix + ":") + grouper = TagNodeGrouper(None, None, prefix + ":") # when with self.assertRaises(GroupingException): grouper.group(deps)