diff --git a/dspy/signatures/signature.py b/dspy/signatures/signature.py index 88812be0c..387b9d771 100644 --- a/dspy/signatures/signature.py +++ b/dspy/signatures/signature.py @@ -1,5 +1,6 @@ import ast import inspect +import logging import re import types import typing @@ -11,8 +12,9 @@ from pydantic.fields import FieldInfo import dsp -from dspy.signatures.field import InputField, OutputField, new_to_old_field from dspy.adapters.image_utils import Image +from dspy.signatures.field import InputField, OutputField, new_to_old_field + def signature_to_template(signature, adapter=None) -> dsp.Template: """Convert from new to legacy format.""" @@ -242,8 +244,8 @@ class Signature(BaseModel, metaclass=SignatureMeta): @classmethod @contextmanager def replace( - cls: "Signature", - new_signature: "Signature", + cls, + new_signature: "Type[Signature]", validate_new_signature: bool = True, ) -> typing.Generator[None, None, None]: """Replace the signature with an updated version. @@ -262,16 +264,35 @@ def replace( f"Field '{field}' is missing from the updated signature '{new_signature.__class__}.", ) - class OldSignature(cls, Signature): + class OldSignature(cls): pass - replace_fields = ["__doc__", "model_fields", "model_extra", "model_config"] - for field in replace_fields: - setattr(cls, field, getattr(new_signature, field)) + def swap_attributes(source: Type[Signature]): + unhandled = {} + + for attr in ["__doc__", "__pydantic_fields__", "model_fields", "model_extra", "model_config"]: + try: + setattr(cls, attr, getattr(source, attr)) + except AttributeError as exc: + if attr in ("__pydantic_fields__", "model_fields"): + version = "< 2.10" if attr == "__pydantic_fields__" else ">= 2.10" + logging.debug(f"Model attribute {attr} not replaced, expected with pydantic {version}") + unhandled[attr] = exc + else: + raise exc + + # if neither of the attributes were replaced, raise an error to prevent silent failures + if set(unhandled.keys()) >= {"model_fields", "__pydantic_fields__"}: + raise ValueError("Failed to replace either model_fields or __pydantic_fields__") from ( + unhandled.get("model_fields") or unhandled.get("__pydantic_fields__") + ) + + swap_attributes(new_signature) cls.model_rebuild(force=True) + yield - for field in replace_fields: - setattr(cls, field, getattr(OldSignature, field)) + + swap_attributes(OldSignature) cls.model_rebuild(force=True) @@ -383,7 +404,7 @@ def _parse_type_node(node, names=None) -> Any: without using structural pattern matching introduced in Python 3.10. """ - + if names is None: names = typing.__dict__ @@ -401,7 +422,7 @@ def _parse_type_node(node, names=None) -> Any: id_ = node.id if id_ in names: return names[id_] - + for type_ in [int, str, float, bool, list, tuple, dict, Image]: if type_.__name__ == id_: return type_ @@ -420,7 +441,7 @@ def _parse_type_node(node, names=None) -> Any: keys = [kw.arg for kw in node.keywords] values = [kw.value.value for kw in node.keywords] return Field(**dict(zip(keys, values))) - + if isinstance(node, ast.Attribute) and node.attr == "Image": return Image diff --git a/requirements-dev.txt b/requirements-dev.txt index 98d89e732..23984fa07 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,7 @@ black==24.2.0 +datamodel-code-generator==0.26.3 +litellm[proxy]==1.51.0 +pillow==10.4.0 pre-commit==3.7.0 pytest==8.3.3 pytest-env==1.1.3 @@ -6,5 +9,3 @@ pytest-mock==3.12.0 ruff==0.3.0 torch==2.2.1 transformers==4.38.2 -pillow==10.4.0 -litellm[proxy]==1.51.0