Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update schema #11320

Merged
merged 12 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 57 additions & 10 deletions cookbook/openai_v1_cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,74 @@
"!pip install \"openai>=1\""
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c3e067ce-7a43-47a7-bc89-41f1de4cf136",
"metadata": {},
"outputs": [],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.messages import HumanMessage, SystemMessage"
]
},
{
"cell_type": "markdown",
"id": "71c34763-d1e7-4b9a-a9d7-3e4cc0dfc2c4",
"id": "fa7e7e95-90a1-4f73-98fe-10c4b4e0951b",
"metadata": {},
"source": [
"## [JSON mode](https://platform.openai.com/docs/guides/text-generation/json-mode)\n",
"\n",
"Constrain the model to only generate valid JSON. Note that you must include a system message with instructions to use JSON for this mode to work.\n",
"## [Vision](https://platform.openai.com/docs/guides/vision)\n",
"\n",
"Only works with certain models. "
"OpenAI released multi-modal models, which can take a sequence of text and images as input."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "c3e067ce-7a43-47a7-bc89-41f1de4cf136",
"execution_count": 2,
"id": "1c8c3965-d3c9-4186-b5f3-5e67855ef916",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='The image appears to be a diagram illustrating the architecture or components of a software system')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain.chat_models import ChatOpenAI\n",
"from langchain.schema.messages import HumanMessage, SystemMessage"
"chat = ChatOpenAI(model=\"gpt-4-vision-preview\")\n",
"chat.invoke(\n",
" [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\": \"What is this image showing\"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": \"https://python.langchain.com/assets/images/langchain_stack-da369071b058555da3d491a695651f15.jpg\",\n",
" \"detail\": \"auto\",\n",
" },\n",
" },\n",
" ]\n",
" )\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "71c34763-d1e7-4b9a-a9d7-3e4cc0dfc2c4",
"metadata": {},
"source": [
"## [JSON mode](https://platform.openai.com/docs/guides/text-generation/json-mode)\n",
"\n",
"Constrain the model to only generate valid JSON. Note that you must include a system message with instructions to use JSON for this mode to work.\n",
"\n",
"Only works with certain models. "
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ def _parse_ai_message(message: BaseMessage) -> Union[List[AgentAction], AgentFin
final_tools.append(_tool)
return final_tools

return AgentFinish(return_values={"output": message.content}, log=message.content)
return AgentFinish(
return_values={"output": message.content}, log=str(message.content)
)


class OpenAIMultiFunctionsAgent(BaseMultiActionAgent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _parse_ai_message(message: BaseMessage) -> Union[AgentAction, AgentFinish]:
)

return AgentFinish(
return_values={"output": message.content}, log=message.content
return_values={"output": message.content}, log=str(message.content)
)

def parse_result(
Expand Down
10 changes: 7 additions & 3 deletions libs/langchain/langchain/callbacks/infino_callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult
Expand Down Expand Up @@ -232,7 +232,9 @@ def on_chat_model_start(
self.chat_openai_model_name = model_name
prompt_tokens = 0
for message_list in messages:
message_string = " ".join(msg.content for msg in message_list)
message_string = " ".join(
cast(str, msg.content) for msg in message_list
)
num_tokens = get_num_tokens(
message_string,
openai_model_name=self.chat_openai_model_name,
Expand All @@ -249,7 +251,9 @@ def on_chat_model_start(
)

# Send the prompt to infino
prompt = " ".join(msg.content for sublist in messages for msg in sublist)
prompt = " ".join(
cast(str, msg.content) for sublist in messages for msg in sublist
)
self._send_to_infino("prompt", prompt, is_ts=False)

# Set the error flag to indicate no error (this will get overridden
Expand Down
10 changes: 10 additions & 0 deletions libs/langchain/langchain/chat_loaders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def merge_chat_runs_in_session(
"""
messages: List[BaseMessage] = []
for message in chat_session["messages"]:
if not isinstance(message.content, str):
raise ValueError(
"Chat Loaders only support messages with content type string, "
f"got {message.content}"
)
if not messages:
messages.append(deepcopy(message))
elif (
Expand All @@ -29,6 +34,11 @@ def merge_chat_runs_in_session(
and messages[-1].additional_kwargs["sender"]
== message.additional_kwargs.get("sender")
):
if not isinstance(messages[-1].content, str):
raise ValueError(
"Chat Loaders only support messages with content type string, "
f"got {messages[-1].content}"
)
messages[-1].content = (
messages[-1].content + delimiter + message.content
).strip()
Expand Down
11 changes: 6 additions & 5 deletions libs/langchain/langchain/chat_models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand Down Expand Up @@ -27,14 +27,15 @@ def _convert_one_message_to_text(
human_prompt: str,
ai_prompt: str,
) -> str:
content = cast(str, message.content)
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
message_text = f"\n\n{message.role.capitalize()}: {content}"
elif isinstance(message, HumanMessage):
message_text = f"{human_prompt} {message.content}"
message_text = f"{human_prompt} {content}"
elif isinstance(message, AIMessage):
message_text = f"{ai_prompt} {message.content}"
message_text = f"{ai_prompt} {content}"
elif isinstance(message, SystemMessage):
message_text = message.content
message_text = content
else:
raise ValueError(f"Got unknown type {message}")
return message_text
Expand Down
19 changes: 6 additions & 13 deletions libs/langchain/langchain/chat_models/azureml_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.chat_models.base import SimpleChatModel
Expand All @@ -23,36 +23,29 @@ class LlamaContentFormatter(ContentFormatterBase):
@staticmethod
def _convert_message_to_dict(message: BaseMessage) -> Dict:
"""Converts message to a dict according to role"""
content = cast(str, message.content)
if isinstance(message, HumanMessage):
return {
"role": "user",
"content": ContentFormatterBase.escape_special_characters(
message.content
),
"content": ContentFormatterBase.escape_special_characters(content),
}
elif isinstance(message, AIMessage):
return {
"role": "assistant",
"content": ContentFormatterBase.escape_special_characters(
message.content
),
"content": ContentFormatterBase.escape_special_characters(content),
}
elif isinstance(message, SystemMessage):
return {
"role": "system",
"content": ContentFormatterBase.escape_special_characters(
message.content
),
"content": ContentFormatterBase.escape_special_characters(content),
}
elif (
isinstance(message, ChatMessage)
and message.role in LlamaContentFormatter.SUPPORTED_ROLES
):
return {
"role": message.role,
"content": ContentFormatterBase.escape_special_characters(
message.content
),
"content": ContentFormatterBase.escape_special_characters(content),
}
else:
supported = ",".join(
Expand Down
12 changes: 2 additions & 10 deletions libs/langchain/langchain/chat_models/baidu_qianfan_endpoint.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,7 @@
from __future__ import annotations

import logging
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
)
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional, cast

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand Down Expand Up @@ -211,7 +203,7 @@ def _convert_prompt_msg_params(
for i in [i for i, m in enumerate(messages) if isinstance(m, SystemMessage)]:
if "system" not in messages_dict:
messages_dict["system"] = ""
messages_dict["system"] += messages[i].content + "\n"
messages_dict["system"] += cast(str, messages[i].content) + "\n"

return {
**messages_dict,
Expand Down
10 changes: 8 additions & 2 deletions libs/langchain/langchain/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,10 @@ def predict(
else:
_stop = list(stop)
result = self([HumanMessage(content=text)], stop=_stop, **kwargs)
return result.content
if isinstance(result.content, str):
return result.content
else:
raise ValueError("Cannot use predict when output is not a string.")

def predict_messages(
self,
Expand All @@ -659,7 +662,10 @@ async def apredict(
result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
return result.content
if isinstance(result.content, str):
return result.content
else:
raise ValueError("Cannot use predict when output is not a string.")

async def apredict_messages(
self,
Expand Down
4 changes: 2 additions & 2 deletions libs/langchain/langchain/chat_models/google_palm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, cast

from tenacity import (
before_sleep_log,
Expand Down Expand Up @@ -114,7 +114,7 @@ def _messages_to_prompt_dict(
if isinstance(input_message, SystemMessage):
if index != 0:
raise ChatGooglePalmError("System message must be first input message.")
context = input_message.content
context = cast(str, input_message.content)
elif isinstance(input_message, HumanMessage) and input_message.example:
if messages:
raise ChatGooglePalmError(
Expand Down
5 changes: 4 additions & 1 deletion libs/langchain/langchain/chat_models/human.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ def _collect_yaml_input(
if message is None:
return HumanMessage(content="")
if stop:
message.content = enforce_stop_tokens(message.content, stop)
if isinstance(message.content, str):
message.content = enforce_stop_tokens(message.content, stop)
else:
raise ValueError("Cannot use when output is not a string.")
return message
except yaml.YAMLError:
raise ValueError("Invalid YAML string entered.")
Expand Down
7 changes: 4 additions & 3 deletions libs/langchain/langchain/chat_models/minimax.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Wrapper around Minimax chat models."""
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast

from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand All @@ -27,10 +27,11 @@ def _parse_chat_history(history: List[BaseMessage]) -> List:
"""Parse a sequence of messages into history."""
chat_history = []
for message in history:
content = cast(str, message.content)
if isinstance(message, HumanMessage):
chat_history.append(_parse_message("USER", message.content))
chat_history.append(_parse_message("USER", content))
if isinstance(message, AIMessage):
chat_history.append(_parse_message("BOT", message.content))
chat_history.append(_parse_message("BOT", content))
return chat_history


Expand Down
6 changes: 4 additions & 2 deletions libs/langchain/langchain/chat_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,14 +316,16 @@ def validate_environment(cls, values: Dict) -> Dict:
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
params = {
"model": self.model_name,
"max_tokens": self.max_tokens,
"stream": self.streaming,
"n": self.n,
"temperature": self.temperature,
**self.model_kwargs,
}
if "vision" not in self.model_name:
params["max_tokens"] = self.max_tokens
return params

def completion_with_retry(
self, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any
Expand Down
Loading
Loading