diff --git a/docs/docs/integrations/chat/baidu_qianfan_endpoint.ipynb b/docs/docs/integrations/chat/baidu_qianfan_endpoint.ipynb index decf441360193..43a1336c1b99a 100644 --- a/docs/docs/integrations/chat/baidu_qianfan_endpoint.ipynb +++ b/docs/docs/integrations/chat/baidu_qianfan_endpoint.ipynb @@ -55,17 +55,9 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[INFO] [09-15 20:00:29] logging.py:55 [t:139698882193216]: requesting llm api endpoint: /chat/eb-instant\n" - ] - } - ], + "outputs": [], "source": [ "\"\"\"For basic init and call\"\"\"\n", "import os\n", @@ -126,9 +118,7 @@ "from langchain.schema import HumanMessage\n", "from langchain_community.chat_models import QianfanChatEndpoint\n", "\n", - "chatLLM = QianfanChatEndpoint(\n", - " streaming=True,\n", - ")\n", + "chatLLM = QianfanChatEndpoint()\n", "res = chatLLM.stream([HumanMessage(content=\"hi\")], streaming=True)\n", "for r in res:\n", " print(\"chat resp:\", r)\n", @@ -260,11 +250,11 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.5" }, "vscode": { "interpreter": { - "hash": "6fa70026b407ae751a5c9e6bd7f7d482379da8ad616f98512780b705c84ee157" + "hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb" } } }, diff --git a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py index ecf00a3982aa8..cf656817a9d1a 100644 --- a/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py +++ b/libs/community/langchain_community/chat_models/baidu_qianfan_endpoint.py @@ -1,5 +1,3 @@ -from __future__ import annotations - import logging from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast @@ -244,7 +242,14 @@ def _generate( """ if self.streaming: completion = "" + token_usage = {} + chat_generation_info: Dict = {} for chunk in self._stream(messages, stop, run_manager, **kwargs): + chat_generation_info = ( + chunk.generation_info + if chunk.generation_info is not None + else chat_generation_info + ) completion += chunk.text lc_msg = AIMessage(content=completion, additional_kwargs={}) gen = ChatGeneration( @@ -253,7 +258,10 @@ def _generate( ) return ChatResult( generations=[gen], - llm_output={"token_usage": {}, "model_name": self.model}, + llm_output={ + "token_usage": chat_generation_info.get("usage", {}), + "model_name": self.model, + }, ) params = self._convert_prompt_msg_params(messages, **kwargs) response_payload = self.client.do(**params) @@ -279,7 +287,13 @@ async def _agenerate( if self.streaming: completion = "" token_usage = {} + chat_generation_info: Dict = {} async for chunk in self._astream(messages, stop, run_manager, **kwargs): + chat_generation_info = ( + chunk.generation_info + if chunk.generation_info is not None + else chat_generation_info + ) completion += chunk.text lc_msg = AIMessage(content=completion, additional_kwargs={}) @@ -289,7 +303,10 @@ async def _agenerate( ) return ChatResult( generations=[gen], - llm_output={"token_usage": {}, "model_name": self.model}, + llm_output={ + "token_usage": chat_generation_info.get("usage", {}), + "model_name": self.model, + }, ) params = self._convert_prompt_msg_params(messages, **kwargs) response_payload = await self.client.ado(**params) @@ -315,16 +332,19 @@ def _stream( **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: params = self._convert_prompt_msg_params(messages, **kwargs) + params["stream"] = True for res in self.client.do(**params): if res: msg = _convert_dict_to_message(res) + additional_kwargs = msg.additional_kwargs.get("function_call", {}) chunk = ChatGenerationChunk( text=res["result"], message=AIMessageChunk( content=msg.content, role="assistant", - additional_kwargs=msg.additional_kwargs, + additional_kwargs=additional_kwargs, ), + generation_info=msg.additional_kwargs, ) yield chunk if run_manager: @@ -338,16 +358,19 @@ async def _astream( **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: params = self._convert_prompt_msg_params(messages, **kwargs) + params["stream"] = True async for res in await self.client.ado(**params): if res: msg = _convert_dict_to_message(res) + additional_kwargs = msg.additional_kwargs.get("function_call", {}) chunk = ChatGenerationChunk( text=res["result"], message=AIMessageChunk( content=msg.content, role="assistant", - additional_kwargs=msg.additional_kwargs, + additional_kwargs=additional_kwargs, ), + generation_info=msg.additional_kwargs, ) yield chunk if run_manager: diff --git a/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py b/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py index c059693761869..e69de29bb2d1d 100644 --- a/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py +++ b/libs/community/tests/integration_tests/chat_models/test_baiduqianfan.py @@ -1,53 +0,0 @@ -from typing import cast - -from langchain_core.pydantic_v1 import SecretStr -from pytest import CaptureFixture, MonkeyPatch - -from langchain_community.chat_models.baidu_qianfan_endpoint import ( - QianfanChatEndpoint, -) - - -def test_qianfan_key_masked_when_passed_from_env( - monkeypatch: MonkeyPatch, capsys: CaptureFixture -) -> None: - """Test initialization with an API key provided via an env variable""" - monkeypatch.setenv("QIANFAN_AK", "test-api-key") - monkeypatch.setenv("QIANFAN_SK", "test-secret-key") - - chat = QianfanChatEndpoint() - print(chat.qianfan_ak, end="") - captured = capsys.readouterr() - assert captured.out == "**********" - - print(chat.qianfan_sk, end="") - captured = capsys.readouterr() - assert captured.out == "**********" - - -def test_qianfan_key_masked_when_passed_via_constructor( - capsys: CaptureFixture, -) -> None: - """Test initialization with an API key provided via the initializer""" - chat = QianfanChatEndpoint( - qianfan_ak="test-api-key", - qianfan_sk="test-secret-key", - ) - print(chat.qianfan_ak, end="") - captured = capsys.readouterr() - assert captured.out == "**********" - - print(chat.qianfan_sk, end="") - captured = capsys.readouterr() - - assert captured.out == "**********" - - -def test_uses_actual_secret_value_from_secret_str() -> None: - """Test that actual secret is retrieved using `.get_secret_value()`.""" - chat = QianfanChatEndpoint( - qianfan_ak="test-api-key", - qianfan_sk="test-secret-key", - ) - assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key" - assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key" diff --git a/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py b/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py index f3bc4bb774616..caa5ef20eb4cc 100644 --- a/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py +++ b/libs/community/tests/integration_tests/chat_models/test_qianfan_endpoint.py @@ -1,18 +1,24 @@ """Test Baidu Qianfan Chat Endpoint.""" -from typing import Any +from typing import Any, cast +import pytest from langchain_core.callbacks import CallbackManager from langchain_core.messages import ( AIMessage, BaseMessage, + BaseMessageChunk, FunctionMessage, HumanMessage, ) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate +from langchain_core.pydantic_v1 import SecretStr +from pytest import CaptureFixture, MonkeyPatch -from langchain_community.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint +from langchain_community.chat_models.baidu_qianfan_endpoint import ( + QianfanChatEndpoint, +) from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler _FUNCTIONS: Any = [ @@ -139,6 +145,25 @@ def test_multiple_history() -> None: assert isinstance(response.content, str) +def test_chat_generate() -> None: + """Tests chat generate works.""" + chat = QianfanChatEndpoint() + response = chat.generate( + [ + [ + HumanMessage(content="Hello."), + AIMessage(content="Hello!"), + HumanMessage(content="How are you doing?"), + ] + ] + ) + assert isinstance(response, LLMResult) + for generations in response.generations: + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + + def test_stream() -> None: """Test that stream works.""" chat = QianfanChatEndpoint(streaming=True) @@ -156,6 +181,57 @@ def test_stream() -> None: assert callback_handler.llm_streams > 0 assert isinstance(response.content, str) + res = chat.stream( + [ + HumanMessage(content="Hello."), + AIMessage(content="Hello!"), + HumanMessage(content="Who are you?"), + ] + ) + + assert len(list(res)) >= 1 + + +@pytest.mark.asyncio +async def test_async_invoke() -> None: + chat = QianfanChatEndpoint() + res = await chat.ainvoke([HumanMessage(content="Hello")]) + assert isinstance(res, BaseMessage) + assert res.content != "" + + +@pytest.mark.asyncio +async def test_async_generate() -> None: + """Tests chat agenerate works.""" + chat = QianfanChatEndpoint() + response = await chat.agenerate( + [ + [ + HumanMessage(content="Hello."), + AIMessage(content="Hello!"), + HumanMessage(content="How are you doing?"), + ] + ] + ) + assert isinstance(response, LLMResult) + for generations in response.generations: + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + + +@pytest.mark.asyncio +async def test_async_stream() -> None: + chat = QianfanChatEndpoint(streaming=True) + async for token in chat.astream( + [ + HumanMessage(content="Hello."), + AIMessage(content="Hello!"), + HumanMessage(content="Who are you?"), + ] + ): + assert isinstance(token, BaseMessageChunk) + def test_multiple_messages() -> None: """Tests multiple messages works.""" @@ -232,3 +308,48 @@ def test_rate_limit() -> None: for res in responses: assert isinstance(res, BaseMessage) assert isinstance(res.content, str) + + +def test_qianfan_key_masked_when_passed_from_env( + monkeypatch: MonkeyPatch, capsys: CaptureFixture +) -> None: + """Test initialization with an API key provided via an env variable""" + monkeypatch.setenv("QIANFAN_AK", "test-api-key") + monkeypatch.setenv("QIANFAN_SK", "test-secret-key") + + chat = QianfanChatEndpoint() + print(chat.qianfan_ak, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + print(chat.qianfan_sk, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + +def test_qianfan_key_masked_when_passed_via_constructor( + capsys: CaptureFixture, +) -> None: + """Test initialization with an API key provided via the initializer""" + chat = QianfanChatEndpoint( + qianfan_ak="test-api-key", + qianfan_sk="test-secret-key", + ) + print(chat.qianfan_ak, end="") + captured = capsys.readouterr() + assert captured.out == "**********" + + print(chat.qianfan_sk, end="") + captured = capsys.readouterr() + + assert captured.out == "**********" + + +def test_uses_actual_secret_value_from_secret_str() -> None: + """Test that actual secret is retrieved using `.get_secret_value()`.""" + chat = QianfanChatEndpoint( + qianfan_ak="test-api-key", + qianfan_sk="test-secret-key", + ) + assert cast(SecretStr, chat.qianfan_ak).get_secret_value() == "test-api-key" + assert cast(SecretStr, chat.qianfan_sk).get_secret_value() == "test-secret-key"