Skip to content

Commit

Permalink
refactor: removed run_config from Grouper class, improved grp config …
Browse files Browse the repository at this point in the history
…validation
  • Loading branch information
Lasica committed Nov 21, 2023
1 parent 650514c commit ec50f51
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 38 deletions.
44 changes: 29 additions & 15 deletions kedro_vertexai/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from importlib import import_module
from inspect import signature
from typing import Dict, List, Optional

from pydantic import BaseModel, validator
Expand Down Expand Up @@ -110,42 +111,55 @@
logger = logging.getLogger(__name__)


# the only place to put it to avoid circular dependencies
def dynamic_load_class(load_class, *args, **kwargs):
if args is None:
args = []
if kwargs is None:
kwargs = {}
def dynamic_load_class(load_class):
try:
module_name, class_name = load_class.rsplit(".", 1)
logger.info(f"Initializing {class_name}")
class_load = getattr(import_module(module_name), class_name)
return class_load(*args, **kwargs)
return class_load
except: # noqa: E722
logger.error(
f"Could not dynamically load class {load_class} with its init params, "
f"Could not dynamically load class {load_class}, "
f"make sure it's valid and accessible from the current Python interpreter",
exc_info=True,
)


def dynamic_init_class(load_class, *args, **kwargs):
if load_class is None:
return None
if args is None:
args = []
if kwargs is None:
kwargs = {}
try:
loaded_class = dynamic_load_class(load_class)
return loaded_class(*args, **kwargs)
except: # noqa: E722
logger.error(
f"Could not dynamically init class {load_class} with its init params, "
f"make sure the configured params match the ",
exc_info=True,
)


class GroupingConfig(BaseModel):
cls: str = "kedro_vertexai.grouping.IdentityNodeGrouper"
params: Optional[dict] = {}

@validator("cls")
def class_valid(cls, v, values, **kwargs):
try:
grouper_class = dynamic_load_class(v)
class_sig = signature(grouper_class)
if "params" in values:
c = dynamic_load_class(v, None, None, **values["params"])
class_sig.bind(None, **values["params"])
else:
c = dynamic_load_class(v, None, None)
if c is None:
raise ValueError(
f"Could not validate grouping class {v} with its params."
)
class_sig.bind(None)
except: # noqa: E722
raise ValueError(f"Invalid parameters for grouping class {v}.")
raise ValueError(
f"Invalid parameters for grouping class {v}, validation failed."
)
return v

# @computed_field
Expand Down
4 changes: 2 additions & 2 deletions kedro_vertexai/dynamic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from kedro_vertexai.config import (
DynamicConfigProviderConfig,
PluginConfig,
dynamic_load_class,
dynamic_init_class,
)

logger = logging.getLogger(__name__)
Expand All @@ -20,7 +20,7 @@ def build(
config: PluginConfig,
provider_config: DynamicConfigProviderConfig,
) -> "DynamicConfigProvider":
return dynamic_load_class(provider_config.cls, config, **provider_config.params)
return dynamic_init_class(provider_config.cls, config, **provider_config.params)

