Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Thinking part #1142

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Support Thinking part
  • Loading branch information
Kludex committed Mar 22, 2025
commit edffe70e24dcde4197b09a87587bbefa3e67a11a
31 changes: 31 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import asyncio

from pydantic_ai import Agent

agent = Agent(model='anthropic:claude-3-7-sonnet-latest')


@agent.tool_plain
def sum(a: int, b: int) -> int:
"""Get the sum of two numbers.

Args:
a: The first number.
b: The second number.

Returns:
The sum of the two numbers.
"""
return a + b


async def main():
async with agent.iter('Get me the sum of 1 and 2, using the sum tool.') as agent_run:
async for node in agent_run:
print(node)
print()
print(agent_run.result)


if __name__ == '__main__':
asyncio.run(main())
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
texts.append(part.content)
elif isinstance(part, _messages.ToolCallPart):
tool_calls.append(part)
elif isinstance(part, _messages.ThinkingPart):
...
else:
assert_never(part)

Expand Down
16 changes: 15 additions & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,20 @@ def has_content(self) -> bool:
return bool(self.content)


@dataclass
class ThinkingPart:
"""A thinking response from a model."""

content: str
"""The thinking content of the response."""

signature: str | None = None
"""The signature of the thinking."""

part_kind: Literal['thinking'] = 'thinking'
"""Part type identifier, this is available on all parts as a discriminator."""


@dataclass
class ToolCallPart:
"""A tool call from a model."""
Expand Down Expand Up @@ -442,7 +456,7 @@ def has_content(self) -> bool:
return bool(self.args)


ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, ThinkingPart], pydantic.Discriminator('part_kind')]
"""A message part returned by a model."""


Expand Down
17 changes: 14 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from json import JSONDecodeError, loads as json_loads
from typing import Any, Literal, Union, cast, overload

from anthropic.types import DocumentBlockParam
from anthropic.types import DocumentBlockParam, ThinkingBlock, ThinkingBlockParam
from httpx import AsyncClient as AsyncHTTPClient
from typing_extensions import assert_never, deprecated

Expand All @@ -27,6 +27,7 @@
RetryPromptPart,
SystemPromptPart,
TextPart,
ThinkingPart,
ToolCallPart,
ToolReturnPart,
UserPromptPart,
Expand Down Expand Up @@ -256,13 +257,14 @@ async def _messages_create(

try:
return await self.client.messages.create(
max_tokens=model_settings.get('max_tokens', 1024),
max_tokens=model_settings.get('max_tokens', 2048),
system=system_prompt or NOT_GIVEN,
messages=anthropic_messages,
model=self._model_name,
tools=tools or NOT_GIVEN,
tool_choice=tool_choice or NOT_GIVEN,
stream=stream,
thinking={'budget_tokens': 1024, 'type': 'enabled'},
temperature=model_settings.get('temperature', NOT_GIVEN),
top_p=model_settings.get('top_p', NOT_GIVEN),
timeout=model_settings.get('timeout', NOT_GIVEN),
Expand All @@ -279,6 +281,8 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
for item in response.content:
if isinstance(item, TextBlock):
items.append(TextPart(content=item.text))
elif isinstance(item, ThinkingBlock):
items.append(ThinkingPart(content=item.thinking, signature=item.signature))
else:
assert isinstance(item, ToolUseBlock), 'unexpected item type'
items.append(
Expand Down Expand Up @@ -345,10 +349,17 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
user_content_params.append(retry_param)
anthropic_messages.append(MessageParam(role='user', content=user_content_params))
elif isinstance(m, ModelResponse):
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
assistant_content_params: list[TextBlockParam | ToolUseBlockParam | ThinkingBlockParam] = []
for response_part in m.parts:
if isinstance(response_part, TextPart):
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
elif isinstance(response_part, ThinkingPart):
assert response_part.signature is not None, 'Thinking part must have a signature'
assistant_content_params.append(
ThinkingBlockParam(
thinking=response_part.content, signature=response_part.signature, type='thinking'
)
)
else:
tool_use_block_param = ToolUseBlockParam(
id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),
Expand Down
Loading