-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix metadata processing in chat adapter #2040
base: main
Are you sure you want to change the base?
Changes from all commits
436d3c8
cc2e206
0a30136
b7c03cf
4001b85
4ec5c74
4277af5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
import re | ||
import textwrap | ||
from collections.abc import Mapping | ||
from typing import Any, Dict, Literal, NamedTuple, Optional, Type | ||
from typing import Any, Dict, Literal, NamedTuple, Optional, Type, Union, List | ||
|
||
import pydantic | ||
from pydantic.fields import FieldInfo | ||
|
@@ -32,12 +32,22 @@ class FieldInfoWithName(NamedTuple): | |
# Built-in field indicating that a chat turn has been completed. | ||
BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField()) | ||
|
||
# Constraints that can be applied to numeric fields. | ||
PERMITTED_CONSTRAINTS = {"gt", "lt", "ge", "le", "multiple_of", "allow_inf_nan"} | ||
|
||
|
||
class ChatAdapter(Adapter): | ||
def __init__(self, callbacks: Optional[list[BaseCallback]] = None): | ||
super().__init__(callbacks) | ||
|
||
def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: | ||
def __call__( | ||
self, | ||
lm: LM, | ||
lm_kwargs: Dict[str, Any], | ||
signature: Type[Signature], | ||
demos: List[Dict[str, Any]], | ||
inputs: Dict[str, Any] | ||
) -> List[Dict[str, Any]]: | ||
try: | ||
return super().__call__(lm, lm_kwargs, signature, demos, inputs) | ||
except Exception as e: | ||
|
@@ -46,9 +56,15 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature] | |
raise e | ||
# fallback to JSONAdapter | ||
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs) | ||
|
||
def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: | ||
messages: list[dict[str, Any]] = [] | ||
|
||
def format( | ||
self, | ||
signature: Type[Signature], | ||
demos: List[Dict[str, Any]], | ||
inputs: Dict[str, Any] | ||
) -> List[ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
Dict[str, Any]]: | ||
messages: List[Dict[str, Any]] = [] | ||
|
||
# Extract demos where some of the output_fields are not filled in. | ||
incomplete_demos = [ | ||
|
@@ -80,7 +96,7 @@ def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs | |
messages = try_expand_image_tags(messages) | ||
return messages | ||
|
||
def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: | ||
def parse(self, signature: Type[Signature], completion: str) -> Dict[str, Any]: | ||
sections = [(None, [])] | ||
|
||
for line in completion.splitlines(): | ||
|
@@ -108,7 +124,13 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: | |
return fields | ||
|
||
# TODO(PR): Looks ok? | ||
def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]: | ||
def format_finetune_data( | ||
self, | ||
signature: Type[Signature], | ||
demos: List[Dict[str, Any]], | ||
inputs: Dict[str, Any], | ||
outputs: Dict[str, Any] | ||
) -> Dict[str, List[Any]]: | ||
# Get system + user messages | ||
messages = self.format(signature, demos, inputs) | ||
|
||
|
@@ -118,10 +140,15 @@ def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, | |
assistant_message = self.format_turn(signature, outputs, role, incomplete) | ||
messages.append(assistant_message) | ||
|
||
# Wrap the messages in a dictionary with a "messages" key | ||
# Wrap the messages in a Dictionary with a "messages" key | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no need to capitalize this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
return dict(messages=messages) | ||
|
||
def format_fields(self, signature: Type[Signature], values: dict[str, Any], role: str) -> str: | ||
def format_fields( | ||
self, | ||
signature: Type[Signature], | ||
values: Dict[str, Any], | ||
role: str | ||
) -> str: | ||
fields_with_values = { | ||
FieldInfoWithName(name=field_name, info=field_info): values.get( | ||
field_name, "Not supplied for this particular example." | ||
|
@@ -131,7 +158,14 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role | |
} | ||
return format_fields(fields_with_values) | ||
|
||
def format_turn(self, signature: Type[Signature], values: dict[str, Any], role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]: | ||
def format_turn( | ||
self, | ||
signature: Type[Signature], | ||
values: Dict[str, Any], | ||
role: str, | ||
incomplete: bool = False, | ||
is_conversation_history: bool = False | ||
) -> Dict[str, Any]: | ||
return format_turn(signature, values, role, incomplete, is_conversation_history) | ||
|
||
|
||
|
@@ -141,9 +175,8 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str: | |
annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values | ||
into a single string, which is is a multiline string if there are multiple fields. | ||
|
||
Args: | ||
fields_with_values: A dictionary mapping information about a field to its corresponding | ||
value. | ||
Parameters: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert this, Args is the standard: https://google.github.io/styleguide/pyguide.html#383-functions-and-methods There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
fields_with_values: A Dictionary mapping information about a field to its corresponding value. | ||
Returns: | ||
The joined formatted values of the fields, represented as a string | ||
""" | ||
|
@@ -155,14 +188,20 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str: | |
return "\n\n".join(output).strip() | ||
|
||
|
||
def format_turn(signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False): | ||
def format_turn( | ||
signature: Type[Signature], | ||
values: Dict[str, Any], | ||
role: str, | ||
incomplete=False, | ||
is_conversation_history=False | ||
) -> Dict[str, str]: | ||
""" | ||
Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted | ||
so that it can instruct an LLM to generate responses conforming to the specified DSPy signature. | ||
|
||
Args: | ||
Parameters: | ||
signature: The DSPy signature to which future LLM responses should conform. | ||
values: A dictionary mapping field names (from the DSPy signature) to corresponding values | ||
values: A Dictionary mapping field names (from the DSPy signature) to corresponding values | ||
that should be included in the message. | ||
role: The role of the message, which can be either "user" or "assistant". | ||
incomplete: If True, indicates that output field values are present in the set of specified | ||
|
@@ -228,26 +267,96 @@ def type_info(v): | |
return {"role": role, "content": joined_messages} | ||
|
||
|
||
def enumerate_fields(fields: dict) -> str: | ||
def _format_constraint(name: str, value: Union[str, float]) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can move this function and other related ones into |
||
""" | ||
Formats a constraint for a numeric field. | ||
|
||
Parameters: | ||
name: The name of the constraint. | ||
value: The value of the constraint. | ||
|
||
Returns: | ||
The formatted constraint as a string. | ||
""" | ||
constraints = { | ||
'gt': f"greater than {value}", | ||
'lt': f"less than {value}", | ||
'ge': f"greater than or equal to {value}", | ||
'le': f"less than or equal to {value}", | ||
'multiple_of': f"a multiple of {value}", | ||
'allow_inf_nan': "allows infinite and NaN values" if value else "no infinite or NaN values allowed" | ||
} | ||
return constraints.get(name, f"{name}={value}") | ||
|
||
|
||
def format_metadata_summary(field: pydantic.fields.FieldInfo) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While I recommend keeping this function, you removed a critical call to this function, limiting its use. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reverted. |
||
""" | ||
Formats a summary of the metadata for a field. | ||
|
||
Parameters: | ||
field: The field whose metadata should be summarized. | ||
|
||
Returns: | ||
A string summarizing the field's metadata. | ||
""" | ||
if not hasattr(field, 'metadata') or not field.metadata: | ||
return "" | ||
metadata_parts = [str(meta) for meta in field.metadata] | ||
if metadata_parts: | ||
return f" [Metadata: {'; '.join(metadata_parts)}]" | ||
return "" | ||
|
||
|
||
def format_metadata_constraints(field: FieldInfo) -> str: | ||
""" | ||
Formats the constraints for a field. | ||
|
||
Parameters: | ||
field: The field whose constraints should be formatted. | ||
|
||
Returns: | ||
A string containing the formatted constraints. | ||
""" | ||
if not hasattr(field, 'metadata') or not field.metadata: | ||
return "" | ||
formatted_constraints = [] | ||
for meta in field.metadata: | ||
constraint_names = [name for name in dir(meta) if not name.startswith('_')] | ||
for name in constraint_names: | ||
if hasattr(meta, name) and name in PERMITTED_CONSTRAINTS: | ||
value = getattr(meta, name) | ||
formatted_constraints.append(_format_constraint(name, value)) | ||
if not formatted_constraints: | ||
return "" | ||
elif len(formatted_constraints) == 1: | ||
return f" that is {formatted_constraints[0]}." | ||
else: | ||
*front, last = formatted_constraints | ||
return f" that is {', '.join(front)} and {last}." | ||
|
||
|
||
def enumerate_fields(fields: Dict) -> str: | ||
parts = [] | ||
for idx, (k, v) in enumerate(fields.items()): | ||
parts.append(f"{idx + 1}. `{k}`") | ||
parts[-1] += f" ({get_annotation_name(v.annotation)})" | ||
parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else "" | ||
|
||
metadata_info = format_metadata_summary(v) | ||
if metadata_info: | ||
parts[-1] += metadata_info | ||
return "\n".join(parts).strip() | ||
|
||
|
||
def move_type_to_front(d): | ||
# Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence. | ||
def move_type_to_front(d: Union[Dict, List, Any]) -> Union[Dict, List, Any]: | ||
# Move the 'type' key to the front of the Dictionary, recursively, for LLM readability/adherence. | ||
if isinstance(d, Mapping): | ||
return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0]))} | ||
elif isinstance(d, list): | ||
return [move_type_to_front(item) for item in d] | ||
return d | ||
|
||
|
||
def prepare_schema(field_type): | ||
def prepare_schema(field_type: Type) -> Dict[str, Any]: | ||
schema = pydantic.TypeAdapter(field_type).json_schema() | ||
schema = move_type_to_front(schema) | ||
return schema | ||
|
@@ -259,7 +368,7 @@ def prepare_instructions(signature: SignatureMeta): | |
parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields)) | ||
parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") | ||
|
||
def field_metadata(field_name, field_info): | ||
def field_metadata(field_name: str, field_info: FieldInfo) -> str: | ||
field_type = field_info.annotation | ||
|
||
if get_dspy_field_type(field_info) == "input" or field_type is str: | ||
|
@@ -268,6 +377,9 @@ def field_metadata(field_name, field_info): | |
desc = "must be True or False" | ||
elif field_type in (int, float): | ||
desc = f"must be a single {field_type.__name__} value" | ||
metadata_info = format_metadata_constraints(field_info) | ||
if metadata_info: | ||
desc += metadata_info | ||
elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): | ||
desc = f"must be one of: {'; '.join(field_type.__members__)}" | ||
elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal: | ||
|
@@ -298,4 +410,4 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): | |
objective = ("\n" + " " * 8).join([""] + instructions.splitlines()) | ||
parts.append(f"In adhering to this structure, your objective is: {objective}") | ||
|
||
return "\n\n".join(parts).strip() | ||
return "\n\n".join(parts).strip() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing comma at the end, also don't use capital case
Dict
, using primitive types likedict
is preferred: https://google.github.io/styleguide/pyguide.html#221-type-annotated-codeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done