Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev committed Jul 25, 2024
1 parent 0e2be55 commit f7ea8bf
Showing 1 changed file with 26 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,16 @@ def test_tool_calling(
) -> None:
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([magic_function])
model_ = rate_limiter | model.bind_tools([magic_function_no_args])

# Test invoke
query = "What is the value of magic_function(3)? Use the tool."
result = model_with_tools.invoke(query)
result = model_.invoke(query)
_validate_tool_call_message(result)

# Test stream
full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
for chunk in model_.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message(full)
Expand All @@ -218,13 +218,13 @@ def test_tool_calling_with_no_arguments(
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")

model_with_tools = model.bind_tools([magic_function_no_args])
model_ = rate_limiter | model.bind_tools([magic_function_no_args])
query = "What is the value of magic_function()? Use the tool."
result = model_with_tools.invoke(query)
result = model_.invoke(query)
_validate_tool_call_message_no_args(result)

full: Optional[BaseMessageChunk] = None
for chunk in model_with_tools.stream(query):
for chunk in model_.stream(query):
full = chunk if full is None else full + chunk # type: ignore
assert isinstance(full, AIMessage)
_validate_tool_call_message_no_args(full)
Expand All @@ -245,8 +245,9 @@ def test_bind_runnables_as_tools(
description="Generate a greeting in a particular style of speaking.",
)
model_with_tools = model.bind_tools([tool_])
model_ = rate_limiter | model_with_tools
query = "Using the tool, generate a Pirate greeting."
result = model_with_tools.invoke(query)
result = model_.invoke(query)
assert isinstance(result, AIMessage)
assert result.tool_calls
tool_call = result.tool_calls[0]
Expand Down Expand Up @@ -274,19 +275,21 @@ class Joke(BaseModelProper):
# or pydantic.v1.BaseModel but not pydantic.BaseModel from pydantic 2.
# We'll need to do a pass updating the type signatures.
chat = model.with_structured_output(Joke) # type: ignore[arg-type]
result = chat.invoke("Tell me a joke about cats.")
model_ = rate_limiter | chat
result = model_.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)

for chunk in chat.stream("Tell me a joke about cats."):
for chunk in model_.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = model.with_structured_output(Joke.schema())
result = chat.invoke("Tell me a joke about cats.")
model_ = rate_limiter | chat
result = model_.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

for chunk in chat.stream("Tell me a joke about cats."):
for chunk in model_.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
Expand All @@ -309,20 +312,20 @@ class Joke(BaseModel): # Uses langchain_core.pydantic_v1.BaseModel
punchline: str = Field(description="answer to resolve the joke")

# Pydantic class
chat = model.with_structured_output(Joke)
result = chat.invoke("Tell me a joke about cats.")
model_ = rate_limiter | model.with_structured_output(Joke)
result = model_.invoke("Tell me a joke about cats.")
assert isinstance(result, Joke)

for chunk in chat.stream("Tell me a joke about cats."):
for chunk in model_.stream("Tell me a joke about cats."):
assert isinstance(chunk, Joke)

# Schema
chat = model.with_structured_output(Joke.schema())
result = chat.invoke("Tell me a joke about cats.")
model_ = rate_limiter | model.with_structured_output(Joke.schema())
result = model_.invoke("Tell me a joke about cats.")
assert isinstance(result, dict)
assert set(result.keys()) == {"setup", "punchline"}

for chunk in chat.stream("Tell me a joke about cats."):
for chunk in model_.stream("Tell me a joke about cats."):
assert isinstance(chunk, dict)
assert isinstance(chunk, dict) # for mypy
assert set(chunk.keys()) == {"setup", "punchline"}
Expand All @@ -338,7 +341,7 @@ def test_tool_message_histories_string_content(
"""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([my_adder_tool])
model_with_tools = rate_limiter | model.bind_tools([my_adder_tool])
function_name = "my_adder_tool"
function_args = {"a": "1", "b": "2"}

Expand Down Expand Up @@ -376,7 +379,7 @@ def test_tool_message_histories_list_content(
"""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([my_adder_tool])
model_with_tools = rate_limiter | model.bind_tools([my_adder_tool])
function_name = "my_adder_tool"
function_args = {"a": 1, "b": 2}

Expand Down Expand Up @@ -419,7 +422,9 @@ def test_structured_few_shot_examples(
"""
if not self.has_tool_calling:
pytest.skip("Test requires tool calling.")
model_with_tools = model.bind_tools([my_adder_tool], tool_choice="any")
model_with_tools = rate_limiter | model.bind_tools(
[my_adder_tool], tool_choice="any"
)
function_name = "my_adder_tool"
function_args = {"a": 1, "b": 2}
function_result = json.dumps({"result": 3})
Expand Down Expand Up @@ -526,4 +531,4 @@ class color_picker(BaseModel):
]
),
]
model.bind_tools([color_picker]).invoke(messages)
(rate_limiter | model.bind_tools([color_picker])).invoke(messages)

0 comments on commit f7ea8bf

Please sign in to comment.