Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added history and support for system_message as param #14824

Merged
merged 5 commits into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/docs/integrations/chat/google_generative_ai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,32 @@
"print(result.content)"
]
},
{
"cell_type": "markdown",
"id": "9e55d043-bb2f-44e3-9134-c39a1abe3a9e",
"metadata": {},
"source": [
"Gemini doesn't support `SystemMessage` at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to True:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a64b523-9710-4d15-9944-1e3cc567a52b",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema.messages import HumanMessage, SystemMessage\n",
"\n",
"model = ChatGoogleGenerativeAI(model=\"gemini-pro\", convert_system_message_to_human=True)\n",
"model(\n",
" [\n",
" SystemMessage(content=\"Answer only yes or no.\"),\n",
" HumanMessage(content=\"Is apple a fruit?\"),\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "40773fac-b24d-476d-91c8-2da8fed99b53",
Expand Down
141 changes: 79 additions & 62 deletions libs/partners/google-genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
ChatMessageChunk,
HumanMessage,
HumanMessageChunk,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
Expand Down Expand Up @@ -106,7 +107,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)


def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.

Expand Down Expand Up @@ -139,7 +140,7 @@ def _chat_with_retry(**kwargs: Any) -> Any:
return _chat_with_retry(**kwargs)


async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any:
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kinda nice to keep the * to require keyword args, but since it's a private method it's not that big a deal

"""
Executes a chat generation method with retry logic using tenacity.

Expand Down Expand Up @@ -172,26 +173,6 @@ async def _achat_with_retry(**kwargs: Any) -> Any:
return await _achat_with_retry(**kwargs)


def _get_role(message: BaseMessage) -> str:
if isinstance(message, ChatMessage):
if message.role not in ("user", "model"):
raise ChatGoogleGenerativeAIError(
"Gemini only supports user and model roles when"
" providing it with Chat messages."
)
return message.role
elif isinstance(message, HumanMessage):
return "user"
elif isinstance(message, AIMessage):
return "model"
else:
# TODO: Gemini doesn't seem to have a concept of system messages yet.
raise ChatGoogleGenerativeAIError(
f"Message of '{message.type}' type not supported by Gemini."
" Please only provide it with Human or AI (user/assistant) messages."
)


def _is_openai_parts_format(part: dict) -> bool:
return "type" in part

Expand Down Expand Up @@ -266,13 +247,14 @@ def _url_to_pil(image_source: str) -> Image:


def _convert_to_parts(
content: Sequence[Union[str, dict]],
raw_content: Union[str, Sequence[Union[str, dict]]],
) -> List[genai.types.PartType]:
"""Converts a list of LangChain messages into a google parts."""
parts = []
content = [raw_content] if isinstance(raw_content, str) else raw_content
for part in content:
if isinstance(part, str):
parts.append(genai.types.PartDict(text=part, inline_data=None))
parts.append(genai.types.PartDict(text=part))
elif isinstance(part, Mapping):
# OpenAI Format
if _is_openai_parts_format(part):
Expand Down Expand Up @@ -304,27 +286,49 @@ def _convert_to_parts(
return parts


def _messages_to_genai_contents(
input_messages: Sequence[BaseMessage],
def _parse_chat_history(
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
) -> List[genai.types.ContentDict]:
"""Converts a list of messages into a Gemini API google content dicts."""

messages: List[genai.types.MessageDict] = []

raw_system_message: Optional[SystemMessage] = None
for i, message in enumerate(input_messages):
role = _get_role(message)
if isinstance(message.content, str):
parts = [message.content]
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!

To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:

llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
role = "model"
elif isinstance(message, HumanMessage):
role = "user"
else:
parts = _convert_to_parts(message.content)
messages.append({"role": role, "parts": parts})
if i > 0:
# Cannot have multiple messages from the same role in a row.
if role == messages[-2]["role"]:
raise ChatGoogleGenerativeAIError(
"Cannot have multiple messages from the same role in a row."
" Consider merging them into a single message with multiple"
f" parts.\nReceived: {messages}"
raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}."
)

parts = _convert_to_parts(message.content)
if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message.content) + parts
raw_system_message = None
messages.append({"role": role, "parts": parts})
return messages


Expand Down Expand Up @@ -457,8 +461,11 @@ class ChatGoogleGenerativeAI(BaseChatModel):
n: int = Field(default=1, alias="candidate_count")
"""Number of chat completions to generate for each prompt. Note that the API may
not return the full n completions if duplicates are generated."""

_generative_model: Any #: :meta private:
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.

Gemini does not support system messages; any unsupported messages will
raise an error."""

class Config:
allow_population_by_field_name = True
Expand Down Expand Up @@ -499,7 +506,7 @@ def validate_environment(cls, values: Dict) -> Dict:
if values.get("top_k") is not None and values["top_k"] <= 0:
raise ValueError("top_k must be positive")
model = values["model"]
values["_generative_model"] = genai.GenerativeModel(model_name=model)
values["client"] = genai.GenerativeModel(model_name=model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why change the name?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted it to be unified with other integrations (it makes it a little bit easier to debug).
If you have any reasons or a strong opinion, we can revert this particular change :).

return values

@property
Expand All @@ -512,18 +519,9 @@ def _identifying_params(self) -> Dict[str, Any]:
"n": self.n,
}

