-
Notifications
You must be signed in to change notification settings - Fork 15.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added history and support for system_message as param #14824
Changes from all commits
1c380da
3f16c61
95ad94b
da2f4d1
6248fc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ | |
ChatMessageChunk, | ||
HumanMessage, | ||
HumanMessageChunk, | ||
SystemMessage, | ||
) | ||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult | ||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator | ||
|
@@ -106,7 +107,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]: | |
) | ||
|
||
|
||
def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: | ||
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: | ||
""" | ||
Executes a chat generation method with retry logic using tenacity. | ||
|
||
|
@@ -139,7 +140,7 @@ def _chat_with_retry(**kwargs: Any) -> Any: | |
return _chat_with_retry(**kwargs) | ||
|
||
|
||
async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: | ||
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: | ||
""" | ||
Executes a chat generation method with retry logic using tenacity. | ||
|
||
|
@@ -172,26 +173,6 @@ async def _achat_with_retry(**kwargs: Any) -> Any: | |
return await _achat_with_retry(**kwargs) | ||
|
||
|
||
def _get_role(message: BaseMessage) -> str: | ||
if isinstance(message, ChatMessage): | ||
if message.role not in ("user", "model"): | ||
raise ChatGoogleGenerativeAIError( | ||
"Gemini only supports user and model roles when" | ||
" providing it with Chat messages." | ||
) | ||
return message.role | ||
elif isinstance(message, HumanMessage): | ||
return "user" | ||
elif isinstance(message, AIMessage): | ||
return "model" | ||
else: | ||
# TODO: Gemini doesn't seem to have a concept of system messages yet. | ||
raise ChatGoogleGenerativeAIError( | ||
f"Message of '{message.type}' type not supported by Gemini." | ||
" Please only provide it with Human or AI (user/assistant) messages." | ||
) | ||
|
||
|
||
def _is_openai_parts_format(part: dict) -> bool: | ||
return "type" in part | ||
|
||
|
@@ -266,13 +247,14 @@ def _url_to_pil(image_source: str) -> Image: | |
|
||
|
||
def _convert_to_parts( | ||
content: Sequence[Union[str, dict]], | ||
raw_content: Union[str, Sequence[Union[str, dict]]], | ||
) -> List[genai.types.PartType]: | ||
"""Converts a list of LangChain messages into a google parts.""" | ||
parts = [] | ||
content = [raw_content] if isinstance(raw_content, str) else raw_content | ||
for part in content: | ||
if isinstance(part, str): | ||
parts.append(genai.types.PartDict(text=part, inline_data=None)) | ||
parts.append(genai.types.PartDict(text=part)) | ||
elif isinstance(part, Mapping): | ||
# OpenAI Format | ||
if _is_openai_parts_format(part): | ||
|
@@ -304,27 +286,49 @@ def _convert_to_parts( | |
return parts | ||
|
||
|
||
def _messages_to_genai_contents( | ||
input_messages: Sequence[BaseMessage], | ||
def _parse_chat_history( | ||
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False | ||
) -> List[genai.types.ContentDict]: | ||
"""Converts a list of messages into a Gemini API google content dicts.""" | ||
|
||
messages: List[genai.types.MessageDict] = [] | ||
|
||
raw_system_message: Optional[SystemMessage] = None | ||
for i, message in enumerate(input_messages): | ||
role = _get_role(message) | ||
if isinstance(message.content, str): | ||
parts = [message.content] | ||
if ( | ||
i == 0 | ||
and isinstance(message, SystemMessage) | ||
and not convert_system_message_to_human | ||
): | ||
raise ValueError( | ||
"""SystemMessages are not yet supported! | ||
|
||
To automatically convert the leading SystemMessage to a HumanMessage, | ||
set `convert_system_message_to_human` to True. Example: | ||
|
||
llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True) | ||
""" | ||
) | ||
elif i == 0 and isinstance(message, SystemMessage): | ||
raw_system_message = message | ||
continue | ||
elif isinstance(message, AIMessage): | ||
role = "model" | ||
elif isinstance(message, HumanMessage): | ||
role = "user" | ||
else: | ||
parts = _convert_to_parts(message.content) | ||
messages.append({"role": role, "parts": parts}) | ||
if i > 0: | ||
# Cannot have multiple messages from the same role in a row. | ||
if role == messages[-2]["role"]: | ||
raise ChatGoogleGenerativeAIError( | ||
"Cannot have multiple messages from the same role in a row." | ||
" Consider merging them into a single message with multiple" | ||
f" parts.\nReceived: {messages}" | ||
raise ValueError( | ||
f"Unexpected message with type {type(message)} at the position {i}." | ||
) | ||
|
||
parts = _convert_to_parts(message.content) | ||
if raw_system_message: | ||
if role == "model": | ||
raise ValueError( | ||
"SystemMessage should be followed by a HumanMessage and " | ||
"not by AIMessage." | ||
) | ||
parts = _convert_to_parts(raw_system_message.content) + parts | ||
raw_system_message = None | ||
messages.append({"role": role, "parts": parts}) | ||
return messages | ||
|
||
|
||
|
@@ -457,8 +461,11 @@ class ChatGoogleGenerativeAI(BaseChatModel): | |
n: int = Field(default=1, alias="candidate_count") | ||
"""Number of chat completions to generate for each prompt. Note that the API may | ||
not return the full n completions if duplicates are generated.""" | ||
|
||
_generative_model: Any #: :meta private: | ||
convert_system_message_to_human: bool = False | ||
"""Whether to merge any leading SystemMessage into the following HumanMessage. | ||
|
||
Gemini does not support system messages; any unsupported messages will | ||
raise an error.""" | ||
|
||
class Config: | ||
allow_population_by_field_name = True | ||
|
@@ -499,7 +506,7 @@ def validate_environment(cls, values: Dict) -> Dict: | |
if values.get("top_k") is not None and values["top_k"] <= 0: | ||
raise ValueError("top_k must be positive") | ||
model = values["model"] | ||
values["_generative_model"] = genai.GenerativeModel(model_name=model) | ||
values["client"] = genai.GenerativeModel(model_name=model) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity, why change the name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just wanted it to be unified with other integrations (it makes it a little bit easier to debug). |
||
return values | ||
|
||
@property | ||
|
@@ -512,18 +519,9 @@ def _identifying_params(self) -> Dict[str, Any]: | |
"n": self.n, | ||
} | ||
|
||
@property | ||
def _generation_method(self) -> Callable: | ||
return self._generative_model.generate_content | ||
|
||
@property | ||
def _async_generation_method(self) -> Callable: | ||
return self._generative_model.generate_content_async | ||
|
||
def _prepare_params( | ||
self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any | ||
self, stop: Optional[List[str]], **kwargs: Any | ||
) -> Dict[str, Any]: | ||
contents = _messages_to_genai_contents(messages) | ||
gen_config = { | ||
k: v | ||
for k, v in { | ||
|
@@ -538,7 +536,7 @@ def _prepare_params( | |
} | ||
if "generation_config" in kwargs: | ||
gen_config = {**gen_config, **kwargs.pop("generation_config")} | ||
params = {"generation_config": gen_config, "contents": contents, **kwargs} | ||
params = {"generation_config": gen_config, **kwargs} | ||
return params | ||
|
||
def _generate( | ||
|
@@ -548,10 +546,11 @@ def _generate( | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
params = self._prepare_params(messages, stop, **kwargs) | ||
params, chat, message = self._prepare_chat(messages, stop=stop) | ||
response: genai.types.GenerateContentResponse = _chat_with_retry( | ||
content=message, | ||
**params, | ||
generation_method=self._generation_method, | ||
generation_method=chat.send_message, | ||
) | ||
return _response_to_result(response) | ||
|
||
|
@@ -562,10 +561,11 @@ async def _agenerate( | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> ChatResult: | ||
params = self._prepare_params(messages, stop, **kwargs) | ||
params, chat, message = self._prepare_chat(messages, stop=stop) | ||
response: genai.types.GenerateContentResponse = await _achat_with_retry( | ||
content=message, | ||
**params, | ||
generation_method=self._async_generation_method, | ||
generation_method=chat.send_message_async, | ||
) | ||
return _response_to_result(response) | ||
|
||
|
@@ -576,10 +576,11 @@ def _stream( | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> Iterator[ChatGenerationChunk]: | ||
params = self._prepare_params(messages, stop, **kwargs) | ||
params, chat, message = self._prepare_chat(messages, stop=stop) | ||
response: genai.types.GenerateContentResponse = _chat_with_retry( | ||
content=message, | ||
**params, | ||
generation_method=self._generation_method, | ||
generation_method=chat.send_message, | ||
stream=True, | ||
) | ||
for chunk in response: | ||
|
@@ -602,10 +603,11 @@ async def _astream( | |
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, | ||
**kwargs: Any, | ||
) -> AsyncIterator[ChatGenerationChunk]: | ||
params = self._prepare_params(messages, stop, **kwargs) | ||
params, chat, message = self._prepare_chat(messages, stop=stop) | ||
async for chunk in await _achat_with_retry( | ||
content=message, | ||
**params, | ||
generation_method=self._async_generation_method, | ||
generation_method=chat.send_message_async, | ||
stream=True, | ||
): | ||
_chat_result = _response_to_result( | ||
|
@@ -619,3 +621,18 @@ async def _astream( | |
yield gen | ||
if run_manager: | ||
await run_manager.on_llm_new_token(gen.text) | ||
|
||
def _prepare_chat( | ||
self, | ||
messages: List[BaseMessage], | ||
stop: Optional[List[str]] = None, | ||
**kwargs: Any, | ||
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]: | ||
params = self._prepare_params(stop, **kwargs) | ||
history = _parse_chat_history( | ||
messages, | ||
convert_system_message_to_human=self.convert_system_message_to_human, | ||
) | ||
message = history.pop() | ||
chat = self.client.start_chat(history=history) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will throw an error if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return params, chat, message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kinda nice to keep the * to require keyword args, but since it's a private method it's not that big a deal