Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate tool call id if not present #1229

Merged
merged 8 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ with capture_run_messages() as messages: # (2)!
ToolCallPart(
tool_name='calc_volume',
args={'size': 6},
tool_call_id=None,
tool_call_id='pyd_ai_tool_call_id',
part_kind='tool-call',
)
],
Expand All @@ -761,7 +761,7 @@ with capture_run_messages() as messages: # (2)!
RetryPromptPart(
content='Please try again.',
tool_name='calc_volume',
tool_call_id=None,
tool_call_id='pyd_ai_tool_call_id',
timestamp=datetime.datetime(...),
part_kind='retry-prompt',
)
Expand All @@ -773,7 +773,7 @@ with capture_run_messages() as messages: # (2)!
ToolCallPart(
tool_name='calc_volume',
args={'size': 6},
tool_call_id=None,
tool_call_id='pyd_ai_tool_call_id',
part_kind='tool-call',
)
],
Expand Down
6 changes: 3 additions & 3 deletions docs/testing-evals.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Here's how we would write tests using [`TestModel`][pydantic_ai.models.test.Test
from datetime import timezone
import pytest

from dirty_equals import IsNow
from dirty_equals import IsNow, IsStr

from pydantic_ai import models, capture_run_messages
from pydantic_ai.models.test import TestModel
Expand Down Expand Up @@ -146,7 +146,7 @@ async def test_forecast():
'location': 'a',
'forecast_date': '2024-01-01', # (8)!
},
tool_call_id=None,
tool_call_id=IsStr(),
)
],
model_name='test',
Expand All @@ -157,7 +157,7 @@ async def test_forecast():
ToolReturnPart(
tool_name='weather_forecast',
content='Sunny with a chance of rain',
tool_call_id=None,
tool_call_id=IsStr(),
timestamp=IsNow(tz=timezone.utc),
),
],
Expand Down
11 changes: 7 additions & 4 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,10 @@ print(dice_result.all_messages())
ModelResponse(
parts=[
ToolCallPart(
tool_name='roll_die', args={}, tool_call_id=None, part_kind='tool-call'
tool_name='roll_die',
args={},
tool_call_id='pyd_ai_tool_call_id',
part_kind='tool-call',
)
],
model_name='gemini-1.5-flash',
Expand All @@ -99,7 +102,7 @@ print(dice_result.all_messages())
ToolReturnPart(
tool_name='roll_die',
content='4',
tool_call_id=None,
tool_call_id='pyd_ai_tool_call_id',
timestamp=datetime.datetime(...),
part_kind='tool-return',
)
Expand All @@ -111,7 +114,7 @@ print(dice_result.all_messages())
ToolCallPart(
tool_name='get_player_name',
args={},
tool_call_id=None,
tool_call_id='pyd_ai_tool_call_id',
part_kind='tool-call',
)
],
Expand All @@ -124,7 +127,7 @@ print(dice_result.all_messages())
ToolReturnPart(
tool_name='get_player_name',
content='Anne',
tool_call_id=None,
tool_call_id='pyd_ai_tool_call_id',
timestamp=datetime.datetime(...),
part_kind='tool-return',
)
Expand Down
8 changes: 7 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
ToolCallPartDelta,
)

from ._utils import generate_tool_call_id as _generate_tool_call_id

VendorId = Hashable
"""
Type alias for a vendor identifier, which can be any hashable type (e.g., a string, UUID, etc.)
Expand Down Expand Up @@ -221,7 +223,11 @@ def handle_tool_call_part(
ModelResponseStreamEvent: A `PartStartEvent` indicating that a new tool call part
has been added to the manager, or replaced an existing part.
"""
new_part = ToolCallPart(tool_name=tool_name, args=args, tool_call_id=tool_call_id)
new_part = ToolCallPart(
tool_name=tool_name,
args=args,
tool_call_id=tool_call_id or _generate_tool_call_id(),
)
if vendor_part_id is None:
# vendor_part_id is None, so we unconditionally append a new ToolCallPart to the end of the list
new_part_index = len(self._parts)
Expand Down
18 changes: 12 additions & 6 deletions pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import time
import uuid
from collections.abc import AsyncIterable, AsyncIterator, Iterator
from contextlib import asynccontextmanager, suppress
from dataclasses import dataclass, is_dataclass
Expand Down Expand Up @@ -195,12 +196,17 @@ def now_utc() -> datetime:
return datetime.now(tz=timezone.utc)