def __init__(self, config: PluginConfig, **kwargs):
self.config = config
Expand Down
5 changes: 2 additions & 3 deletions kedro_vertexai/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from kedro_vertexai.config import (
KedroVertexAIRunnerConfig,
RunConfig,
dynamic_load_class,
dynamic_init_class,
)
from kedro_vertexai.constants import (
KEDRO_CONFIG_JOB_NAME,
Expand Down Expand Up @@ -50,10 +50,9 @@ def __init__(self, config, project_name, context, run_name: str):
self.context: KedroContext = context
self.run_config: RunConfig = config.run_config
self.catalog = context.config_loader.get("catalog*")
self.grouping: NodeGrouper = dynamic_load_class(
self.grouping: NodeGrouper = dynamic_init_class(
self.run_config.grouping.cls,
context,
self.run_config,
**self.run_config.grouping.params,
)

Expand Down
13 changes: 4 additions & 9 deletions kedro_vertexai/grouping.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from abc import ABC
from dataclasses import dataclass, field
from typing import Dict, Set
from typing import Dict, Optional, Set

from kedro.framework.context import KedroContext
from kedro.pipeline.node import Node
from toposort import CircularDependencyError, toposort

from kedro_vertexai.config import RunConfig

TagsDict = Dict[str, Set[str]]
PipelineDependenciesDict = Dict[Node, Set[Node]]
GroupDependenciesDict = Dict[str, Set[str]]
Expand Down Expand Up @@ -43,9 +41,8 @@ class NodeGrouper(ABC):
For each node it tells which set of nodes are parents of them, based on nodes outputs
"""

def __init__(self, kedro_context: KedroContext, run_config: RunConfig):
def __init__(self, kedro_context: Optional[KedroContext]):
self.context = kedro_context
self.run_config = run_config

def group(self, node_dependencies: PipelineDependenciesDict) -> Grouping:
raise NotImplementedError
Expand Down Expand Up @@ -76,10 +73,8 @@ class TagNodeGrouper(NodeGrouper):
"""Grouping class that uses special tag prefix convention to aggregate
nodes together. Only one such tag is allowed per node."""

def __init__(
self, kedro_context: KedroContext, run_config: RunConfig, tag_prefix="group:"
) -> None:
super().__init__(kedro_context, run_config)
def __init__(self, kedro_context: KedroContext, tag_prefix="group:") -> None:
super().__init__(kedro_context)
self.tag_prefix = tag_prefix

def group(self, node_dependencies: PipelineDependenciesDict) -> Grouping:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import yaml
from pydantic import ValidationError

from kedro_vertexai.config import PluginConfig, dynamic_load_class
from kedro_vertexai.config import PluginConfig, dynamic_init_class
from kedro_vertexai.grouping import IdentityNodeGrouper, TagNodeGrouper

CONFIG_FULL = """
Expand Down Expand Up @@ -54,7 +54,7 @@ def test_grouping_config(self):
assert (
cfg.run_config.grouping.cls == "kedro_vertexai.grouping.IdentityNodeGrouper"
)
c_obj = dynamic_load_class(cfg.run_config.grouping.cls, None, None)
c_obj = dynamic_init_class(cfg.run_config.grouping.cls, None)
assert isinstance(c_obj, IdentityNodeGrouper)

cfg_tag_group = """
Expand All @@ -70,8 +70,8 @@ def test_grouping_config(self):
"""
cfg = PluginConfig.parse_obj(yaml.safe_load(cfg_tag_group))
assert cfg.run_config.grouping is not None
c_obj = dynamic_load_class(
cfg.run_config.grouping.cls, None, None, **cfg.run_config.grouping.params
c_obj = dynamic_init_class(
cfg.run_config.grouping.cls, None, **cfg.run_config.grouping.params
)
assert isinstance(c_obj, TagNodeGrouper)
assert c_obj.tag_prefix == "group:"
Expand All @@ -90,8 +90,8 @@ def test_grouping_config_error(self, log_error):
foo: "bar:"
"""
cfg = PluginConfig.parse_obj(yaml.safe_load(cfg_tag_group))
c = dynamic_load_class(
cfg.run_config.grouping.cls, **cfg.run_config.grouping.params
c = dynamic_init_class(
cfg.run_config.grouping.cls, None, **cfg.run_config.grouping.params
)
assert c is None
log_error.assert_called_once()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def create_pipeline_deps(self):
def test_identity_grouping(self):
# given
deps = self.create_pipeline_deps()
grouper = IdentityNodeGrouper(None, None)
grouper = IdentityNodeGrouper(None)
# when
group = grouper.group(deps)
for name in self.node_names:
Expand All @@ -91,7 +91,7 @@ def test_legal_tag_groups(self):
deps = self.create_pipeline_deps()
for prefix in self.legal_groups:
with self.subTest(msg=f"test_{prefix}", group_prefix=prefix):
grouper = TagNodeGrouper(None, None, prefix + ":")
grouper = TagNodeGrouper(None, prefix + ":")
# when
group = grouper.group(deps)
# assert
Expand All @@ -116,7 +116,7 @@ def test_illegal_tag_groups(self):
deps = self.create_pipeline_deps()
for prefix in self.illegal_groups:
with self.subTest(msg=f"test_{prefix}", group_prefix=prefix):
grouper = TagNodeGrouper(None, None, prefix + ":")
grouper = TagNodeGrouper(None, prefix + ":")
# when
with self.assertRaises(GroupingException):
grouper.group(deps)

0 comments on commit ec50f51

Please sign in to comment.