Skip to content

Commit

Permalink
added history and support for system_message as param (#14824)
Browse files Browse the repository at this point in the history
- **Description:** added support for chat_history for Google
GenerativeAI (to actually use the `chat` API) plus since Gemini
currently doesn't have a support for SystemMessage, added support for it
only if a user provides additional `convert_system_message_to_human`
flag during model initialization (in this case, SystemMessage would be
prepanded to the first HumanMessage)
  - **Issue:** #14710 
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
  - **Twitter handle:** lkuligin

---------

Co-authored-by: William FH <[email protected]>
  • Loading branch information
lkuligin and hinthornw authored Dec 19, 2023
1 parent 2861766 commit 2d0f1ca
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 64 deletions.
26 changes: 26 additions & 0 deletions docs/docs/integrations/chat/google_generative_ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,32 @@
"print(result.content)"
]
},
{
"cell_type": "markdown",
"id": "9e55d043-bb2f-44e3-9134-c39a1abe3a9e",
"metadata": {},
"source": [
"Gemini doesn't support `SystemMessage` at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to True:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a64b523-9710-4d15-9944-1e3cc567a52b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.messages import HumanMessage, SystemMessage\n",
"\n",
"model = ChatGoogleGenerativeAI(model=\"gemini-pro\", convert_system_message_to_human=True)\n",
"model(\n",
" [\n",
" SystemMessage(content=\"Answer only yes or no.\"),\n",
" HumanMessage(content=\"Is apple a fruit?\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "40773fac-b24d-476d-91c8-2da8fed99b53",
Expand Down
141 changes: 79 additions & 62 deletions libs/partners/google-genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
Expand Down Expand Up @@ -106,7 +107,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)


def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
Expand Down Expand Up @@ -139,7 +140,7 @@ def _chat_with_retry(**kwargs: Any) -> Any:
return _chat_with_retry(**kwargs)


async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
Expand Down Expand Up @@ -172,26 +173,6 @@ async def _achat_with_retry(**kwargs: Any) -> Any:
return await _achat_with_retry(**kwargs)


def _get_role(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
if message.role not in ("user", "model"):
raise ChatGoogleGenerativeAIError(
"Gemini only supports user and model roles when"
" providing it with Chat messages."
)
return message.role
elif isinstance(message, HumanMessage):
return "user"
elif isinstance(message, AIMessage):
return "model"
else:
# TODO: Gemini doesn't seem to have a concept of system messages yet.
raise ChatGoogleGenerativeAIError(
f"Message of '{message.type}' type not supported by Gemini."
" Please only provide it with Human or AI (user/assistant) messages."
)


def _is_openai_parts_format(part: dict) -> bool:
return "type" in part

Expand Down Expand Up @@ -266,13 +247,14 @@ def _url_to_pil(image_source: str) -> Image:


def _convert_to_parts(
content: Sequence[Union[str, dict]],
raw_content: Union[str, Sequence[Union[str, dict]]],
) -> List[genai.types.PartType]:
"""Converts a list of LangChain messages into a google parts."""
parts = []
content = [raw_content] if isinstance(raw_content, str) else raw_content
for part in content:
if isinstance(part, str):
parts.append(genai.types.PartDict(text=part, inline_data=None))
parts.append(genai.types.PartDict(text=part))
elif isinstance(part, Mapping):
# OpenAI Format
if _is_openai_parts_format(part):
Expand Down Expand Up @@ -304,27 +286,49 @@ def _convert_to_parts(
return parts


def _messages_to_genai_contents(
input_messages: Sequence[BaseMessage],
def _parse_chat_history(
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
) -> List[genai.types.ContentDict]:
"""Converts a list of messages into a Gemini API google content dicts."""

messages: List[genai.types.MessageDict] = []

raw_system_message: Optional[SystemMessage] = None
for i, message in enumerate(input_messages):
role = _get_role(message)
if isinstance(message.content, str):
parts = [message.content]
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
role = "model"
elif isinstance(message, HumanMessage):
role = "user"
else:
parts = _convert_to_parts(message.content)
messages.append({"role": role, "parts": parts})
if i > 0:
# Cannot have multiple messages from the same role in a row.
if role == messages[-2]["role"]:
raise ChatGoogleGenerativeAIError(
"Cannot have multiple messages from the same role in a row."
" Consider merging them into a single message with multiple"
f" parts.\nReceived: {messages}"
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)

parts = _convert_to_parts(message.content)
if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message.content) + parts
raw_system_message = None
messages.append({"role": role, "parts": parts})
return messages


Expand Down Expand Up @@ -457,8 +461,11 @@ class ChatGoogleGenerativeAI(BaseChatModel):
n: int = Field(default=1, alias="candidate_count")
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""

_generative_model: Any #: :meta private:
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will
raise an error."""

class Config:
allow_population_by_field_name = True
Expand Down Expand Up @@ -499,7 +506,7 @@ def validate_environment(cls, values: Dict) -> Dict:
if values.get("top_k") is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
model = values["model"]
values["_generative_model"] = genai.GenerativeModel(model_name=model)
values["client"] = genai.GenerativeModel(model_name=model)
return values

@property
Expand All @@ -512,18 +519,9 @@ def _identifying_params(self) -> Dict[str, Any]:
"n": self.n,
}

@property
def _generation_method(self) -> Callable:
return self._generative_model.generate_content

@property
def _async_generation_method(self) -> Callable:
return self._generative_model.generate_content_async

def _prepare_params(
self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any
self, stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
contents = _messages_to_genai_contents(messages)
gen_config = {
k: v
for k, v in {
Expand All @@ -538,7 +536,7 @@ def _prepare_params(
}
if "generation_config" in kwargs:
gen_config = {**gen_config, **kwargs.pop("generation_config")}
params = {"generation_config": gen_config, "contents": contents, **kwargs}
params = {"generation_config": gen_config, **kwargs}
return params

def _generate(
Expand All @@ -548,10 +546,11 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
generation_method=self._generation_method,
generation_method=chat.send_message,
)
return _response_to_result(response)

Expand All @@ -562,10 +561,11 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = await _achat_with_retry(
content=message,
**params,
generation_method=self._async_generation_method,
generation_method=chat.send_message_async,
)
return _response_to_result(response)

Expand All @@ -576,10 +576,11 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
generation_method=self._generation_method,
generation_method=chat.send_message,
stream=True,
)
for chunk in response:
Expand All @@ -602,10 +603,11 @@ async def _astream(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
async for chunk in await _achat_with_retry(
content=message,
**params,
generation_method=self._async_generation_method,
generation_method=chat.send_message_async,
stream=True,
):
_chat_result = _response_to_result(
Expand All @@ -619,3 +621,18 @@ async def _astream(
yield gen
if run_manager:
await run_manager.on_llm_new_token(gen.text)

def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
chat = self.client.start_chat(history=history)
return params, chat, message
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test ChatGoogleGenerativeAI chat model."""
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
Expand Down Expand Up @@ -147,3 +147,40 @@ def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
llm = ChatGoogleGenerativeAI(model=_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)


def test_chat_google_genai_single_call_with_history() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)


def test_chat_google_genai_system_message_error() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])


def test_chat_google_genai_system_message() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
24 changes: 23 additions & 1 deletion libs/partners/google-genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Test chat model integration."""
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture

from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
_parse_chat_history,
)


def test_integration_initialization() -> None:
Expand Down Expand Up @@ -36,3 +40,21 @@ def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> N
captured = capsys.readouterr()

assert captured.out == "**********"


def test_parse_history() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0] == {
"role": "user",
"parts": [{"text": system_input}, {"text": text_question1}],
}
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}

0 comments on commit 2d0f1ca

Please sign in to comment.