diff --git a/dspy/adapters/chat_adapter.py b/dspy/adapters/chat_adapter.py index 885270813a..d454fb4112 100644 --- a/dspy/adapters/chat_adapter.py +++ b/dspy/adapters/chat_adapter.py @@ -4,7 +4,7 @@ import re import textwrap from collections.abc import Mapping -from typing import Any, Dict, Literal, NamedTuple, Optional, Type +from typing import Any, Dict, Literal, NamedTuple, Optional, Type, Union, List import pydantic from pydantic.fields import FieldInfo @@ -32,12 +32,22 @@ class FieldInfoWithName(NamedTuple): # Built-in field indicating that a chat turn has been completed. BuiltInCompletedOutputFieldInfo = FieldInfoWithName(name="completed", info=OutputField()) +# Constraints that can be applied to numeric fields. +PERMITTED_CONSTRAINTS = {"gt", "lt", "ge", "le", "multiple_of", "allow_inf_nan"} + class ChatAdapter(Adapter): def __init__(self, callbacks: Optional[list[BaseCallback]] = None): super().__init__(callbacks) - def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: + def __call__( + self, + lm: LM, + lm_kwargs: Dict[str, Any], + signature: Type[Signature], + demos: List[Dict[str, Any]], + inputs: Dict[str, Any] + ) -> List[Dict[str, Any]]: try: return super().__call__(lm, lm_kwargs, signature, demos, inputs) except Exception as e: @@ -46,9 +56,15 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature] raise e # fallback to JSONAdapter return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs) - - def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: - messages: list[dict[str, Any]] = [] + + def format( + self, + signature: Type[Signature], + demos: List[Dict[str, Any]], + inputs: Dict[str, Any] + ) -> List[ + Dict[str, Any]]: + messages: List[Dict[str, Any]] = [] # Extract demos where some of the output_fields are not filled in. incomplete_demos = [ @@ -80,7 +96,7 @@ def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs messages = try_expand_image_tags(messages) return messages - def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: + def parse(self, signature: Type[Signature], completion: str) -> Dict[str, Any]: sections = [(None, [])] for line in completion.splitlines(): @@ -108,7 +124,13 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: return fields # TODO(PR): Looks ok? - def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]: + def format_finetune_data( + self, + signature: Type[Signature], + demos: List[Dict[str, Any]], + inputs: Dict[str, Any], + outputs: Dict[str, Any] + ) -> Dict[str, List[Any]]: # Get system + user messages messages = self.format(signature, demos, inputs) @@ -118,10 +140,15 @@ def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, assistant_message = self.format_turn(signature, outputs, role, incomplete) messages.append(assistant_message) - # Wrap the messages in a dictionary with a "messages" key + # Wrap the messages in a Dictionary with a "messages" key return dict(messages=messages) - def format_fields(self, signature: Type[Signature], values: dict[str, Any], role: str) -> str: + def format_fields( + self, + signature: Type[Signature], + values: Dict[str, Any], + role: str + ) -> str: fields_with_values = { FieldInfoWithName(name=field_name, info=field_info): values.get( field_name, "Not supplied for this particular example." @@ -131,7 +158,14 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role } return format_fields(fields_with_values) - def format_turn(self, signature: Type[Signature], values: dict[str, Any], role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]: + def format_turn( + self, + signature: Type[Signature], + values: Dict[str, Any], + role: str, + incomplete: bool = False, + is_conversation_history: bool = False + ) -> Dict[str, Any]: return format_turn(signature, values, role, incomplete, is_conversation_history) @@ -141,9 +175,8 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str: annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values into a single string, which is is a multiline string if there are multiple fields. - Args: - fields_with_values: A dictionary mapping information about a field to its corresponding - value. + Parameters: + fields_with_values: A Dictionary mapping information about a field to its corresponding value. Returns: The joined formatted values of the fields, represented as a string """ @@ -155,14 +188,20 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any]) -> str: return "\n\n".join(output).strip() -def format_turn(signature: Type[Signature], values: dict[str, Any], role: str, incomplete=False, is_conversation_history=False): +def format_turn( + signature: Type[Signature], + values: Dict[str, Any], + role: str, + incomplete=False, + is_conversation_history=False +) -> Dict[str, str]: """ Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted so that it can instruct an LLM to generate responses conforming to the specified DSPy signature. - Args: + Parameters: signature: The DSPy signature to which future LLM responses should conform. - values: A dictionary mapping field names (from the DSPy signature) to corresponding values + values: A Dictionary mapping field names (from the DSPy signature) to corresponding values that should be included in the message. role: The role of the message, which can be either "user" or "assistant". incomplete: If True, indicates that output field values are present in the set of specified @@ -228,18 +267,88 @@ def type_info(v): return {"role": role, "content": joined_messages} -def enumerate_fields(fields: dict) -> str: +def _format_constraint(name: str, value: Union[str, float]) -> str: + """ + Formats a constraint for a numeric field. + + Parameters: + name: The name of the constraint. + value: The value of the constraint. + + Returns: + The formatted constraint as a string. + """ + constraints = { + 'gt': f"greater than {value}", + 'lt': f"less than {value}", + 'ge': f"greater than or equal to {value}", + 'le': f"less than or equal to {value}", + 'multiple_of': f"a multiple of {value}", + 'allow_inf_nan': "allows infinite and NaN values" if value else "no infinite or NaN values allowed" + } + return constraints.get(name, f"{name}={value}") + + +def format_metadata_summary(field: pydantic.fields.FieldInfo) -> str: + """ + Formats a summary of the metadata for a field. + + Parameters: + field: The field whose metadata should be summarized. + + Returns: + A string summarizing the field's metadata. + """ + if not hasattr(field, 'metadata') or not field.metadata: + return "" + metadata_parts = [str(meta) for meta in field.metadata] + if metadata_parts: + return f" [Metadata: {'; '.join(metadata_parts)}]" + return "" + + +def format_metadata_constraints(field: FieldInfo) -> str: + """ + Formats the constraints for a field. + + Parameters: + field: The field whose constraints should be formatted. + + Returns: + A string containing the formatted constraints. + """ + if not hasattr(field, 'metadata') or not field.metadata: + return "" + formatted_constraints = [] + for meta in field.metadata: + constraint_names = [name for name in dir(meta) if not name.startswith('_')] + for name in constraint_names: + if hasattr(meta, name) and name in PERMITTED_CONSTRAINTS: + value = getattr(meta, name) + formatted_constraints.append(_format_constraint(name, value)) + if not formatted_constraints: + return "" + elif len(formatted_constraints) == 1: + return f" that is {formatted_constraints[0]}." + else: + *front, last = formatted_constraints + return f" that is {', '.join(front)} and {last}." + + +def enumerate_fields(fields: Dict) -> str: parts = [] for idx, (k, v) in enumerate(fields.items()): parts.append(f"{idx + 1}. `{k}`") parts[-1] += f" ({get_annotation_name(v.annotation)})" parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else "" - + metadata_info = format_metadata_summary(v) + if metadata_info: + parts[-1] += metadata_info return "\n".join(parts).strip() -def move_type_to_front(d): - # Move the 'type' key to the front of the dictionary, recursively, for LLM readability/adherence. +def move_type_to_front(d: Union[Dict, List, Any]) -> Union[Dict, List, Any]: + # Move the 'type' key to the front of the Dictionary, recursively, for LLM readability/adherence. if isinstance(d, Mapping): return {k: move_type_to_front(v) for k, v in sorted(d.items(), key=lambda item: (item[0] != "type", item[0]))} elif isinstance(d, list): @@ -247,7 +356,7 @@ def move_type_to_front(d): return d -def prepare_schema(field_type): +def prepare_schema(field_type: Type) -> Dict[str, Any]: schema = pydantic.TypeAdapter(field_type).json_schema() schema = move_type_to_front(schema) return schema @@ -259,7 +368,7 @@ def prepare_instructions(signature: SignatureMeta): parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields)) parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") - def field_metadata(field_name, field_info): + def field_metadata(field_name: str, field_info: FieldInfo) -> str: field_type = field_info.annotation if get_dspy_field_type(field_info) == "input" or field_type is str: @@ -268,6 +377,9 @@ def field_metadata(field_name, field_info): desc = "must be True or False" elif field_type in (int, float): desc = f"must be a single {field_type.__name__} value" + metadata_info = format_metadata_constraints(field_info) + if metadata_info: + desc += metadata_info elif inspect.isclass(field_type) and issubclass(field_type, enum.Enum): desc = f"must be one of: {'; '.join(field_type.__members__)}" elif hasattr(field_type, "__origin__") and field_type.__origin__ is Literal: @@ -298,4 +410,4 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]): objective = ("\n" + " " * 8).join([""] + instructions.splitlines()) parts.append(f"In adhering to this structure, your objective is: {objective}") - return "\n\n".join(parts).strip() + return "\n\n".join(parts).strip() \ No newline at end of file diff --git a/dspy/adapters/json_adapter.py b/dspy/adapters/json_adapter.py index 7567e7876c..c493533662 100644 --- a/dspy/adapters/json_adapter.py +++ b/dspy/adapters/json_adapter.py @@ -4,7 +4,7 @@ import logging import textwrap from copy import deepcopy -from typing import Any, Dict, KeysView, Literal, NamedTuple, Type +from typing import Any, Dict, KeysView, Literal, NamedTuple, Type, Union, List import json_repair import litellm @@ -27,11 +27,22 @@ class FieldInfoWithName(NamedTuple): name: str info: FieldInfo + +# Constraints that can be applied to numeric fields. +PERMITTED_CONSTRAINTS = {"gt", "lt", "ge", "le", "multiple_of", "allow_inf_nan"} + class JSONAdapter(Adapter): def __init__(self): pass - def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: + def __call__( + self, + lm: LM, + lm_kwargs: Dict[str, Any], + signature: Type[Signature], + demos: List[Dict[str, Any]], + inputs: Dict[str, Any] + ) -> List[Dict[str, Any]]: inputs = self.format(signature, demos, inputs) inputs = dict(prompt=inputs) if isinstance(inputs, str) else dict(messages=inputs) @@ -66,7 +77,12 @@ def __call__(self, lm: LM, lm_kwargs: dict[str, Any], signature: Type[Signature] return values - def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any]) -> list[dict[str, Any]]: + def format( + self, + signature: Type[Signature], + demos: List[Dict[str, Any]], + inputs: Dict[str, Any] + ) -> List[Dict[str, Any]]: messages = [] # Extract demos where some of the output_fields are not filled in. @@ -94,7 +110,7 @@ def format(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs return messages - def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: + def parse(self, signature: Type[Signature], completion: str) -> Dict[str, Any]: fields = json_repair.loads(completion) fields = {k: v for k, v in fields.items() if k in signature.output_fields} @@ -108,7 +124,12 @@ def parse(self, signature: Type[Signature], completion: str) -> dict[str, Any]: return fields - def format_fields(self, signature: Type[Signature], values: dict[str, Any], role: str) -> str: + def format_fields( + self, + signature: Type[Signature], + values: Dict[str, Any], + role: str + ) -> str: fields_with_values = { FieldInfoWithName(name=field_name, info=field_info): values.get( field_name, "Not supplied for this particular example." @@ -119,10 +140,22 @@ def format_fields(self, signature: Type[Signature], values: dict[str, Any], role return format_fields(role=role, fields_with_values=fields_with_values) - def format_turn(self, signature: Type[Signature], values, role: str, incomplete: bool = False, is_conversation_history: bool = False) -> dict[str, Any]: + def format_turn( + self, + signature: Type[Signature], + values, role: str, + incomplete: bool = False, + is_conversation_history: bool = False + ) -> Dict[str, Any]: return format_turn(signature, values, role, incomplete, is_conversation_history) - def format_finetune_data(self, signature: Type[Signature], demos: list[dict[str, Any]], inputs: dict[str, Any], outputs: dict[str, Any]) -> dict[str, list[Any]]: + def format_finetune_data( + self, + signature: Type[Signature], + demos: List[Dict[str, Any]], + inputs: Dict[str, Any], + outputs: Dict[str, Any] + ) -> Dict[str, List[Any]]: # TODO: implement format_finetune_data method in JSONAdapter raise NotImplementedError @@ -132,7 +165,7 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str: Formats the value of the specified field according to the field's DSPy type (input or output), annotation (e.g. str, int, etc.), and the type of the value itself. - Args: + Parameters: field_info: Information about the field, including its DSPy field type and annotation. value: The value of the field. @@ -146,15 +179,54 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str: return format_field_value(field_info=field_info, value=value) +def _format_constraint(name: str, value: Union[str, float]) -> str: + constraints = { + 'gt': f"greater than {value}", + 'lt': f"less than {value}", + 'ge': f"greater than or equal to {value}", + 'le': f"less than or equal to {value}", + 'multiple_of': f"a multiple of {value}", + 'allow_inf_nan': "allows infinite and NaN values" if value else "no infinite or NaN values allowed" + } + return constraints.get(name, f"{name}={value}") + + +def format_metadata_summary(field: pydantic.fields.FieldInfo) -> str: + if not hasattr(field, 'metadata') or not field.metadata: + return "" + metadata_parts = [str(meta) for meta in field.metadata] + if metadata_parts: + return f" [Metadata: {'; '.join(metadata_parts)}]" + return "" + + +def format_metadata_constraints(field: FieldInfo) -> str: + if not hasattr(field, 'metadata') or not field.metadata: + return "" + formatted_constraints = [] + for meta in field.metadata: + constraint_names = [name for name in dir(meta) if not name.startswith('_')] + for name in constraint_names: + if hasattr(meta, name) and name in PERMITTED_CONSTRAINTS: + value = getattr(meta, name) + formatted_constraints.append(_format_constraint(name, value)) + if not formatted_constraints: + return "" + elif len(formatted_constraints) == 1: + return f" that is {formatted_constraints[0]}." + else: + *front, last = formatted_constraints + return f" that is {', '.join(front)} and {last}." + + def format_fields(role: str, fields_with_values: Dict[FieldInfoWithName, Any]) -> str: """ Formats the values of the specified fields according to the field's DSPy type (input or output), annotation (e.g. str, int, etc.), and the type of the value itself. Joins the formatted values into a single string, which is is a multiline string if there are multiple fields. - Args: - fields_with_values: A dictionary mapping information about a field to its corresponding - value. + Parameters: + fields_with_values: A dictionary mapping information about a field to its corresponding value. Returns: The joined formatted values of the fields, represented as a string. """ @@ -183,7 +255,7 @@ def format_turn( Constructs a new message ("turn") to append to a chat thread. The message is carefully formatted so that it can instruct an LLM to generate responses conforming to the specified DSPy signature. - Args: + Parameters: signature: The DSPy signature to which future LLM responses should conform. values: A dictionary mapping field names (from the DSPy signature) to corresponding values that should be included in the message. @@ -245,13 +317,15 @@ def type_info(v): return {"role": role, "content": "\n\n".join(content).strip()} -def enumerate_fields(fields): +def enumerate_fields(fields: Dict) -> str: parts = [] for idx, (k, v) in enumerate(fields.items()): parts.append(f"{idx+1}. `{k}`") parts[-1] += f" ({get_annotation_name(v.annotation)})" parts[-1] += f": {v.json_schema_extra['desc']}" if v.json_schema_extra["desc"] != f"${{{k}}}" else "" - + metadata_info = format_metadata_summary(v) + if metadata_info: + parts[-1] += metadata_info return "\n".join(parts).strip() @@ -261,7 +335,7 @@ def prepare_instructions(signature: SignatureMeta): parts.append("Your output fields are:\n" + enumerate_fields(signature.output_fields)) parts.append("All interactions will be structured in the following way, with the appropriate values filled in.") - def field_metadata(field_name, field_info): + def field_metadata(field_name: str, field_info: FieldInfo) -> str: type_ = field_info.annotation if get_dspy_field_type(field_info) == "input" or type_ is str: @@ -270,6 +344,9 @@ def field_metadata(field_name, field_info): desc = "must be True or False" elif type_ in (int, float): desc = f"must be a single {type_.__name__} value" + metadata_info = format_metadata_constraints(field_info) + if metadata_info: + desc += metadata_info elif inspect.isclass(type_) and issubclass(type_, enum.Enum): desc = f"must be one of: {'; '.join(type_.__members__)}" elif hasattr(type_, "__origin__") and type_.__origin__ is Literal: @@ -313,7 +390,7 @@ def _get_structured_outputs_response_format(signature: SignatureMeta) -> pydanti Obtains the LiteLLM / OpenAI `response_format` parameter for generating structured outputs from an LM request, based on the output fields of the specified DSPy signature. - Args: + Parameters: signature: The DSPy signature for which to obtain the `response_format` request parameter. Returns: A Pydantic model representing the `response_format` parameter for the LM request.