Skip to content

Commit

Permalink
community[patch]: fix qianfan chat stream calling caused exception (#…
Browse files Browse the repository at this point in the history
…13800)

- **Description:** 
`QianfanChatEndpoint` extends `BaseChatModel` as a super class, which
has a default stream implement might concat the MessageChunk with
`__add__`. When call stream(), a ValueError for duplicated key will be
raise.
  - **Issues:** 
     * #13546  
     * #13548
     * merge two single test file related to qianfan.
  - **Dependencies:** no
  - **Tag maintainer:**

---------

Co-authored-by: root <[email protected]>
Co-authored-by: Harrison Chase <[email protected]>
  • Loading branch information
3 people authored Jan 9, 2024
1 parent 656e87b commit 70b6315
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 76 deletions.
20 changes: 5 additions & 15 deletions docs/docs/integrations/chat/baidu_qianfan_endpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,9 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[INFO] [09-15 20:00:29] logging.py:55 [t:139698882193216]: requesting llm api endpoint: /chat/eb-instant\n"
]
}
],
"outputs": [],
"source": [
"\"\"\"For basic init and call\"\"\"\n",
"import os\n",
Expand Down Expand Up @@ -126,9 +118,7 @@
"from langchain.schema import HumanMessage\n",
"from langchain_community.chat_models import QianfanChatEndpoint\n",
"\n",
"chatLLM = QianfanChatEndpoint(\n",
" streaming=True,\n",
")\n",
"chatLLM = QianfanChatEndpoint()\n",
"res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
"for r in res:\n",
" print(\"chat resp:\", r)\n",
Expand Down Expand Up @@ -260,11 +250,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.11.5"
},
"vscode": {
"interpreter": {
"hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157"
"hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb"
}
}
},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import logging
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast

Expand Down Expand Up @@ -244,7 +242,14 @@ def _generate(
"""
if self.streaming:
completion = ""
token_usage = {}
chat_generation_info: Dict = {}
for chunk in self._stream(messages, stop, run_manager, **kwargs):
chat_generation_info = (
chunk.generation_info
if chunk.generation_info is not None
else chat_generation_info
)
completion += chunk.text
lc_msg = AIMessage(content=completion, additional_kwargs={})
gen = ChatGeneration(
Expand All @@ -253,7 +258,10 @@ def _generate(
)
return ChatResult(
generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model},
llm_output={
"token_usage": chat_generation_info.get("usage", {}),
"model_name": self.model,
},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = self.client.do(**params)
Expand All @@ -279,7 +287,13 @@ async def _agenerate(
if self.streaming:
completion = ""
token_usage = {}
chat_generation_info: Dict = {}
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
chat_generation_info = (
chunk.generation_info
if chunk.generation_info is not None
else chat_generation_info
)
completion += chunk.text

lc_msg = AIMessage(content=completion, additional_kwargs={})
Expand All @@ -289,7 +303,10 @@ async def _agenerate(
)
return ChatResult(
generations=[gen],
llm_output={"token_usage": {}, "model_name": self.model},
llm_output={
"token_usage": chat_generation_info.get("usage", {}),
"model_name": self.model,
},
)
params = self._convert_prompt_msg_params(messages, **kwargs)
response_payload = await self.client.ado(**params)
Expand All @@ -315,16 +332,19 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stream"] = True
for res in self.client.do(**params):
if res:
msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk(
content=msg.content,
role="assistant",
additional_kwargs=msg.additional_kwargs,
additional_kwargs=additional_kwargs,
),
generation_info=msg.additional_kwargs,
)
yield chunk
if run_manager:
Expand All @@ -338,16 +358,19 @@ async def _astream(
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._convert_prompt_msg_params(messages, **kwargs)
params["stream"] = True
async for res in await self.client.ado(**params):
if res:
msg = _convert_dict_to_message(res)
additional_kwargs = msg.additional_kwargs.get("function_call", {})
chunk = ChatGenerationChunk(
text=res["result"],
message=AIMessageChunk(
content=msg.content,
role="assistant",
additional_kwargs=msg.additional_kwargs,
additional_kwargs=additional_kwargs,
),
generation_info=msg.additional_kwargs,
)
yield chunk
if run_manager:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,53 +0,0 @@
from typing import cast

from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain_community.chat_models.baidu_qianfan_endpoint import (
QianfanChatEndpoint,
)


def test_qianfan_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")

chat = QianfanChatEndpoint()
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

print(chat.qianfan_sk, end="")
captured = capsys.readouterr()
assert captured.out == "**********"


def test_qianfan_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

print(chat.qianfan_sk, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
"""Test Baidu Qianfan Chat Endpoint."""

from typing import Any
from typing import Any, cast

import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
FunctionMessage,
HumanMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture, MonkeyPatch

from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
from langchain_community.chat_models.baidu_qianfan_endpoint import (
QianfanChatEndpoint,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

_FUNCTIONS: Any = [
Expand Down Expand Up @@ -139,6 +145,25 @@ def test_multiple_history() -> None:
assert isinstance(response.content, str)


def test_chat_generate() -> None:
"""Tests chat generate works."""
chat = QianfanChatEndpoint()
response = chat.generate(
[
[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
]
)
assert isinstance(response, LLMResult)
for generations in response.generations:
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)


def test_stream() -> None:
"""Test that stream works."""
chat = QianfanChatEndpoint(streaming=True)
Expand All @@ -156,6 +181,57 @@ def test_stream() -> None:
assert callback_handler.llm_streams > 0
assert isinstance(response.content, str)

res = chat.stream(
[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="Who are you?"),
]
)

assert len(list(res)) >= 1


@pytest.mark.asyncio
async def test_async_invoke() -> None:
chat = QianfanChatEndpoint()
res = await chat.ainvoke([HumanMessage(content="Hello")])
assert isinstance(res, BaseMessage)
assert res.content != ""


@pytest.mark.asyncio
async def test_async_generate() -> None:
"""Tests chat agenerate works."""
chat = QianfanChatEndpoint()
response = await chat.agenerate(
[
[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
]
)
assert isinstance(response, LLMResult)
for generations in response.generations:
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)


@pytest.mark.asyncio
async def test_async_stream() -> None:
chat = QianfanChatEndpoint(streaming=True)
async for token in chat.astream(
[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="Who are you?"),
]
):
assert isinstance(token, BaseMessageChunk)


def test_multiple_messages() -> None:
"""Tests multiple messages works."""
Expand Down Expand Up @@ -232,3 +308,48 @@ def test_rate_limit() -> None:
for res in responses:
assert isinstance(res, BaseMessage)
assert isinstance(res.content, str)


def test_qianfan_key_masked_when_passed_from_env(
monkeypatch: MonkeyPatch, capsys: CaptureFixture
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("QIANFAN_AK", "test-api-key")
monkeypatch.setenv("QIANFAN_SK", "test-secret-key")

chat = QianfanChatEndpoint()
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

print(chat.qianfan_sk, end="")
captured = capsys.readouterr()
assert captured.out == "**********"


def test_qianfan_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
print(chat.qianfan_ak, end="")
captured = capsys.readouterr()
assert captured.out == "**********"

print(chat.qianfan_sk, end="")
captured = capsys.readouterr()

assert captured.out == "**********"


def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = QianfanChatEndpoint(
qianfan_ak="test-api-key",
qianfan_sk="test-secret-key",
)
assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key"
assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"

0 comments on commit 70b6315

Please sign in to comment.