Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix metadata processing in chat adapter #2040

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 135 additions & 23 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import re
import textwrap
from collections.abc import Mapping
from typing import Any, Dict, Literal, NamedTuple, Optional, Type
from typing import Any, Dict, Literal, NamedTuple, Optional, Type, Union, List

import pydantic
from pydantic.fields import FieldInfo
Expand Down Expand Up @@ -32,12 +32,22 @@ class FieldInfoWithName(NamedTuple):
# Built-in field indicating that a chat turn has been completed.
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField())

# Constraints that can be applied to numeric fields.
PERMITTED_CONSTRAINTS = {"gt", "lt", "ge", "le", "multiple_of", "allow_inf_nan"}


class ChatAdapter(Adapter):
def __init__(self, callbacks: Optional[list[BaseCallback]] = None):
super().__init__(callbacks)

def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
def __call__(
self,
lm: LM,
lm_kwargs: Dict[str, Any],
signature: Type[Signature],
demos: List[Dict[str, Any]],
inputs: Dict[str, Any]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing comma at the end, also don't use capital case Dict, using primitive types like dict is preferred: https://google.github.io/styleguide/pyguide.html#221-type-annotated-code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

) -> List[Dict[str, Any]]:
try:
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
except Exception as e:
Expand All @@ -46,9 +56,15 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature]
raise e
# fallback to JSONAdapter
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)

def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = []

def format(
self,
signature: Type[Signature],
demos: List[Dict[str, Any]],
inputs: Dict[str, Any]
) -> List[
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Dict[str, Any]]:
messages: List[Dict[str, Any]] = []

# Extract demos where some of the output_fields are not filled in.
incomplete_demos = [
Expand Down Expand Up @@ -80,7 +96,7 @@ def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs
messages = try_expand_image_tags(messages)
return messages

def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
def parse(self, signature: Type[Signature], completion: str) -> Dict[str, Any]:
sections = [(None, [])]

for line in completion.splitlines():
Expand Down Expand Up @@ -108,7 +124,13 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]:
return fields

# TODO(PR): Looks ok?
def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]:
def format_finetune_data(
self,
signature: Type[Signature],
demos: List[Dict[str, Any]],
inputs: Dict[str, Any],
outputs: Dict[str, Any]
) -> Dict[str, List[Any]]:
# Get system + user messages
messages = self.format(signature, demos, inputs)

Expand All @@ -118,10 +140,15 @@ def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str,
assistant_message = self.format_turn(signature, outputs, role, incomplete)
messages.append(assistant_message)

# Wrap the messages in a dictionary with a "messages" key
# Wrap the messages in a Dictionary with a "messages" key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to capitalize this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return dict(messages=messages)

