Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/signature replace and pydantic 2.10 #1855

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions dspy/signatures/signature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
import inspect
import logging
import re
import types
import typing
Expand All @@ -11,8 +12,9 @@
from pydantic.fields import FieldInfo

import dsp
from dspy.signatures.field import InputField, OutputField, new_to_old_field
from dspy.adapters.image_utils import Image
from dspy.signatures.field import InputField, OutputField, new_to_old_field


def signature_to_template(signature, adapter=None) -> dsp.Template:
"""Convert from new to legacy format."""
Expand Down Expand Up @@ -242,8 +244,8 @@ class Signature(BaseModel, metaclass=SignatureMeta):
@classmethod
@contextmanager
def replace(
cls: "Signature",
new_signature: "Signature",
cls,
new_signature: "Type[Signature]",
validate_new_signature: bool = True,
) -> typing.Generator[None, None, None]:
"""Replace the signature with an updated version.
Expand All @@ -262,16 +264,35 @@ def replace(
f"Field '{field}' is missing from the updated signature '{new_signature.__class__}.",
)

class OldSignature(cls, Signature):
class OldSignature(cls):
pass

replace_fields = ["__doc__", "model_fields", "model_extra", "model_config"]
for field in replace_fields:
setattr(cls, field, getattr(new_signature, field))
def swap_attributes(source: Type[Signature]):
unhandled = {}

for attr in ["__doc__", "__pydantic_fields__", "model_fields", "model_extra", "model_config"]:
try:
setattr(cls, attr, getattr(source, attr))
except AttributeError as exc:
if attr in ("__pydantic_fields__", "model_fields"):
version = "< 2.10" if attr == "__pydantic_fields__" else ">= 2.10"
logging.debug(f"Model attribute {attr} not replaced, expected with pydantic {version}")
unhandled[attr] = exc
else:
raise exc

# if neither of the attributes were replaced, raise an error to prevent silent failures
if set(unhandled.keys()) >= {"model_fields", "__pydantic_fields__"}:
raise ValueError("Failed to replace either model_fields or __pydantic_fields__") from (
unhandled.get("model_fields") or unhandled.get("__pydantic_fields__")
)

swap_attributes(new_signature)
cls.model_rebuild(force=True)

yield
for field in replace_fields:
setattr(cls, field, getattr(OldSignature, field))

swap_attributes(OldSignature)
cls.model_rebuild(force=True)


Expand Down Expand Up @@ -383,7 +404,7 @@ def _parse_type_node(node, names=None) -> Any:

without using structural pattern matching introduced in Python 3.10.
"""

if names is None:
names = typing.__dict__

Expand All @@ -401,7 +422,7 @@ def _parse_type_node(node, names=None) -> Any:
id_ = node.id
if id_ in names:
return names[id_]

for type_ in [int, str, float, bool, list, tuple, dict, Image]:
if type_.__name__ == id_:
return type_
Expand All @@ -420,7 +441,7 @@ def _parse_type_node(node, names=None) -> Any:
keys = [kw.arg for kw in node.keywords]
values = [kw.value.value for kw in node.keywords]
return Field(**dict(zip(keys, values)))

if isinstance(node, ast.Attribute) and node.attr == "Image":
return Image

Expand Down
5 changes: 3 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
black==24.2.0
datamodel-code-generator==0.26.3
litellm[proxy]==1.51.0
pillow==10.4.0
pre-commit==3.7.0
pytest==8.3.3
pytest-env==1.1.3
pytest-mock==3.12.0
ruff==0.3.0
torch==2.2.1
transformers==4.38.2
pillow==10.4.0
litellm[proxy]==1.51.0
Loading