Skip to content

Commit edffe70

Browse files
committed
Support Thinking part
1 parent d8effbd commit edffe70

File tree

4 files changed

+62
-4
lines changed

4 files changed

+62
-4
lines changed

main.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import asyncio
2+
3+
from pydantic_ai import Agent
4+
5+
agent = Agent(model='anthropic:claude-3-7-sonnet-latest')
6+
7+
8+
@agent.tool_plain
9+
def sum(a: int, b: int) -> int:
10+
"""Get the sum of two numbers.
11+
12+
Args:
13+
a: The first number.
14+
b: The second number.
15+
16+
Returns:
17+
The sum of the two numbers.
18+
"""
19+
return a + b
20+
21+
22+
async def main():
23+
async with agent.iter('Get me the sum of 1 and 2, using the sum tool.') as agent_run:
24+
async for node in agent_run:
25+
print(node)
26+
print()
27+
print(agent_run.result)
28+
29+
30+
if __name__ == '__main__':
31+
asyncio.run(main())

pydantic_ai_slim/pydantic_ai/_agent_graph.py

+2
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]:
406406
texts.append(part.content)
407407
elif isinstance(part, _messages.ToolCallPart):
408408
tool_calls.append(part)
409+
elif isinstance(part, _messages.ThinkingPart):
410+
...
409411
else:
410412
assert_never(part)
411413

pydantic_ai_slim/pydantic_ai/messages.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,20 @@ def has_content(self) -> bool:
393393
return bool(self.content)
394394

395395

396+
@dataclass
397+
class ThinkingPart:
398+
"""A thinking response from a model."""
399+
400+
content: str
401+
"""The thinking content of the response."""
402+
403+
signature: str | None = None
404+
"""The signature of the thinking."""
405+
406+
part_kind: Literal['thinking'] = 'thinking'
407+
"""Part type identifier, this is available on all parts as a discriminator."""
408+
409+
396410
@dataclass
397411
class ToolCallPart:
398412
"""A tool call from a model."""
@@ -442,7 +456,7 @@ def has_content(self) -> bool:
442456
return bool(self.args)
443457

444458

445-
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart], pydantic.Discriminator('part_kind')]
459+
ModelResponsePart = Annotated[Union[TextPart, ToolCallPart, ThinkingPart], pydantic.Discriminator('part_kind')]
446460
"""A message part returned by a model."""
447461

448462

pydantic_ai_slim/pydantic_ai/models/anthropic.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from json import JSONDecodeError, loads as json_loads
1010
from typing import Any, Literal, Union, cast, overload
1111

12-
from anthropic.types import DocumentBlockParam
12+
from anthropic.types import DocumentBlockParam, ThinkingBlock, ThinkingBlockParam
1313
from httpx import AsyncClient as AsyncHTTPClient
1414
from typing_extensions import assert_never, deprecated
1515

@@ -27,6 +27,7 @@
2727
RetryPromptPart,
2828
SystemPromptPart,
2929
TextPart,
30+
ThinkingPart,
3031
ToolCallPart,
3132
ToolReturnPart,
3233
UserPromptPart,
@@ -256,13 +257,14 @@ async def _messages_create(
256257

257258
try:
258259
return await self.client.messages.create(
259-
max_tokens=model_settings.get('max_tokens', 1024),
260+
max_tokens=model_settings.get('max_tokens', 2048),
260261
system=system_prompt or NOT_GIVEN,
261262
messages=anthropic_messages,
262263
model=self._model_name,
263264
tools=tools or NOT_GIVEN,
264265
tool_choice=tool_choice or NOT_GIVEN,
265266
stream=stream,
267+
thinking={'budget_tokens': 1024, 'type': 'enabled'},
266268
temperature=model_settings.get('temperature', NOT_GIVEN),
267269
top_p=model_settings.get('top_p', NOT_GIVEN),
268270
timeout=model_settings.get('timeout', NOT_GIVEN),
@@ -279,6 +281,8 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
279281
for item in response.content:
280282
if isinstance(item, TextBlock):
281283
items.append(TextPart(content=item.text))
284+
elif isinstance(item, ThinkingBlock):
285+
items.append(ThinkingPart(content=item.thinking, signature=item.signature))
282286
else:
283287
assert isinstance(item, ToolUseBlock), 'unexpected item type'
284288
items.append(
@@ -345,10 +349,17 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Me
345349
user_content_params.append(retry_param)
346350
anthropic_messages.append(MessageParam(role='user', content=user_content_params))
347351
elif isinstance(m, ModelResponse):
348-
assistant_content_params: list[TextBlockParam | ToolUseBlockParam] = []
352+
assistant_content_params: list[TextBlockParam | ToolUseBlockParam | ThinkingBlockParam] = []
349353
for response_part in m.parts:
350354
if isinstance(response_part, TextPart):
351355
assistant_content_params.append(TextBlockParam(text=response_part.content, type='text'))
356+
elif isinstance(response_part, ThinkingPart):
357+
assert response_part.signature is not None, 'Thinking part must have a signature'
358+
assistant_content_params.append(
359+
ThinkingBlockParam(
360+
thinking=response_part.content, signature=response_part.signature, type='thinking'
361+
)
362+
)
352363
else:
353364
tool_use_block_param = ToolUseBlockParam(
354365
id=_guard_tool_call_id(t=response_part, model_source='Anthropic'),

0 commit comments

Comments
 (0)