Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for pipeline deserialization callbacks #7518

Merged
merged 7 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion haystack/core/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
import inspect
import sys
from collections.abc import Callable
from contextlib import contextmanager
from contextvars import ContextVar
from copy import deepcopy
from types import new_class
from typing import Any, Dict, Optional, Protocol, runtime_checkable
Expand All @@ -84,6 +86,28 @@
logger = logging.getLogger(__name__)


# Callback inputs: component class (Type) and init parameters (as keyword arguments) (Dict[str, Any]).
_COMPONENT_PRE_INIT_CALLBACK: ContextVar[Optional[Callable]] = ContextVar("component_pre_init_callback", default=None)


@contextmanager
def _hook_component_init(callback: Callable):
"""
Context manager to set a callback that will be invoked
before a component's constructor is called. The callback
receives the component class and the init parameters (as keyword
arguments) and can modify the init parameters in place.

:param callback:
Callback function to invoke.
"""
token = _COMPONENT_PRE_INIT_CALLBACK.set(callback)
try:
yield
finally:
_COMPONENT_PRE_INIT_CALLBACK.reset(token)


@runtime_checkable
class Component(Protocol):
"""
Expand Down Expand Up @@ -123,13 +147,39 @@ def run(self, **kwargs):


class ComponentMeta(type):
@staticmethod
def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
init_signature = inspect.signature(cls_type.__init__)
init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}

out = {}
for arg, (name, info) in zip(args, init_params.items()):
if info.kind == inspect.Parameter.VAR_POSITIONAL:
raise ComponentError(
"Pre-init hooks do not support components with variadic positional args in their init method"
)

assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
out[name] = arg
return out

def __call__(cls, *args, **kwargs):
"""
This method is called when clients instantiate a Component and
runs before __new__ and __init__.
"""
# This will call __new__ then __init__, giving us back the Component instance
instance = super().__call__(*args, **kwargs)
pre_init_hook = _COMPONENT_PRE_INIT_CALLBACK.get()
if pre_init_hook is None:
instance = super().__call__(*args, **kwargs)
else:
named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
assert (
set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
), "positional and keyword arguments overlap"
kwargs.update(named_positional_args)
pre_init_hook(cls, kwargs)
instance = super().__call__(**kwargs)

# Before returning, we have the chance to modify the newly created
# Component instance, so we take the chance and set up the I/O sockets
Expand Down
32 changes: 25 additions & 7 deletions haystack/core/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
PipelineUnmarshalError,
PipelineValidationError,
)
from haystack.core.serialization import component_from_dict, component_to_dict
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
from haystack.core.type_utils import _type_name, _types_are_compatible
from haystack.marshal import Marshaller, YamlMarshaller
from haystack.telemetry import pipeline_running
Expand Down Expand Up @@ -130,12 +130,16 @@ def to_dict(self) -> Dict[str, Any]:
}

@classmethod
def from_dict(cls: Type[T], data: Dict[str, Any], **kwargs) -> T:
def from_dict(
cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
) -> T:
"""
Deserializes the pipeline from a dictionary.

:param data:
Dictionary to deserialize from.
:param callbacks:
Callbacks to invoke during deserialization.
:param kwargs:
`components`: a dictionary of {name: instance} to reuse instances of components instead of creating new ones.
:returns:
Expand Down Expand Up @@ -171,7 +175,7 @@ def from_dict(cls: Type[T], data: Dict[str, Any], **kwargs) -> T:

# Create a new one
component_class = component.registry[component_data["type"]]
instance = component_from_dict(component_class, component_data)
instance = component_from_dict(component_class, component_data, name, callbacks)
pipe.add_component(name=name, instance=instance)

for connection in data.get("connections", []):
Expand Down Expand Up @@ -208,21 +212,33 @@ def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER):
fp.write(marshaller.marshal(self.to_dict()))