@property
def _generation_method(self) -> Callable:
return self._generative_model.generate_content

@property
def _async_generation_method(self) -> Callable:
return self._generative_model.generate_content_async

def _prepare_params(
self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any
self, stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
contents = _messages_to_genai_contents(messages)
gen_config = {
k: v
for k, v in {
Expand All @@ -538,7 +536,7 @@ def _prepare_params(
}
if "generation_config" in kwargs:
gen_config = {**gen_config, **kwargs.pop("generation_config")}
params = {"generation_config": gen_config, "contents": contents, **kwargs}
params = {"generation_config": gen_config, **kwargs}
return params

def _generate(
Expand All @@ -548,10 +546,11 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
generation_method=self._generation_method,
generation_method=chat.send_message,
)
return _response_to_result(response)

Expand All @@ -562,10 +561,11 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = await _achat_with_retry(
content=message,
**params,
generation_method=self._async_generation_method,
generation_method=chat.send_message_async,
)
return _response_to_result(response)

Expand All @@ -576,10 +576,11 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
generation_method=self._generation_method,
generation_method=chat.send_message,
stream=True,
)
for chunk in response:
Expand All @@ -602,10 +603,11 @@ async def _astream(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._prepare_params(messages, stop, **kwargs)
params, chat, message = self._prepare_chat(messages, stop=stop)
async for chunk in await _achat_with_retry(
content=message,
**params,
generation_method=self._async_generation_method,
generation_method=chat.send_message_async,
stream=True,
):
_chat_result = _response_to_result(
Expand All @@ -619,3 +621,18 @@ async def _astream(
yield gen
if run_manager:
await run_manager.on_llm_new_token(gen.text)

def _prepare_chat(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
chat = self.client.start_chat(history=history)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will throw an error if candidate_count > 1, right? Is there any real benefit to using this, given that the chat history is managed externally anyway?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

candidate_count > 1 is unsupported at this moment anyways, generate_content would throw the same error

return params, chat, message
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test ChatGoogleGenerativeAI chat model."""
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
Expand Down Expand Up @@ -147,3 +147,40 @@ def test_chat_google_genai_invoke_multimodal_invalid_model() -> None:
llm = ChatGoogleGenerativeAI(model=_MODEL)
with pytest.raises(ChatGoogleGenerativeAIError):
llm.invoke(messages)


def test_chat_google_genai_single_call_with_history() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)


def test_chat_google_genai_system_message_error() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])


def test_chat_google_genai_system_message() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
response = model([system_message, message1, message2, message3])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
24 changes: 23 additions & 1 deletion libs/partners/google-genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""Test chat model integration."""
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.pydantic_v1 import SecretStr
from pytest import CaptureFixture

from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
_parse_chat_history,
)


def test_integration_initialization() -> None:
Expand Down Expand Up @@ -36,3 +40,21 @@ def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> N
captured = capsys.readouterr()

assert captured.out == "**********"


def test_parse_history() -> None:
system_input = "You're supposed to answer math questions."
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content=system_input)
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history(messages, convert_system_message_to_human=True)
assert len(history) == 3
assert history[0] == {
"role": "user",
"parts": [{"text": system_input}, {"text": text_question1}],
}
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
Loading