Skip to content

Commit ebf741f

Browse files
Support async in dspy.Tool (#8106)
* async Tool * fix tests
1 parent aa46018 commit ebf741f

File tree

2 files changed

+178
-18
lines changed

2 files changed

+178
-18
lines changed

dspy/primitives/tool.py

+31-18
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import inspect
23
from typing import Any, Callable, Optional, get_origin, get_type_hints
34

@@ -132,10 +133,26 @@ def _parse_function(self, func: Callable, arg_desc: dict[str, str] = None):
132133
self.arg_types = self.arg_types or arg_types
133134
self.has_kwargs = any([param.kind == param.VAR_KEYWORD for param in sig.parameters.values()])
134135

135-
def _parse_args(self, **kwargs):
136+
def _validate_and_parse_args(self, **kwargs):
137+
# Validate the args value comply to the json schema.
138+
for k, v in kwargs.items():
139+
if k not in self.args:
140+
if self.has_kwargs:
141+
continue
142+
else:
143+
raise ValueError(f"Arg {k} is not in the tool's args.")
144+
try:
145+
instance = v.model_dump() if hasattr(v, "model_dump") else v
146+
type_str = self.args[k].get("type")
147+
if type_str is not None and type_str != "Any":
148+
validate(instance=instance, schema=self.args[k])
149+
except ValidationError as e:
150+
raise ValueError(f"Arg {k} is invalid: {e.message}")
151+
152+
# Parse the args to the correct type.
136153
parsed_kwargs = {}
137154
for k, v in kwargs.items():
138-
if k in self.arg_types and self.arg_types[k] != any:
155+
if k in self.arg_types and self.arg_types[k] != Any:
139156
# Create a pydantic model wrapper with a dummy field `value` to parse the arg to the correct type.
140157
# This is specifically useful for handling nested Pydantic models like `list[list[MyPydanticModel]]`
141158
pydantic_wrapper = create_model("Wrapper", value=(self.arg_types[k], ...))
@@ -147,19 +164,15 @@ def _parse_args(self, **kwargs):
147164

148165
@with_callbacks
149166
def __call__(self, **kwargs):
150-
for k, v in kwargs.items():
151-
if k not in self.args:
152-
if self.has_kwargs:
153-
# If the tool has kwargs, skip validation for unknown args
154-
continue
155-
raise ValueError(f"Arg {k} is not in the tool's args.")
156-
try:
157-
instance = v.model_dump() if hasattr(v, "model_dump") else v
158-
type_str = self.args[k].get("type")
159-
if type_str is not None and type_str != "Any":
160-
validate(instance=instance, schema=self.args[k])
161-
except ValidationError as e:
162-
raise ValueError(f"Arg {k} is invalid: {e.message}")
163-
164-
parsed_kwargs = self._parse_args(**kwargs)
165-
return self.func(**parsed_kwargs)
167+
parsed_kwargs = self._validate_and_parse_args(**kwargs)
168+
result = self.func(**parsed_kwargs)
169+
if asyncio.iscoroutine(result):
170+
raise ValueError("You are calling `__call__` on an async tool, please use `acall` instead.")
171+
return result
172+
173+
async def acall(self, **kwargs):
174+
parsed_kwargs = self._validate_and_parse_args(**kwargs)
175+
result = self.func(**parsed_kwargs)
176+
if not asyncio.iscoroutine(result):
177+
raise ValueError("You are calling `acall` on a non-async tool, please use `__call__` instead.")
178+
return await result

tests/primitives/test_tool.py

+147
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pydantic import BaseModel
33
from dspy.primitives.tool import Tool
44
from typing import Any, Optional
5+
import asyncio
56

67

78
# Test fixtures
@@ -67,6 +68,52 @@ def complex_dummy_function(profile: UserProfile, priority: int, notes: Optional[
6768
}
6869

6970

71+
async def async_dummy_function(x: int, y: str = "hello") -> str:
72+
"""An async dummy function for testing.
73+
74+
Args:
75+
x: An integer parameter
76+
y: A string parameter
77+
"""
78+
await asyncio.sleep(0.1) # Simulate some async work
79+
return f"{y} {x}"
80+
81+
82+
async def async_dummy_with_pydantic(model: DummyModel) -> str:
83+
"""An async dummy function that accepts a Pydantic model."""
84+
await asyncio.sleep(0.1) # Simulate some async work
85+
return f"{model.field1} {model.field2}"
86+
87+
88+
async def async_complex_dummy_function(
89+
profile: UserProfile, priority: int, notes: Optional[str] = None
90+
) -> dict[str, Any]:
91+
"""Process user profile with complex nested structure asynchronously.
92+
93+
Args:
94+
profile: User profile containing nested contact and address information
95+
priority: Priority level of the processing
96+
notes: Optional processing notes
97+
"""
98+
# Simulate some async processing work
99+
await asyncio.sleep(0.1)
100+
101+
primary_address = next(
102+
(addr for addr in profile.contact.addresses if addr.is_primary), profile.contact.addresses[0]
103+
)
104+
105+
# Simulate more async work after finding primary address
106+
await asyncio.sleep(0.1)
107+
108+
return {
109+
"user_id": profile.user_id,
110+
"name": profile.name,
111+
"priority": priority,
112+
"primary_address": primary_address.model_dump(),
113+
"notes": notes,
114+
}
115+
116+
70117
def test_basic_initialization():
71118
tool = Tool(name="test_tool", desc="A test tool", args={"param1": {"type": "string"}}, func=lambda x: x)
72119
assert tool.name == "test_tool"
@@ -198,7 +245,107 @@ def dummy_function(x: list[list[DummyModel]]):
198245
def test_tool_call_kwarg():
199246
def fn(x: int, **kwargs):
200247
return kwargs
248+
201249
tool = Tool(fn)
202250

203251
assert tool(x=1, y=2, z=3) == {"y": 2, "z": 3}
204252

253+
254+
@pytest.mark.asyncio
255+
async def test_async_tool_from_function():
256+
tool = Tool(async_dummy_function)
257+
258+
assert tool.name == "async_dummy_function"
259+
assert "An async dummy function for testing" in tool.desc
260+
assert "x" in tool.args
261+
assert "y" in tool.args
262+
assert tool.args["x"]["type"] == "integer"
263+
assert tool.args["y"]["type"] == "string"
264+
assert tool.args["y"]["default"] == "hello"
265+
266+
# Test async call
267+
result = await tool.acall(x=42, y="hello")
268+
assert result == "hello 42"
269+
270+
271+
@pytest.mark.asyncio
272+
async def test_async_tool_with_pydantic():
273+
tool = Tool(async_dummy_with_pydantic)
274+
275+
assert tool.name == "async_dummy_with_pydantic"
276+
assert "model" in tool.args
277+
assert tool.args["model"]["type"] == "object"
278+
assert "field1" in tool.args["model"]["properties"]
279+
assert "field2" in tool.args["model"]["properties"]
280+
281+
# Test async call with pydantic model
282+
model = DummyModel(field1="test", field2=123)
283+
result = await tool.acall(model=model)
284+
assert result == "test 123"
285+
286+
# Test async call with dict
287+
result = await tool.acall(model={"field1": "test", "field2": 123})
288+
assert result == "test 123"
289+
290+
291+
@pytest.mark.asyncio
292+
async def test_async_tool_with_complex_pydantic():
293+
tool = Tool(async_complex_dummy_function)
294+
295+
profile = UserProfile(
296+
user_id=1,
297+
name="Test User",
298+
contact=ContactInfo(
299+
300+
addresses=[
301+
Address(street="123 Main St", city="Test City", zip_code="12345", is_primary=True),
302+
Address(street="456 Side St", city="Test City", zip_code="12345"),
303+
],
304+
),
305+
)
306+
307+
result = await tool.acall(profile=profile, priority=1, notes="Test note")
308+
assert result["user_id"] == 1
309+
assert result["name"] == "Test User"
310+
assert result["priority"] == 1
311+
assert result["notes"] == "Test note"
312+
assert result["primary_address"]["street"] == "123 Main St"
313+
314+
315+
@pytest.mark.asyncio
316+
async def test_async_tool_invalid_call():
317+
tool = Tool(async_dummy_function)
318+
with pytest.raises(ValueError):
319+
await tool.acall(x="not an integer", y="hello")
320+
321+
322+
@pytest.mark.asyncio
323+
async def test_async_tool_with_kwargs():
324+
async def fn(x: int, **kwargs):
325+
return kwargs
326+
327+
tool = Tool(fn)
328+
329+
result = await tool.acall(x=1, y=2, z=3)
330+
assert result == {"y": 2, "z": 3}
331+
332+
333+
@pytest.mark.asyncio
334+
async def test_async_concurrent_calls():
335+
"""Test that multiple async tools can run concurrently."""
336+
tool = Tool(async_dummy_function)
337+
338+
# Create multiple concurrent calls
339+
tasks = [tool.acall(x=i, y=f"hello{i}") for i in range(5)]
340+
341+
# Run them concurrently and measure time
342+
start_time = asyncio.get_event_loop().time()
343+
results = await asyncio.gather(*tasks)
344+
end_time = asyncio.get_event_loop().time()
345+
346+
# Verify results, `asyncio.gather` returns results in the order of the tasks
347+
assert results == [f"hello{i} {i}" for i in range(5)]
348+
349+
# Check that it ran concurrently (should take ~0.1s, not ~0.5s)
350+
# We use 0.3s as threshold to account for some overhead
351+
assert end_time - start_time < 0.3

0 commit comments

Comments
 (0)