def format_fields(self, signature: Type[Signature], values: dict[str, Any], role: str) -> str:
def format_fields(
self,
signature: Type[Signature],
values: Dict[str, Any],
role: str
) -> str:
fields_with_values = {
FieldInfoWithName(name=field_name, info=field_info): values.get(
field_name, "Not supplied for this particular example."
Expand All @@ -131,7 +158,14 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role
}
return format_fields(fields_with_values)

def format_turn(self, signature: Type[Signature], values: dict[str, Any], role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]:
def format_turn(
self,
signature: Type[Signature],
values: Dict[str, Any],
role: str,
incomplete: bool = False,
is_conversation_history: bool = False
) -> Dict[str, Any]:
return format_turn(signature, values, role, incomplete, is_conversation_history)


Expand All @@ -141,9 +175,8 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values
into a single string, which is is a multiline string if there are multiple fields.

Args:
fields_with_values: A dictionary mapping information about a field to its corresponding
value.
Parameters:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

fields_with_values: A Dictionary mapping information about a field to its corresponding value.
Returns:
The joined formatted values of the fields, represented as a string
"""
Expand All @@ -155,14 +188,20 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str:
return "\n\n".join(output).strip()


def format_turn(signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False):
def format_turn(
signature: Type[Signature],
values: Dict[str, Any],
role: str,
incomplete=False,
is_conversation_history=False
) -> Dict[str, str]:
"""
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature.

Args:
Parameters:
signature: The DSPy signature to which future LLM responses should conform.
values: A dictionary mapping field names (from the DSPy signature) to corresponding values
values: A Dictionary mapping field names (from the DSPy signature) to corresponding values
that should be included in the message.
role: The role of the message, which can be either "user" or "assistant".
incomplete: If True, indicates that output field values are present in the set of specified
Expand Down Expand Up @@ -228,26 +267,96 @@ def type_info(v):
return {"role": role, "content": joined_messages}


def enumerate_fields(fields: dict) -> str:
def _format_constraint(name: str, value: Union[str, float]) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move this function and other related ones into dspy/signatures/field and import in the adapter to keep adapter code simple.

"""
Formats a constraint for a numeric field.

Parameters:
name: The name of the constraint.
value: The value of the constraint.

Returns:
The formatted constraint as a string.
"""
constraints = {
'gt': f"greater than {value}",
'lt': f"less than {value}",
'ge': f"greater than or equal to {value}",
'le': f"less than or equal to {value}",
'multiple_of': f"a multiple of {value}",
'allow_inf_nan': "allows infinite and NaN values" if value else "no infinite or NaN values allowed"
}
return constraints.get(name, f"{name}={value}")


def format_metadata_summary(field: pydantic.fields.FieldInfo) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I recommend keeping this function, you removed a critical call to this function, limiting its use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted.

"""
Formats a summary of the metadata for a field.

Parameters:
field: The field whose metadata should be summarized.

Returns:
A string summarizing the field's metadata.
"""
if not hasattr(field, 'metadata') or not field.metadata:
return ""
metadata_parts = [str(meta) for meta in field.metadata]
if metadata_parts:
return f" [Metadata: {'; '.join(metadata_parts)}]"
return ""


def format_metadata_constraints(field: FieldInfo) -> str:
"""
Formats the constraints for a field.

Parameters:
field: The field whose constraints should be formatted.

Returns:
A string containing the formatted constraints.
"""
if not hasattr(field, 'metadata') or not field.metadata:
return ""
formatted_constraints = []
for meta in field.metadata:
constraint_names = [name for name in dir(meta) if not name.startswith('_')]
for name in constraint_names:
if hasattr(meta, name) and name in PERMITTED_CONSTRAINTS:
value = getattr(meta, name)
formatted_constraints.append(_format_constraint(name, value))
if not formatted_constraints:
return ""
elif len(formatted_constraints) == 1:
return f" that is {formatted_constraints[0]}."
else:
*front, last = formatted_constraints
return f" that is {', '.join(front)} and {last}."


def enumerate_fields(fields: Dict) -> str:
parts = []
for idx, (k, v) in enumerate(fields.items()):
parts.append(f"{idx + 1}. `{k}`")
parts[-1] += f" ({get_annotation_name(v.annotation)})"
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else ""

metadata_info = format_metadata_summary(v)
if metadata_info:
parts[-1] += metadata_info
return "\n".join(parts).strip()


def move_type_to_front(d):
# Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence.
def move_type_to_front(d: Union[Dict, List, Any]) -> Union[Dict, List, Any]:
# Move the 'type' key to the front of the Dictionary, recursively, for LLM readability/adherence.
if isinstance(d, Mapping):
return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0]))}
elif isinstance(d, list):
return [move_type_to_front(item) for item in d]
return d


def prepare_schema(field_type):
def prepare_schema(field_type: Type) -> Dict[str, Any]:
schema = pydantic.TypeAdapter(field_type).json_schema()
schema = move_type_to_front(schema)
return schema
Expand All @@ -259,7 +368,7 @@ def prepare_instructions(signature: SignatureMeta):
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields))
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.")

def field_metadata(field_name, field_info):
def field_metadata(field_name: str, field_info: FieldInfo) -> str:
field_type = field_info.annotation

if get_dspy_field_type(field_info) == "input" or field_type is str:
Expand All @@ -268,6 +377,9 @@ def field_metadata(field_name, field_info):
desc = "must be True or False"
elif field_type in (int, float):
desc = f"must be a single {field_type.__name__} value"
metadata_info = format_metadata_constraints(field_info)
if metadata_info:
desc += metadata_info
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum):
desc = f"must be one of: {'; '.join(field_type.__members__)}"
elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal:
Expand Down Expand Up @@ -298,4 +410,4 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
parts.append(f"In adhering to this structure, your objective is: {objective}")

return "\n\n".join(parts).strip()
return "\n\n".join(parts).strip()
Loading