Skip to content

Commit

Permalink
refactor: added run_config and context as optional arguments of nodeg…
Browse files Browse the repository at this point in the history
…rouper
  • Loading branch information
Lasica committed Nov 20, 2023
1 parent 857a0f9 commit 650514c
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 21 deletions.
21 changes: 11 additions & 10 deletions kedro_vertexai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,21 +111,19 @@


# the only place to put it to avoid circular dependencies
def dynamic_load_class(
load_class, args: Optional[list] = None, kwargs: Optional[dict] = None
):
def dynamic_load_class(load_class, *args, **kwargs):
if args is None:
args = []
if kwargs is None:
kwargs = {}
try:
module_name, class_name = load_class.rsplit(".", 1)
logger.info(f"Initializing {class_name}")
cls = getattr(import_module(module_name), class_name)
return cls(*args, **kwargs)
class_load = getattr(import_module(module_name), class_name)
return class_load(*args, **kwargs)
except: # noqa: E722
logger.error(
f"Could not dynamically load class {load_class} with it init params, "
f"Could not dynamically load class {load_class} with its init params, "
f"make sure it's valid and accessible from the current Python interpreter",
exc_info=True,
)
Expand All @@ -137,12 +135,15 @@ class GroupingConfig(BaseModel):

@validator("cls")
def class_valid(cls, v, values, **kwargs):
c = dynamic_load_class(v)
if c is None:
raise ValueError(f"Could not validate grouping class {v} with its params.")
try:
if "params" in values:
c(**values["params"])
c = dynamic_load_class(v, None, None, **values["params"])
else:
c = dynamic_load_class(v, None, None)
if c is None:
raise ValueError(
f"Could not validate grouping class {v} with its params."
)
except: # noqa: E722
raise ValueError(f"Invalid parameters for grouping class {v}.")
return v
Expand Down
4 changes: 1 addition & 3 deletions kedro_vertexai/dynamic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def build(
config: PluginConfig,
provider_config: DynamicConfigProviderConfig,
) -> "DynamicConfigProvider":
return dynamic_load_class(
provider_config.cls, args=[config], kwargs=provider_config.params
)
return dynamic_load_class(provider_config.cls, config, **provider_config.params)

def __init__(self, config: PluginConfig, **kwargs):
self.config = config
Expand Down
5 changes: 4 additions & 1 deletion kedro_vertexai/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ def __init__(self, config, project_name, context, run_name: str):
self.run_config: RunConfig = config.run_config
self.catalog = context.config_loader.get("catalog*")
self.grouping: NodeGrouper = dynamic_load_class(
self.run_config.grouping.cls, kwargs=self.run_config.grouping.params
self.run_config.grouping.cls,
context,
self.run_config,
**self.run_config.grouping.params,
)

def get_pipeline_name(self):
Expand Down
12 changes: 11 additions & 1 deletion kedro_vertexai/grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from dataclasses import dataclass, field
from typing import Dict, Set

from kedro.framework.context import KedroContext
from kedro.pipeline.node import Node
from toposort import CircularDependencyError, toposort

from kedro_vertexai.config import RunConfig

TagsDict = Dict[str, Set[str]]
PipelineDependenciesDict = Dict[Node, Set[Node]]
GroupDependenciesDict = Dict[str, Set[str]]
Expand Down Expand Up @@ -40,6 +43,10 @@ class NodeGrouper(ABC):
For each node it tells which set of nodes are parents of them, based on nodes outputs
"""

def __init__(self, kedro_context: KedroContext, run_config: RunConfig):
self.context = kedro_context
self.run_config = run_config

def group(self, node_dependencies: PipelineDependenciesDict) -> Grouping:
raise NotImplementedError

Expand Down Expand Up @@ -69,7 +76,10 @@ class TagNodeGrouper(NodeGrouper):
"""Grouping class that uses special tag prefix convention to aggregate
nodes together. Only one such tag is allowed per node."""

def __init__(self, tag_prefix="group:") -> None:
def __init__(
self, kedro_context: KedroContext, run_config: RunConfig, tag_prefix="group:"
) -> None:
super().__init__(kedro_context, run_config)
self.tag_prefix = tag_prefix

def group(self, node_dependencies: PipelineDependenciesDict) -> Grouping:
Expand Down
6 changes: 3 additions & 3 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_grouping_config(self):
assert (
cfg.run_config.grouping.cls == "kedro_vertexai.grouping.IdentityNodeGrouper"
)
c_obj = dynamic_load_class(cfg.run_config.grouping.cls)
c_obj = dynamic_load_class(cfg.run_config.grouping.cls, None, None)
assert isinstance(c_obj, IdentityNodeGrouper)

cfg_tag_group = """
Expand All @@ -71,7 +71,7 @@ def test_grouping_config(self):
cfg = PluginConfig.parse_obj(yaml.safe_load(cfg_tag_group))
assert cfg.run_config.grouping is not None
c_obj = dynamic_load_class(
cfg.run_config.grouping.cls, kwargs=cfg.run_config.grouping.params
cfg.run_config.grouping.cls, None, None, **cfg.run_config.grouping.params
)
assert isinstance(c_obj, TagNodeGrouper)
assert c_obj.tag_prefix == "group:"
Expand All @@ -91,7 +91,7 @@ def test_grouping_config_error(self, log_error):
"""
cfg = PluginConfig.parse_obj(yaml.safe_load(cfg_tag_group))
c = dynamic_load_class(
cfg.run_config.grouping.cls, kwargs=cfg.run_config.grouping.params
cfg.run_config.grouping.cls, **cfg.run_config.grouping.params
)
assert c is None
log_error.assert_called_once()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def create_pipeline_deps(self):
def test_identity_grouping(self):
# given
deps = self.create_pipeline_deps()
grouper = IdentityNodeGrouper()
grouper = IdentityNodeGrouper(None, None)
# when
group = grouper.group(deps)
for name in self.node_names:
Expand All @@ -91,7 +91,7 @@ def test_legal_tag_groups(self):
deps = self.create_pipeline_deps()
for prefix in self.legal_groups:
with self.subTest(msg=f"test_{prefix}", group_prefix=prefix):
grouper = TagNodeGrouper(prefix + ":")
grouper = TagNodeGrouper(None, None, prefix + ":")
# when
group = grouper.group(deps)
# assert
Expand All @@ -116,7 +116,7 @@ def test_illegal_tag_groups(self):
deps = self.create_pipeline_deps()
for prefix in self.illegal_groups:
with self.subTest(msg=f"test_{prefix}", group_prefix=prefix):
grouper = TagNodeGrouper(prefix + ":")
grouper = TagNodeGrouper(None, None, prefix + ":")
# when
with self.assertRaises(GroupingException):
grouper.group(deps)

0 comments on commit 650514c

Please sign in to comment.