@classmethod
def loads(cls, data: Union[str, bytes, bytearray], marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipeline":
def loads(
cls,
data: Union[str, bytes, bytearray],
marshaller: Marshaller = DEFAULT_MARSHALLER,
callbacks: Optional[DeserializationCallbacks] = None,
) -> "Pipeline":
"""
Creates a `Pipeline` object from the string representation passed in the `data` argument.

:param data:
The string representation of the pipeline, can be `str`, `bytes` or `bytearray`.
:param marshaller:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:param callbacks:
Callbacks to invoke during deserialization.
:returns:
A `Pipeline` object.
"""
return cls.from_dict(marshaller.unmarshal(data))
return cls.from_dict(marshaller.unmarshal(data), callbacks)

@classmethod
def load(cls, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipeline":
def load(
cls,
fp: TextIO,
marshaller: Marshaller = DEFAULT_MARSHALLER,
callbacks: Optional[DeserializationCallbacks] = None,
) -> "Pipeline":
"""
Creates a `Pipeline` object from the string representation read from the file-like
object passed in the `fp` argument.
Expand All @@ -233,10 +249,12 @@ def load(cls, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER) -> "Pipel
A file-like object ready to be read from.
:param marshaller:
The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
:param callbacks:
Callbacks to invoke during deserialization.
:returns:
A `Pipeline` object.
"""
return cls.from_dict(marshaller.unmarshal(fp.read()))
return cls.from_dict(marshaller.unmarshal(fp.read()), callbacks)

def add_component(self, name: str, instance: Component) -> None:
"""
Expand Down
51 changes: 46 additions & 5 deletions haystack/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,33 @@
#
# SPDX-License-Identifier: Apache-2.0
import inspect
from typing import Any, Dict, Type
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Dict, Optional, Type

from haystack.core.component.component import _hook_component_init
from haystack.core.errors import DeserializationError, SerializationError


@dataclass(frozen=True)
class DeserializationCallbacks:
"""
Callback functions that are invoked in specific
stages of the pipeline deserialization process.

:param component_pre_init:
Invoked just before a component instance is
initialized. Receives the following inputs:
`component_name` (`str`), `component_class` (`Type`), `init_params` (`Dict[str, Any]`).

The callback is allowed to modify the `init_params`
dictionary, which contains all the parameters that
are passed to the component's constructor.
"""

component_pre_init: Optional[Callable] = None


def component_to_dict(obj: Any) -> Dict[str, Any]:
"""
Converts a component instance into a dictionary. If a `to_dict` method is present in the
Expand Down Expand Up @@ -59,7 +81,9 @@ def generate_qualified_class_name(cls: Type[object]) -> str:
return f"{cls.__module__}.{cls.__name__}"


def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
def component_from_dict(
cls: Type[object], data: Dict[str, Any], name: str, callbacks: Optional[DeserializationCallbacks] = None
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
) -> Any:
"""
Creates a component instance from a dictionary. If a `from_dict` method is present in the
component class, that will be used instead of the default method.
Expand All @@ -68,13 +92,30 @@ def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
The class to be used for deserialization.
:param data:
The serialized data.
:param name:
The name of the component.
:param callbacks:
Callbacks to invoke during deserialization.
:returns:
The deserialized component.
"""
if hasattr(cls, "from_dict"):
return cls.from_dict(data)

return default_from_dict(cls, data)
def component_pre_init_callback(component_cls, init_params):
assert callbacks is not None
assert callbacks.component_pre_init is not None
callbacks.component_pre_init(name, component_cls, init_params)

def do_from_dict():
if hasattr(cls, "from_dict"):
return cls.from_dict(data)

return default_from_dict(cls, data)

if callbacks is None or callbacks.component_pre_init is None:
return do_from_dict()

with _hook_component_init(component_pre_init_callback):
return do_from_dict()


def default_to_dict(obj: Any, **init_parameters) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add support for callbacks during pipeline deserialization. Currently supports a pre-init hook for components that can be used to inspect and modify the initialization parameters
before the invocation of the component's `__init__` method.
119 changes: 119 additions & 0 deletions test/core/component/test_component.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from functools import partial
from typing import Any

import pytest

from haystack.core.component import Component, InputSocket, OutputSocket, component
from haystack.core.component.component import _hook_component_init
from haystack.core.component.types import Variadic
from haystack.core.errors import ComponentError
from haystack.core.pipeline import Pipeline
Expand Down Expand Up @@ -271,3 +273,120 @@ def run(self, value: int):
"Component 'MockComponent' has no variadic input, but it's marked as greedy."
" This is not supported and can lead to unexpected behavior.\n" in caplog.text
)


def test_pre_init_hooking():
@component
class MockComponent:
def __init__(self, pos_arg1, pos_arg2, pos_arg3=None, *, kwarg1=1, kwarg2="string"):
self.pos_arg1 = pos_arg1
self.pos_arg2 = pos_arg2
self.pos_arg3 = pos_arg3
self.kwarg1 = kwarg1
self.kwarg2 = kwarg2

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}

