Skip to content

Commit

Permalink
mistralai[patch]: streaming tool calls (#19469)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored Mar 23, 2024
1 parent b43a9d5 commit b617085
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 7 deletions.
8 changes: 2 additions & 6 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def _convert_delta_to_message_chunk(
_delta: Dict, default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = _delta.get("role")
content = _delta.get("content", "")
content = _delta.get("content") or ""
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
additional_kwargs: Dict = {}
if tool_calls := _delta.get("tool_calls"):
additional_kwargs["tool_calls"] = [tc.model_dump() for tc in tool_calls]
additional_kwargs["tool_calls"] = tool_calls
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
Expand Down Expand Up @@ -355,8 +355,6 @@ def _stream(
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
if not delta["content"]:
continue
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
Expand Down Expand Up @@ -384,8 +382,6 @@ async def _astream(
if len(chunk["choices"]) == 0:
continue
delta = chunk["choices"][0]["delta"]
if not delta["content"]:
continue
new_chunk = _convert_delta_to_message_chunk(delta, default_chunk_class)
# make future chunks same type as first chunk
default_chunk_class = new_chunk.__class__
Expand Down
2 changes: 1 addition & 1 deletion libs/partners/mistralai/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-mistralai"
version = "0.1.0rc1"
version = "0.1.0rc2"
description = "An integration package connecting Mistral and LangChain"
authors = []
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
"""Test ChatMistral chat model."""

import json
from typing import Any

from langchain_core.messages import AIMessageChunk
from langchain_core.pydantic_v1 import BaseModel

from langchain_mistralai.chat_models import ChatMistralAI


Expand Down Expand Up @@ -83,3 +89,58 @@ def test_structured_output() -> None:
"What weighs more a pound of bricks or a pound of feathers"
)
assert isinstance(result, dict)


def test_streaming_structured_output() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)

class Person(BaseModel):
name: str
age: int

structured_llm = llm.with_structured_output(Person)
strm = structured_llm.stream("Erick, 27 years old")
chunk_num = 0
for chunk in strm:
assert chunk_num == 0, "should only have one chunk with model"
assert isinstance(chunk, Person)
assert chunk.name == "Erick"
assert chunk.age == 27
chunk_num += 1


def test_streaming_tool_call() -> None:
llm = ChatMistralAI(model="mistral-large", temperature=0)

class Person(BaseModel):
name: str
age: int

tool_llm = llm.bind_tools([Person])

# where it calls the tool
strm = tool_llm.stream("Erick, 27 years old")

additional_kwargs = None
for chunk in strm:
assert isinstance(chunk, AIMessageChunk)
assert chunk.content == ""
additional_kwargs = chunk.additional_kwargs

assert additional_kwargs is not None
assert "tool_calls" in additional_kwargs
assert len(additional_kwargs["tool_calls"]) == 1
assert additional_kwargs["tool_calls"][0]["function"]["name"] == "Person"
assert json.loads(additional_kwargs["tool_calls"][0]["function"]["arguments"]) == {
"name": "Erick",
"age": 27,
}

# where it doesn't call the tool
strm = tool_llm.stream("What is 2+2?")
acc: Any = None
for chunk in strm:
assert isinstance(chunk, AIMessageChunk)
acc = chunk if acc is None else acc + chunk
assert acc.content != ""
assert "tool_calls" not in acc.additional_kwargs

0 comments on commit b617085

Please sign in to comment.