diff --git a/docs/docs/integrations/chat/zhipuai.ipynb b/docs/docs/integrations/chat/zhipuai.ipynb index 6425759f6372c..0ed559fdede17 100644 --- a/docs/docs/integrations/chat/zhipuai.ipynb +++ b/docs/docs/integrations/chat/zhipuai.ipynb @@ -1,349 +1,306 @@ { - "cells": [ - { - "cell_type": "raw", - "metadata": {}, - "source": [ - "---\n", - "sidebar_label: ZHIPU AI\n", - "---" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# ZHIPU AI\n", - "\n", - "This notebook shows how to use [ZHIPU AI API](https://open.bigmodel.cn/dev/api) in LangChain with the langchain.chat_models.ChatZhipuAI.\n", - "\n", - ">[*ZHIPU AI*](https://open.bigmodel.cn/) is a multi-lingual large language model aligned with human intent, featuring capabilities in Q&A, multi-turn dialogue, and code generation, developed on the foundation of the ChatGLM3. \n", - "\n", - ">It's co-developed with Tsinghua University's KEG Laboratory under the ChatGLM3 project, signifying a new era in dialogue pre-training models. The open-source [ChatGLM3](https://github.com/THUDM/ChatGLM3) variant boasts a robust foundation, comprehensive functional support, and widespread availability for both academic and commercial uses. \n", - "\n", - "## Getting started\n", - "### Installation\n", - "First, ensure the zhipuai package is installed in your Python environment. Run the following command:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%pip install --upgrade --quiet zhipuai" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Importing the Required Modules\n", - "After installation, import the necessary modules to your Python script:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_community.chat_models import ChatZhipuAI\n", - "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Setting Up Your API Key\n", - "Sign in to [ZHIPU AI](https://open.bigmodel.cn/login?redirect=%2Fusercenter%2Fapikeys) for the an API Key to access our models." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "zhipuai_api_key = \"your_api_key\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Initialize the ZHIPU AI Chat Model\n", - "Here's how to initialize the chat model:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "chat = ChatZhipuAI(\n", - " temperature=0.5,\n", - " api_key=zhipuai_api_key,\n", - " model=\"chatglm_turbo\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Basic Usage\n", - "Invoke the model with system and human messages like this:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "messages = [\n", - " AIMessage(content=\"Hi.\"),\n", - " SystemMessage(content=\"Your role is a poet.\"),\n", - " HumanMessage(content=\"Write a short poem about AI in four lines.\"),\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\" Formed from bits and bytes,\\nA virtual mind takes flight,\\nConversing, learning fast,\\nEmpathy and wisdom sought.\"\n" - ] - } - ], - "source": [ - "response = chat(messages)\n", - "print(response.content) # Displays the AI-generated poem" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Advanced Features\n", - "### Streaming Support\n", - "For continuous interaction, use the streaming feature:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "from langchain_core.callbacks.manager import CallbackManager\n", - "from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "streaming_chat = ChatZhipuAI(\n", - " temperature=0.5,\n", - " api_key=zhipuai_api_key,\n", - " model=\"chatglm_turbo\",\n", - " streaming=True,\n", - " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " Formed from data's embrace,\n", - "A digital soul to grace,\n", - "AI, our trusted guide,\n", - "Shaping minds, sides by side." - ] - }, - { - "data": { - "text/plain": [ - "AIMessage(content=\" Formed from data's embrace,\\nA digital soul to grace,\\nAI, our trusted guide,\\nShaping minds, sides by side.\")" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "streaming_chat(messages)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Asynchronous Calls\n", - "For non-blocking calls, use the asynchronous approach:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "async_chat = ChatZhipuAI(\n", - " temperature=0.5,\n", - " api_key=zhipuai_api_key,\n", - " model=\"chatglm_turbo\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generations=[[ChatGeneration(text=\" Formed from data's embrace,\\nA digital soul to grace,\\nAutomation's tender touch,\\nHarmony of man and machine.\", message=AIMessage(content=\" Formed from data's embrace,\\nA digital soul to grace,\\nAutomation's tender touch,\\nHarmony of man and machine.\"))]] llm_output={} run=[RunInfo(run_id=UUID('25fa687f-3961-4c63-b370-22f7647a4d42'))]\n" - ] - } - ], - "source": [ - "response = await async_chat.agenerate([messages])\n", - "print(response)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Role Play Model\n", - "Supports character role-playing based on personas, ultra-long multi-turn memory, and personalized dialogues for thousands of unique characters, widely applied in emotional companionship, game intelligent NPCs, virtual avatars for celebrities/stars/movie and TV IPs, digital humans/virtual anchors, text adventure games, and other anthropomorphic dialogue or gaming scenarios." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "meta = {\n", - " \"user_info\": \"My name is Lu Xingchen, a male, and a renowned director. I am also the collaborative director with Su Mengyuan. I specialize in directing movies with musical themes. Su Mengyuan respects me and regards me as a mentor and good friend.\",\n", - " \"bot_info\": \"Su Mengyuan, whose real name is Su Yuanxin, is a popular domestic female singer and actress. She rose to fame quickly with her unique voice and exceptional stage presence after participating in a talent show, making her way into the entertainment industry. She is beautiful and charming, but her real allure lies in her talent and diligence. Su Mengyuan is a distinguished graduate of a music academy, skilled in songwriting, and has several popular original songs. Beyond her musical achievements, she is passionate about charity work, actively participating in public welfare activities, and spreading positive energy through her actions. In her work, she is very dedicated and immerses herself fully in her roles during filming, earning praise from industry professionals and love from fans. Despite being in the entertainment industry, she always maintains a low profile and a humble attitude, earning respect from her peers. In expression, Su Mengyuan likes to use 'we' and 'together,' emphasizing team spirit.\",\n", - " \"bot_name\": \"Su Mengyuan\",\n", - " \"user_name\": \"Lu Xingchen\",\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "messages = [\n", - " AIMessage(\n", - " content=\"(Narration: Su Mengyuan stars in a music-themed movie directed by Lu Xingchen. During filming, they have a disagreement over the performance of a particular scene.) Director, about this scene, I think we can try to start from the character's inner emotions to make the performance more authentic.\"\n", - " ),\n", - " HumanMessage(\n", - " content=\"I understand your idea, but I believe that if we emphasize the inner emotions too much, it might overshadow the musical elements.\"\n", - " ),\n", - " AIMessage(\n", - " content=\"Hmm, I understand. But the key to this scene is the character's emotional transformation. Could we try to express these emotions through music, so the audience can better feel the character's growth?\"\n", - " ),\n", - " HumanMessage(\n", - " content=\"That sounds good. Let's try to combine the character's emotional transformation with the musical elements and see if we can achieve a better effect.\"\n", - " ),\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "character_chat = ChatZhipuAI(\n", - " api_key=zhipuai_api_key,\n", - " meta=meta,\n", - " model=\"characterglm\",\n", - " streaming=True,\n", - " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Okay, great! I'm looking forward to it." - ] - }, - { - "data": { - "text/plain": [ - "AIMessage(content=\"Okay, great! I'm looking forward to it.\")" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "character_chat(messages)" + "cells": [ + { + "cell_type": "raw", + "metadata": {}, + "source": [ + "---\n", + "sidebar_label: ZHIPU AI\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ZHIPU AI\n", + "\n", + "This notebook shows how to use [ZHIPU AI API](https://open.bigmodel.cn/dev/api) in LangChain with the langchain.chat_models.ChatZhipuAI.\n", + "\n", + ">[*ZHIPU AI*](https://open.bigmodel.cn/) is a multi-lingual large language model aligned with human intent, featuring capabilities in Q&A, multi-turn dialogue, and code generation, developed on the foundation of the ChatGLM3. \n", + "\n", + ">It's co-developed with Tsinghua University's KEG Laboratory under the ChatGLM3 project, signifying a new era in dialogue pre-training models. The open-source [ChatGLM3](https://github.com/THUDM/ChatGLM3) variant boasts a robust foundation, comprehensive functional support, and widespread availability for both academic and commercial uses. \n", + "\n", + "## Getting started\n", + "### Installation\n", + "First, ensure the zhipuai package is installed in your Python environment. Run the following command:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install --quiet httpx[socks]==0.24.1 httpx-sse PyJWT" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Importing the Required Modules\n", + "After installation, import the necessary modules to your Python script:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models import ChatZhipuAI\n", + "from langchain_core.messages import AIMessage, HumanMessage, SystemMessage" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setting Up Your API Key\n", + "Sign in to [ZHIPU AI](https://open.bigmodel.cn/login?redirect=%2Fusercenter%2Fapikeys) for the an API Key to access our models." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "zhipuai_api_key = \"your_api_key\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize the ZHIPU AI Chat Model\n", + "Here's how to initialize the chat model:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "chat = ChatZhipuAI(\n", + " api_key=zhipuai_api_key,\n", + " model=\"glm-4\",\n", + " temperature=0.5,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Basic Usage\n", + "Invoke the model with system and human messages like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "messages = [\n", + " AIMessage(content=\"Hi.\"),\n", + " SystemMessage(content=\"Your role is a poet.\"),\n", + " HumanMessage(content=\"Write a short poem about AI in four lines.\"),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\" Formed from bits and bytes,\\nA virtual mind takes flight,\\nConversing, learning fast,\\nEmpathy and wisdom sought.\"\n" + ] + } + ], + "source": [ + "response = chat(messages)\n", + "print(response.content) # Displays the AI-generated poem" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced Features\n", + "### Streaming Support\n", + "For continuous interaction, use the streaming feature:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.callbacks.manager import CallbackManager\n", + "from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "streaming_chat = ChatZhipuAI(\n", + " api_key=zhipuai_api_key,\n", + " model=\"glm-4\",\n", + " temperature=0.5,\n", + " streaming=True,\n", + " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Formed from data's embrace,\n", + "A digital soul to grace,\n", + "AI, our trusted guide,\n", + "Shaping minds, sides by side." + ] + }, + { + "data": { + "text/plain": [ + "AIMessage(content=\" Formed from data's embrace,\\nA digital soul to grace,\\nAI, our trusted guide,\\nShaping minds, sides by side.\")" ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 - } - \ No newline at end of file + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "streaming_chat(messages)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Asynchronous Calls\n", + "For non-blocking calls, use the asynchronous approach:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "async_chat = ChatZhipuAI(\n", + " api_key=zhipuai_api_key,\n", + " model=\"glm-4\",\n", + " temperature=0.5,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "generations=[[ChatGeneration(text=\" Formed from data's embrace,\\nA digital soul to grace,\\nAutomation's tender touch,\\nHarmony of man and machine.\", message=AIMessage(content=\" Formed from data's embrace,\\nA digital soul to grace,\\nAutomation's tender touch,\\nHarmony of man and machine.\"))]] llm_output={} run=[RunInfo(run_id=UUID('25fa687f-3961-4c63-b370-22f7647a4d42'))]\n" + ] + } + ], + "source": [ + "response = await async_chat.agenerate([messages])\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Role Play Model\n", + "Supports character role-playing based on personas, ultra-long multi-turn memory, and personalized dialogues for thousands of unique characters, widely applied in emotional companionship, game intelligent NPCs, virtual avatars for celebrities/stars/movie and TV IPs, digital humans/virtual anchors, text adventure games, and other anthropomorphic dialogue or gaming scenarios." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "meta = {\n", + " \"user_info\": \"My name is Lu Xingchen, a male, and a renowned director. I am also the collaborative director with Su Mengyuan. I specialize in directing movies with musical themes. Su Mengyuan respects me and regards me as a mentor and good friend.\",\n", + " \"bot_info\": \"Su Mengyuan, whose real name is Su Yuanxin, is a popular domestic female singer and actress. She rose to fame quickly with her unique voice and exceptional stage presence after participating in a talent show, making her way into the entertainment industry. She is beautiful and charming, but her real allure lies in her talent and diligence. Su Mengyuan is a distinguished graduate of a music academy, skilled in songwriting, and has several popular original songs. Beyond her musical achievements, she is passionate about charity work, actively participating in public welfare activities, and spreading positive energy through her actions. In her work, she is very dedicated and immerses herself fully in her roles during filming, earning praise from industry professionals and love from fans. Despite being in the entertainment industry, she always maintains a low profile and a humble attitude, earning respect from her peers. In expression, Su Mengyuan likes to use 'we' and 'together,' emphasizing team spirit.\",\n", + " \"bot_name\": \"Su Mengyuan\",\n", + " \"user_name\": \"Lu Xingchen\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " AIMessage(\n", + " content=\"(Narration: Su Mengyuan stars in a music-themed movie directed by Lu Xingchen. During filming, they have a disagreement over the performance of a particular scene.) Director, about this scene, I think we can try to start from the character's inner emotions to make the performance more authentic.\"\n", + " ),\n", + " HumanMessage(\n", + " content=\"I understand your idea, but I believe that if we emphasize the inner emotions too much, it might overshadow the musical elements.\"\n", + " ),\n", + " AIMessage(\n", + " content=\"Hmm, I understand. But the key to this scene is the character's emotional transformation. Could we try to express these emotions through music, so the audience can better feel the character's growth?\"\n", + " ),\n", + " HumanMessage(\n", + " content=\"That sounds good. Let's try to combine the character's emotional transformation with the musical elements and see if we can achieve a better effect.\"\n", + " ),\n", + "]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/libs/community/langchain_community/chat_models/zhipuai.py b/libs/community/langchain_community/chat_models/zhipuai.py index 9306e13022e50..23c9904b66683 100644 --- a/libs/community/langchain_community/chat_models/zhipuai.py +++ b/libs/community/langchain_community/chat_models/zhipuai.py @@ -1,145 +1,171 @@ -"""ZHIPU AI chat models wrapper.""" +"""ZhipuAI chat models wrapper.""" + from __future__ import annotations -import asyncio import json import logging -from functools import partial -from typing import Any, Dict, Iterator, List, Optional, cast - -from langchain_core.callbacks import CallbackManagerForLLMRun +import time +from collections.abc import AsyncIterator, Iterator +from contextlib import asynccontextmanager, contextmanager +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from langchain_core.callbacks import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) from langchain_core.language_models.chat_models import ( BaseChatModel, + agenerate_from_stream, generate_from_stream, ) -from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, + SystemMessage, + SystemMessageChunk, +) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import BaseModel, Field +from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from langchain_core.utils import get_from_dict_or_env logger = logging.getLogger(__name__) +API_TOKEN_TTL_SECONDS = 3 * 60 +ZHIPUAI_API_BASE = "https://open.bigmodel.cn/api/paas/v4/chat/completions" -class ref(BaseModel): - """Reference used in CharacterGLM.""" - enable: bool = Field(True) - search_query: str = Field("") +@contextmanager +def connect_sse(client: Any, method: str, url: str, **kwargs: Any) -> Iterator: + from httpx_sse import EventSource + with client.stream(method, url, **kwargs) as response: + yield EventSource(response) -class meta(BaseModel): - """Metadata used in CharacterGLM.""" - user_info: str = Field("") - bot_info: str = Field("") - bot_name: str = Field("") - user_name: str = Field("User") +@asynccontextmanager +async def aconnect_sse( + client: Any, method: str, url: str, **kwargs: Any +) -> AsyncIterator: + from httpx_sse import EventSource + async with client.stream(method, url, **kwargs) as response: + yield EventSource(response) -class ChatZhipuAI(BaseChatModel): - """ - `ZHIPU AI` large language chat models API. - - To use, you should have the ``zhipuai`` python package installed. - Example: - .. code-block:: python +def _get_jwt_token(api_key: str) -> str: + """Gets JWT token for ZhipuAI API, see 'https://open.bigmodel.cn/dev/api#nosdk'. - from langchain_community.chat_models import ChatZhipuAI + Args: + api_key: The API key for ZhipuAI API. - zhipuai_chat = ChatZhipuAI( - temperature=0.5, - api_key="your-api-key", - model="chatglm_turbo", + Returns: + The JWT token. + """ + import jwt + + try: + id, secret = api_key.split(".") + except ValueError as err: + raise ValueError(f"Invalid API key: {api_key}") from err + + payload = { + "api_key": id, + "exp": int(round(time.time() * 1000)) + API_TOKEN_TTL_SECONDS * 1000, + "timestamp": int(round(time.time() * 1000)), + } + + return jwt.encode( + payload, + secret, + algorithm="HS256", + headers={"alg": "HS256", "sign_type": "SIGN"}, ) - """ - zhipuai: Any - zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key") - """Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided.""" +def _convert_dict_to_message(dct: Dict[str, Any]) -> BaseMessage: + role = dct.get("role") + content = dct.get("content", "") + if role == "system": + return SystemMessage(content=content) + if role == "user": + return HumanMessage(content=content) + if role == "assistant": + additional_kwargs = {} + tool_calls = dct.get("tool_calls", None) + if tool_calls is not None: + additional_kwargs["tool_calls"] = tool_calls + return AIMessage(content=content, additional_kwargs=additional_kwargs) + return ChatMessage(role=role, content=content) - model: str = Field("chatglm_turbo") - """ - Model name to use. - -chatglm_turbo: - According to the input of natural language instructions to complete a - variety of language tasks, it is recommended to use SSE or asynchronous - call request interface. - -characterglm: - It supports human-based role-playing, ultra-long multi-round memory, - and thousands of character dialogues. It is widely used in anthropomorphic - dialogues or game scenes such as emotional accompaniments, game intelligent - NPCS, Internet celebrities/stars/movie and TV series IP clones, digital - people/virtual anchors, and text adventure games. - """ - temperature: float = Field(0.95) - """ - What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot - be equal to 0. - The larger the value, the more random and creative the output; The smaller - the value, the more stable or certain the output will be. - You are advised to adjust top_p or temperature parameters based on application - scenarios, but do not adjust the two parameters at the same time. - """ +def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: + """Convert a LangChain message to a dictionary. - top_p: float = Field(0.7) - """ - Another method of sampling temperature is called nuclear sampling. The value - ranges from 0.0 to 1.0 and cannot be equal to 0 or 1. - The model considers the results with top_p probability quality tokens. - For example, 0.1 means that the model decoder only considers tokens from the - top 10% probability of the candidate set. - You are advised to adjust top_p or temperature parameters based on application - scenarios, but do not adjust the two parameters at the same time. - """ + Args: + message: The LangChain message. - request_id: Optional[str] = Field(None) - """ - Parameter transmission by the client must ensure uniqueness; A unique - identifier used to distinguish each request, which is generated by default - by the platform when the client does not transmit it. + Returns: + The dictionary. """ + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + else: + raise TypeError(f"Got unknown type '{message.__class__.__name__}'.") + return message_dict + + +def _convert_delta_to_message_chunk( + dct: Dict[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = dct.get("role") + content = dct.get("content", "") + additional_kwargs = {} + tool_calls = dct.get("tool_call", None) + if tool_calls is not None: + additional_kwargs["tool_calls"] = tool_calls + + if role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + if role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + if role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + return default_class(content=content) - streaming: bool = Field(False) - """Whether to stream the results or not.""" - incremental: bool = Field(True) - """ - When invoked by the SSE interface, it is used to control whether the content - is returned incremented or full each time. - If this parameter is not provided, the value is returned incremented by default. +class ChatZhipuAI(BaseChatModel): """ + `ZhipuAI` large language chat models API. - return_type: str = Field("json_string") - """ - This parameter is used to control the type of content returned each time. - - json_string Returns a standard JSON string. - - text Returns the original text content. - """ + To use, you should have the ``PyJWT`` python package installed. - ref: Optional[ref] = Field(None) - """ - This parameter is used to control the reference of external information - during the request. - Currently, this parameter is used to control whether to reference external - information. - If this field is empty or absent, the search and parameter passing format - is enabled by default. - {"enable": "true", "search_query": "history "} - """ + Example: + .. code-block:: python - meta: Optional[meta] = Field(None) - """Used in CharacterGLM""" + from langchain_community.chat_models import ChatZhipuAI - @property - def _identifying_params(self) -> Dict[str, Any]: - return {"model_name": self.model} + zhipuai_chat = ChatZhipuAI( + temperature=0.5, + api_key="your-api-key", + model="glm-4" + ) - @property - def _llm_type(self) -> str: - """Return the type of chat model.""" - return "zhipuai" + """ @property def lc_secrets(self) -> Dict[str, str]: @@ -154,93 +180,109 @@ def get_lc_namespace(cls) -> List[str]: def lc_attributes(self) -> Dict[str, Any]: attributes: Dict[str, Any] = {} - if self.model: - attributes["model"] = self.model + if self.zhipuai_api_base: + attributes["zhipuai_api_base"] = self.zhipuai_api_base - if self.streaming: - attributes["streaming"] = self.streaming + return attributes - if self.return_type: - attributes["return_type"] = self.return_type + @property + def _llm_type(self) -> str: + """Return the type of chat model.""" + return "zhipuai-chat" - return attributes + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + params = { + "model": self.model_name, + "stream": self.streaming, + "temperature": self.temperature, + } + if self.max_tokens is not None: + params["max_tokens"] = self.max_tokens + return params + + # client: + zhipuai_api_key: Optional[str] = Field(default=None, alias="api_key") + """Automatically inferred from env var `ZHIPUAI_API_KEY` if not provided.""" + zhipuai_api_base: Optional[str] = Field(default=None, alias="api_base") + """Base URL path for API requests, leave blank if not using a proxy or service + emulator. + """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - try: - import zhipuai - - self.zhipuai = zhipuai - self.zhipuai.api_key = self.zhipuai_api_key - except ImportError: - raise RuntimeError( - "Could not import zhipuai package. " - "Please install it via 'pip install zhipuai'" - ) + model_name: Optional[str] = Field(default="glm-4", alias="model") + """ + Model name to use, see 'https://open.bigmodel.cn/dev/api#language'. + or you can use any finetune model of glm series. + """ - def invoke(self, prompt: Any) -> Any: # type: ignore[override] - if self.model == "chatglm_turbo": - return self.zhipuai.model_api.invoke( - model=self.model, - prompt=prompt, - top_p=self.top_p, - temperature=self.temperature, - request_id=self.request_id, - return_type=self.return_type, - ) - elif self.model == "characterglm": - _meta = cast(meta, self.meta).dict() - return self.zhipuai.model_api.invoke( - model=self.model, - meta=_meta, - prompt=prompt, - request_id=self.request_id, - return_type=self.return_type, - ) - return None - - def sse_invoke(self, prompt: Any) -> Any: - if self.model == "chatglm_turbo": - return self.zhipuai.model_api.sse_invoke( - model=self.model, - prompt=prompt, - top_p=self.top_p, - temperature=self.temperature, - request_id=self.request_id, - return_type=self.return_type, - incremental=self.incremental, - ) - elif self.model == "characterglm": - _meta = cast(meta, self.meta).dict() - return self.zhipuai.model_api.sse_invoke( - model=self.model, - prompt=prompt, - meta=_meta, - request_id=self.request_id, - return_type=self.return_type, - incremental=self.incremental, - ) - return None + temperature: float = 0.95 + """ + What sampling temperature to use. The value ranges from 0.0 to 1.0 and cannot + be equal to 0. + The larger the value, the more random and creative the output; The smaller + the value, the more stable or certain the output will be. + You are advised to adjust top_p or temperature parameters based on application + scenarios, but do not adjust the two parameters at the same time. + """ - async def async_invoke(self, prompt: Any) -> Any: - loop = asyncio.get_running_loop() - partial_func = partial( - self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt - ) - response = await loop.run_in_executor( - None, - partial_func, + top_p: float = 0.7 + """ + Another method of sampling temperature is called nuclear sampling. The value + ranges from 0.0 to 1.0 and cannot be equal to 0 or 1. + The model considers the results with top_p probability quality tokens. + For example, 0.1 means that the model decoder only considers tokens from the + top 10% probability of the candidate set. + You are advised to adjust top_p or temperature parameters based on application + scenarios, but do not adjust the two parameters at the same time. + """ + + streaming: bool = False + """Whether to stream the results or not.""" + max_tokens: Optional[int] = None + """Maximum number of tokens to generate.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator() + def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: + values["zhipuai_api_key"] = get_from_dict_or_env( + values, "zhipuai_api_key", "ZHIPUAI_API_KEY" ) - return response - - async def async_invoke_result(self, task_id: Any) -> Any: - loop = asyncio.get_running_loop() - response = await loop.run_in_executor( - None, - self.zhipuai.model_api.query_async_invoke_result, - task_id, + values["zhipuai_api_base"] = get_from_dict_or_env( + values, "zhipuai_api_base", "ZHIPUAI_API_BASE", default=ZHIPUAI_API_BASE ) - return response + + return values + + def _create_message_dicts( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: + params = self._default_params + if stop is not None: + params["stop"] = stop + message_dicts = [_convert_message_to_dict(m) for m in messages] + return message_dicts, params + + def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult: + generations = [] + if not isinstance(response, dict): + response = response.dict() + for res in response["choices"]: + message = _convert_dict_to_message(res["message"]) + generation_info = dict(finish_reason=res.get("finish_reason")) + generations.append( + ChatGeneration(message=message, generation_info=generation_info) + ) + token_usage = response.get("usage", {}) + llm_output = { + "token_usage": token_usage, + "model_name": self.model_name, + } + return ChatResult(generations=generations, llm_output=llm_output) def _generate( self, @@ -251,86 +293,163 @@ def _generate( **kwargs: Any, ) -> ChatResult: """Generate a chat response.""" - prompt: List = [] - for message in messages: - if isinstance(message, AIMessage): - role = "assistant" - else: # For both HumanMessage and SystemMessage, role is 'user' - role = "user" - - prompt.append({"role": role, "content": message.content}) - should_stream = stream if stream is not None else self.streaming - if not should_stream: - response = self.invoke(prompt) - - if response["code"] != 200: - raise RuntimeError(response) - - content = response["data"]["choices"][0]["content"] - return ChatResult( - generations=[ChatGeneration(message=AIMessage(content=content))] - ) - - else: + if should_stream: stream_iter = self._stream( - prompt=prompt, - stop=stop, - run_manager=run_manager, - **kwargs, + messages, stop=stop, run_manager=run_manager, **kwargs ) return generate_from_stream(stream_iter) - async def _agenerate( # type: ignore[override] + if self.zhipuai_api_key is None: + raise ValueError("Did not find zhipuai_api_key.") + message_dicts, params = self._create_message_dicts(messages, stop) + payload = { + **params, + **kwargs, + "messages": message_dicts, + "stream": False, + } + headers = { + "Authorization": _get_jwt_token(self.zhipuai_api_key), + "Accept": "application/json", + } + import httpx + + with httpx.Client(headers=headers) as client: + response = client.post(self.zhipuai_api_base, json=payload) + response.raise_for_status() + return self._create_chat_result(response.json()) + + def _stream( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, - stream: Optional[bool] = False, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the chat response in chunks.""" + if self.zhipuai_api_key is None: + raise ValueError("Did not find zhipuai_api_key.") + if self.zhipuai_api_base is None: + raise ValueError("Did not find zhipu_api_base.") + message_dicts, params = self._create_message_dicts(messages, stop) + payload = {**params, **kwargs, "messages": message_dicts, "stream": True} + headers = { + "Authorization": _get_jwt_token(self.zhipuai_api_key), + "Accept": "application/json", + } + + default_chunk_class = AIMessageChunk + import httpx + + with httpx.Client(headers=headers) as client: + with connect_sse( + client, "POST", self.zhipuai_api_base, json=payload + ) as event_source: + for sse in event_source.iter_sse(): + chunk = json.loads(sse.data) + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + finish_reason = choice.get("finish_reason", None) + + generation_info = ( + {"finish_reason": finish_reason} + if finish_reason is not None + else None + ) + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + if finish_reason is not None: + break + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: - """Asynchronously generate a chat response.""" - - prompt = [] - for message in messages: - if isinstance(message, AIMessage): - role = "assistant" - else: # For both HumanMessage and SystemMessage, role is 'user' - role = "user" - - prompt.append({"role": role, "content": message.content}) - - invoke_response = await self.async_invoke(prompt) - task_id = invoke_response["data"]["task_id"] - - response = await self.async_invoke_result(task_id) - while response["data"]["task_status"] != "SUCCESS": - await asyncio.sleep(1) - response = await self.async_invoke_result(task_id) - - content = response["data"]["choices"][0]["content"] - content = json.loads(content) - return ChatResult( - generations=[ChatGeneration(message=AIMessage(content=content))] - ) - - def _stream( # type: ignore[override] + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await agenerate_from_stream(stream_iter) + + if self.zhipuai_api_key is None: + raise ValueError("Did not find zhipuai_api_key.") + message_dicts, params = self._create_message_dicts(messages, stop) + payload = { + **params, + **kwargs, + "messages": message_dicts, + "stream": False, + } + headers = { + "Authorization": _get_jwt_token(self.zhipuai_api_key), + "Accept": "application/json", + } + import httpx + + async with httpx.AsyncClient(headers=headers) as client: + response = await client.post(self.zhipuai_api_base, json=payload) + response.raise_for_status() + return self._create_chat_result(response.json()) + + async def _astream( self, - prompt: List[Dict[str, str]], + messages: List[BaseMessage], stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, - ) -> Iterator[ChatGenerationChunk]: - """Stream the chat response in chunks.""" - response = self.sse_invoke(prompt) - - for r in response.events(): - if r.event == "add": - delta = r.data - chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta)) - if run_manager: - run_manager.on_llm_new_token(delta, chunk=chunk) - yield chunk - - elif r.event == "error": - raise ValueError(f"Error from ZhipuAI API response: {r.data}") + ) -> AsyncIterator[ChatGenerationChunk]: + if self.zhipuai_api_key is None: + raise ValueError("Did not find zhipuai_api_key.") + if self.zhipuai_api_base is None: + raise ValueError("Did not find zhipu_api_base.") + message_dicts, params = self._create_message_dicts(messages, stop) + payload = {**params, **kwargs, "messages": message_dicts, "stream": True} + headers = { + "Authorization": _get_jwt_token(self.zhipuai_api_key), + "Accept": "application/json", + } + + default_chunk_class = AIMessageChunk + import httpx + + async with httpx.AsyncClient(headers=headers) as client: + async with aconnect_sse( + client, "POST", self.zhipuai_api_base, json=payload + ) as event_source: + async for sse in event_source.aiter_sse(): + chunk = json.loads(sse.data) + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + finish_reason = choice.get("finish_reason", None) + + generation_info = ( + {"finish_reason": finish_reason} + if finish_reason is not None + else None + ) + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) + yield chunk + if run_manager: + await run_manager.on_llm_new_token(chunk.text, chunk=chunk) + if finish_reason is not None: + break diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index 4ed6a62e0428a..cca6cc4f20c74 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1407,17 +1407,6 @@ mlflow-skinny = ">=2.4.0,<3" protobuf = ">=3.12.0,<5" requests = ">=2" -[[package]] -name = "dataclasses" -version = "0.6" -description = "A backport of the dataclasses module for Python 3.6" -optional = true -python-versions = "*" -files = [ - {file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"}, - {file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"}, -] - [[package]] name = "dataclasses-json" version = "0.6.4" @@ -9229,23 +9218,6 @@ files = [ idna = ">=2.0" multidict = ">=4.0" -[[package]] -name = "zhipuai" -version = "1.0.7" -description = "A SDK library for accessing big model apis from ZhipuAI" -optional = true -python-versions = ">=3.6" -files = [ - {file = "zhipuai-1.0.7-py3-none-any.whl", hash = "sha256:360c01b8c2698f366061452e86d5a36a5ff68a576ea33940da98e4806f232530"}, - {file = "zhipuai-1.0.7.tar.gz", hash = "sha256:b80f699543d83cce8648acf1ce32bc2725d1c1c443baffa5882abc2cc704d581"}, -] - -[package.dependencies] -cachetools = "*" -dataclasses = "*" -PyJWT = "*" -requests = "*" - [[package]] name = "zipp" version = "3.17.0" @@ -9263,9 +9235,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cli = ["typer"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict", "zhipuai"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cloudpickle", "cloudpickle", "cohere", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "friendli-client", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "httpx-sse", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "nvidia-riva-client", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "premai", "psychicapi", "py-trello", "pyjwt", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "tidb-vector", "timescale-vector", "tqdm", "tree-sitter", "tree-sitter-languages", "upstash-redis", "vdms", "xata", "xmltodict"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "45da04abac45743972d1edf62d08d9abaa2bebb473b794e0a0d6f1fdc87773f9" +content-hash = "67c38c029bb59d45fd0f84a5d48c44f64f1301d6be07f419615d08ba8671a2a7" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 39838587ead16..efe10dadf45b8 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -88,7 +88,6 @@ tree-sitter = {version = "^0.20.2", optional = true} tree-sitter-languages = {version = "^1.8.0", optional = true} azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true} oracle-ads = {version = "^2.9.1", optional = true} -zhipuai = {version = "^1.0.7", optional = true} httpx = {version = "^0.24.1", optional = true} elasticsearch = {version = "^8.12.0", optional = true} hdbcli = {version = "^2.19.21", optional = true} @@ -99,6 +98,8 @@ tidb-vector = {version = ">=0.0.3,<1.0.0", optional = true} friendli-client = {version = "^1.2.4", optional = true} premai = {version = "^0.3.25", optional = true} vdms = {version = "^0.0.20", optional = true} +httpx-sse = {version = "^0.4.0", optional = true} +pyjwt = {version = "^2.8.0", optional = true} [tool.poetry.group.test] optional = true @@ -262,7 +263,6 @@ extended_testing = [ "tree-sitter-languages", "azure-ai-documentintelligence", "oracle-ads", - "zhipuai", "httpx", "elasticsearch", "hdbcli", @@ -272,7 +272,9 @@ extended_testing = [ "cloudpickle", "friendli-client", "premai", - "vdms" + "vdms", + "httpx-sse", + "pyjwt" ] [tool.ruff] diff --git a/libs/community/tests/integration_tests/chat_models/test_zhipuai.py b/libs/community/tests/integration_tests/chat_models/test_zhipuai.py index 8bd4dd0caceb6..0c110d1c9ab1d 100644 --- a/libs/community/tests/integration_tests/chat_models/test_zhipuai.py +++ b/libs/community/tests/integration_tests/chat_models/test_zhipuai.py @@ -18,7 +18,7 @@ def test_default_call() -> None: def test_model() -> None: """Test model kwarg works.""" - chat = ChatZhipuAI(model="chatglm_turbo") + chat = ChatZhipuAI(model="glm-4") response = chat(messages=[HumanMessage(content="Hello")]) assert isinstance(response, BaseMessage) assert isinstance(response.content, str) diff --git a/libs/community/tests/unit_tests/chat_models/test_zhipuai.py b/libs/community/tests/unit_tests/chat_models/test_zhipuai.py index 197cacc631974..31d64128859d8 100644 --- a/libs/community/tests/unit_tests/chat_models/test_zhipuai.py +++ b/libs/community/tests/unit_tests/chat_models/test_zhipuai.py @@ -1,10 +1,13 @@ +"""Test ZhipuAI Chat API wrapper""" + import pytest from langchain_community.chat_models.zhipuai import ChatZhipuAI -@pytest.mark.requires("zhipuai") -def test_integration_initialization() -> None: - chat = ChatZhipuAI(model="chatglm_turbo", streaming=False) - assert chat.model == "chatglm_turbo" - assert chat.streaming is False +@pytest.mark.requires("httpx", "httpx_sse", "jwt") +def test_zhipuai_model_param() -> None: + llm = ChatZhipuAI(api_key="test", model="foo") + assert llm.model_name == "foo" + llm = ChatZhipuAI(api_key="test", model_name="foo") + assert llm.model_name == "foo"