def guard_tool_call_id(
t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart, model_source: str
) -> str:
"""Type guard that checks a `tool_call_id` is not None both for static typing and runtime."""
assert t.tool_call_id is not None, f'{model_source} requires `tool_call_id` to be set: {t}'
return t.tool_call_id
def guard_tool_call_id(t: _messages.ToolCallPart | _messages.ToolReturnPart | _messages.RetryPromptPart) -> str:
"""Type guard that either returns the tool call id or generates a new one if it's None."""
return t.tool_call_id or generate_tool_call_id()


def generate_tool_call_id() -> str:
"""Generate a tool call id.

Ensure that the tool call id is unique.
"""
return f'pyd_ai_{uuid.uuid4().hex}'


class PeekableAsyncStream(Generic[T]):
Expand Down
42 changes: 15 additions & 27 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from opentelemetry._events import Event
from typing_extensions import TypeAlias

from ._utils import now_utc as _now_utc
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
from .exceptions import UnexpectedModelBehavior


Expand Down Expand Up @@ -268,8 +268,8 @@ class ToolReturnPart:
content: Any
"""The return value."""

tool_call_id: str | None = None
"""Optional tool call identifier, this is used by some models including OpenAI."""
tool_call_id: str
"""The tool call identifier, this is used by some models including OpenAI."""

timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp, when the tool returned."""
Expand Down Expand Up @@ -328,8 +328,11 @@ class RetryPromptPart:
tool_name: str | None = None
"""The name of the tool that was called, if any."""

tool_call_id: str | None = None
"""Optional tool call identifier, this is used by some models including OpenAI."""
tool_call_id: str = field(default_factory=_generate_tool_call_id)
"""The tool call identifier, this is used by some models including OpenAI.

In case the tool call id is not provided by the model, PydanticAI will generate a random one.
"""

timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp, when the retry was triggered."""
Expand Down Expand Up @@ -406,8 +409,11 @@ class ToolCallPart:
This is stored either as a JSON string or a Python dictionary depending on how data was received.
"""

tool_call_id: str | None = None
"""Optional tool call identifier, this is used by some models including OpenAI."""
tool_call_id: str = field(default_factory=_generate_tool_call_id)
"""The tool call identifier, this is used by some models including OpenAI.