def pre_init_hook(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

def pre_init_hook_modify(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

init_params["pos_arg1"] = 2
init_params["pos_arg2"] = 0
init_params["pos_arg3"] = "modified"
init_params["kwarg2"] = "modified string"

with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "kwarg1": None})):
_ = MockComponent(1, 2, kwarg1=None)

with _hook_component_init(partial(pre_init_hook, expected_params={"pos_arg1": 1, "pos_arg2": 2, "pos_arg3": 0.01})):
_ = MockComponent(pos_arg1=1, pos_arg2=2, pos_arg3=0.01)

with _hook_component_init(
partial(pre_init_hook_modify, expected_params={"pos_arg1": 0, "pos_arg2": 1, "pos_arg3": 0.01, "kwarg1": 0})
):
c = MockComponent(0, 1, pos_arg3=0.01, kwarg1=0)

assert c.pos_arg1 == 2
assert c.pos_arg2 == 0
assert c.pos_arg3 == "modified"
assert c.kwarg1 == 0
assert c.kwarg2 == "modified string"


def test_pre_init_hooking_variadic_positional_args():
@component
class MockComponent:
def __init__(self, *args, kwarg1=1, kwarg2="string"):
self.args = args
self.kwarg1 = kwarg1
self.kwarg2 = kwarg2

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}

def pre_init_hook(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

c = MockComponent(1, 2, 3, kwarg1=None)
assert c.args == (1, 2, 3)
assert c.kwarg1 is None
assert c.kwarg2 == "string"

with pytest.raises(ComponentError), _hook_component_init(
partial(pre_init_hook, expected_params={"args": (1, 2), "kwarg1": None})
):
_ = MockComponent(1, 2, kwarg1=None)


def test_pre_init_hooking_variadic_kwargs():
@component
class MockComponent:
def __init__(self, pos_arg1, pos_arg2=None, **kwargs):
self.pos_arg1 = pos_arg1
self.pos_arg2 = pos_arg2
self.kwargs = kwargs

@component.output_types(output_value=int)
def run(self, input_value: int):
return {"output_value": input_value}

def pre_init_hook(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

with _hook_component_init(
partial(pre_init_hook, expected_params={"pos_arg1": 1, "kwarg1": None, "kwarg2": 10, "kwarg3": "string"})
):
c = MockComponent(1, kwarg1=None, kwarg2=10, kwarg3="string")
assert c.pos_arg1 == 1
assert c.pos_arg2 is None
assert c.kwargs == {"kwarg1": None, "kwarg2": 10, "kwarg3": "string"}

def pre_init_hook_modify(component_class, init_params, expected_params):
assert component_class == MockComponent
assert init_params == expected_params

init_params["pos_arg1"] = 2
init_params["pos_arg2"] = 0
init_params["some_kwarg"] = "modified string"

with _hook_component_init(
partial(
pre_init_hook_modify,
expected_params={"pos_arg1": 0, "pos_arg2": 1, "kwarg1": 999, "some_kwarg": "some_value"},
)
):
c = MockComponent(0, 1, kwarg1=999, some_kwarg="some_value")

assert c.pos_arg1 == 2
assert c.pos_arg2 == 0
assert c.kwargs == {"kwarg1": 999, "some_kwarg": "modified string"}
Loading