In case the tool call id is not provided by the model, PydanticAI will generate a random one.
"""

part_kind: Literal['tool-call'] = 'tool-call'
"""Part type identifier, this is available on all parts as a discriminator."""
Expand Down Expand Up @@ -564,11 +570,7 @@ def as_part(self) -> ToolCallPart | None:
if self.tool_name_delta is None or self.args_delta is None:
return None

return ToolCallPart(
self.tool_name_delta,
self.args_delta,
self.tool_call_id,
)
return ToolCallPart(self.tool_name_delta, self.args_delta, self.tool_call_id or _generate_tool_call_id())

@overload
def apply(self, part: ModelResponsePart) -> ToolCallPart: ...
Expand Down Expand Up @@ -620,20 +622,11 @@ def _apply_to_delta(self, delta: ToolCallPartDelta) -> ToolCallPart | ToolCallPa
delta = replace(delta, args_delta=updated_args_delta)

if self.tool_call_id:
# Set the tool_call_id if it wasn't present, otherwise error if it has changed
if delta.tool_call_id is not None and delta.tool_call_id != self.tool_call_id:
raise UnexpectedModelBehavior(
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({delta=}, {self=})'
)
delta = replace(delta, tool_call_id=self.tool_call_id)

# If we now have enough data to create a full ToolCallPart, do so
if delta.tool_name_delta is not None and delta.args_delta is not None:
return ToolCallPart(
delta.tool_name_delta,
delta.args_delta,
delta.tool_call_id,
)
return ToolCallPart(delta.tool_name_delta, delta.args_delta, delta.tool_call_id or _generate_tool_call_id())

return delta

Expand All @@ -656,11 +649,6 @@ def _apply_to_part(self, part: ToolCallPart) -> ToolCallPart:
part = replace(part, args=updated_dict)

if self.tool_call_id:
# Replace the tool_call_id entirely if given
if part.tool_call_id is not None and part.tool_call_id != self.tool_call_id:
raise UnexpectedModelBehavior(
f'Cannot apply a new tool_call_id to a ToolCallPartDelta that already has one ({part=}, {self=})'
)
part = replace(part, tool_call_id=self.tool_call_id)
return part

Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
user_content_params.append(content)
elif isinstance(request_part, ToolReturnPart):
tool_result_block_param = ToolResultBlockParam(
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
tool_use_id=_guard_tool_call_id(t=request_part),
type='tool_result',
content=request_part.model_response_str(),
is_error=False,
Expand All @@ -337,7 +337,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
retry_param = TextBlockParam(type='text', text=request_part.model_response())
else:
retry_param = ToolResultBlockParam(
tool_use_id=_guard_tool_call_id(t=request_part, model_source='Anthropic'),
tool_use_id=_guard_tool_call_id(t=request_part),
type='tool_result',
content=request_part.model_response(),
is_error=True,
Expand All @@ -351,7 +351,7 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
else:
tool_use_block_param = ToolUseBlockParam(
id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
id=_guard_tool_call_id(t=response_part),
type='tool_use',
name=response_part.tool_name,
input=response_part.args_as_dict(),
Expand Down
3 changes: 1 addition & 2 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,9 @@ async def _map_user_prompt(part: UserPromptPart) -> list[MessageUnionTypeDef]:

@staticmethod
def _map_tool_call(t: ToolCallPart) -> ContentBlockOutputTypeDef:
assert t.tool_call_id is not None
return {
'toolUse': {
'toolUseId': t.tool_call_id,
'toolUseId': _utils.guard_tool_call_id(t=t),
'name': t.tool_name,
'input': t.args_as_dict(),
}
Expand Down
10 changes: 5 additions & 5 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing_extensions import assert_never, deprecated

from .. import ModelHTTPError, result
from .._utils import guard_tool_call_id as _guard_tool_call_id
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
from ..messages import (
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -225,7 +225,7 @@ def _process_response(self, response: ChatResponse) -> ModelResponse:
ToolCallPart(
tool_name=c.function.name,
args=c.function.arguments,
tool_call_id=c.id,
tool_call_id=c.id or _generate_tool_call_id(),
)
)
return ModelResponse(parts=parts, model_name=self._model_name)
Expand Down Expand Up @@ -262,7 +262,7 @@ def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[T
@staticmethod
def _map_tool_call(t: ToolCallPart) -> ToolCallV2:
return ToolCallV2(
id=_guard_tool_call_id(t=t, model_source='Cohere'),
id=_guard_tool_call_id(t=t),
type='function',
function=ToolCallV2Function(
name=t.tool_name,
Expand Down Expand Up @@ -294,7 +294,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
elif isinstance(part, ToolReturnPart):
yield ToolChatMessageV2(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
tool_call_id=_guard_tool_call_id(t=part),
content=part.model_response_str(),
)
elif isinstance(part, RetryPromptPart):
Expand All @@ -303,7 +303,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[ChatMessageV2]:
else:
yield ToolChatMessageV2(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='Cohere'),
tool_call_id=_guard_tool_call_id(t=part),
content=part.model_response(),
)
else:
Expand Down
8 changes: 1 addition & 7 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ async def _make_request(
url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}'

request_json = _gemini_request_ta.dump_json(request_data, by_alias=True)

async with self.client.stream(
'POST',
url,
Expand Down Expand Up @@ -603,12 +602,7 @@ def _process_response_from_parts(
if 'text' in part:
items.append(TextPart(content=part['text']))
elif 'function_call' in part:
items.append(
ToolCallPart(
tool_name=part['function_call']['name'],
args=part['function_call']['args'],
)
)
items.append(ToolCallPart(tool_name=part['function_call']['name'], args=part['function_call']['args']))
elif 'function_response' in part:
raise UnexpectedModelBehavior(
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
Expand Down
6 changes: 3 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def _map_message(self, message: ModelMessage) -> Iterable[chat.ChatCompletionMes
@staticmethod
def _map_tool_call(t: ToolCallPart) -> chat.ChatCompletionMessageToolCallParam:
return chat.ChatCompletionMessageToolCallParam(
id=_guard_tool_call_id(t=t, model_source='Groq'),
id=_guard_tool_call_id(t=t),
type='function',
function={'name': t.tool_name, 'arguments': t.args_as_json_str()},
)
Expand All @@ -335,7 +335,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletio
elif isinstance(part, ToolReturnPart):
yield chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
tool_call_id=_guard_tool_call_id(t=part),
content=part.model_response_str(),
)
elif isinstance(part, RetryPromptPart):
Expand All @@ -344,7 +344,7 @@ def _map_user_message(cls, message: ModelRequest) -> Iterable[chat.ChatCompletio
else:
yield chat.ChatCompletionToolMessageParam(
role='tool',
tool_call_id=_guard_tool_call_id(t=part, model_source='Groq'),
tool_call_id=_guard_tool_call_id(t=part),
content=part.model_response(),
)

Expand Down
Loading