diff --git a/docs/docs/integrations/chat/google_generative_ai.ipynb b/docs/docs/integrations/chat/google_generative_ai.ipynb index d69a2dfbf2bcf..90cb17ce50f6d 100644 --- a/docs/docs/integrations/chat/google_generative_ai.ipynb +++ b/docs/docs/integrations/chat/google_generative_ai.ipynb @@ -135,6 +135,32 @@ "print(result.content)" ] }, + { + "cell_type": "markdown", + "id": "9e55d043-bb2f-44e3-9134-c39a1abe3a9e", + "metadata": {}, + "source": [ + "Gemini doesn't support `SystemMessage` at the moment, but it can be added to the first human message in the row. If you want such behavior, just set the `convert_system_message_to_human` to True:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a64b523-9710-4d15-9944-1e3cc567a52b", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema.messages import HumanMessage, SystemMessage\n", + "\n", + "model = ChatGoogleGenerativeAI(model=\"gemini-pro\", convert_system_message_to_human=True)\n", + "model(\n", + " [\n", + " SystemMessage(content=\"Answer only yes or no.\"),\n", + " HumanMessage(content=\"Is apple a fruit?\"),\n", + " ]\n", + ")" + ] + }, { "cell_type": "markdown", "id": "40773fac-b24d-476d-91c8-2da8fed99b53", diff --git a/libs/partners/google-genai/langchain_google_genai/chat_models.py b/libs/partners/google-genai/langchain_google_genai/chat_models.py index d4724f9d892a4..b1f289095fb21 100644 --- a/libs/partners/google-genai/langchain_google_genai/chat_models.py +++ b/libs/partners/google-genai/langchain_google_genai/chat_models.py @@ -37,6 +37,7 @@ ChatMessageChunk, HumanMessage, HumanMessageChunk, + SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.pydantic_v1 import Field, SecretStr, root_validator @@ -106,7 +107,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]: ) -def _chat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: +def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: """ Executes a chat generation method with retry logic using tenacity. @@ -139,7 +140,7 @@ def _chat_with_retry(**kwargs: Any) -> Any: return _chat_with_retry(**kwargs) -async def _achat_with_retry(*, generation_method: Callable, **kwargs: Any) -> Any: +async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: """ Executes a chat generation method with retry logic using tenacity. @@ -172,26 +173,6 @@ async def _achat_with_retry(**kwargs: Any) -> Any: return await _achat_with_retry(**kwargs) -def _get_role(message: BaseMessage) -> str: - if isinstance(message, ChatMessage): - if message.role not in ("user", "model"): - raise ChatGoogleGenerativeAIError( - "Gemini only supports user and model roles when" - " providing it with Chat messages." - ) - return message.role - elif isinstance(message, HumanMessage): - return "user" - elif isinstance(message, AIMessage): - return "model" - else: - # TODO: Gemini doesn't seem to have a concept of system messages yet. - raise ChatGoogleGenerativeAIError( - f"Message of '{message.type}' type not supported by Gemini." - " Please only provide it with Human or AI (user/assistant) messages." - ) - - def _is_openai_parts_format(part: dict) -> bool: return "type" in part @@ -266,13 +247,14 @@ def _url_to_pil(image_source: str) -> Image: def _convert_to_parts( - content: Sequence[Union[str, dict]], + raw_content: Union[str, Sequence[Union[str, dict]]], ) -> List[genai.types.PartType]: """Converts a list of LangChain messages into a google parts.""" parts = [] + content = [raw_content] if isinstance(raw_content, str) else raw_content for part in content: if isinstance(part, str): - parts.append(genai.types.PartDict(text=part, inline_data=None)) + parts.append(genai.types.PartDict(text=part)) elif isinstance(part, Mapping): # OpenAI Format if _is_openai_parts_format(part): @@ -304,27 +286,49 @@ def _convert_to_parts( return parts -def _messages_to_genai_contents( - input_messages: Sequence[BaseMessage], +def _parse_chat_history( + input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False ) -> List[genai.types.ContentDict]: - """Converts a list of messages into a Gemini API google content dicts.""" - messages: List[genai.types.MessageDict] = [] + + raw_system_message: Optional[SystemMessage] = None for i, message in enumerate(input_messages): - role = _get_role(message) - if isinstance(message.content, str): - parts = [message.content] + if ( + i == 0 + and isinstance(message, SystemMessage) + and not convert_system_message_to_human + ): + raise ValueError( + """SystemMessages are not yet supported! + +To automatically convert the leading SystemMessage to a HumanMessage, +set `convert_system_message_to_human` to True. Example: + +llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True) +""" + ) + elif i == 0 and isinstance(message, SystemMessage): + raw_system_message = message + continue + elif isinstance(message, AIMessage): + role = "model" + elif isinstance(message, HumanMessage): + role = "user" else: - parts = _convert_to_parts(message.content) - messages.append({"role": role, "parts": parts}) - if i > 0: - # Cannot have multiple messages from the same role in a row. - if role == messages[-2]["role"]: - raise ChatGoogleGenerativeAIError( - "Cannot have multiple messages from the same role in a row." - " Consider merging them into a single message with multiple" - f" parts.\nReceived: {messages}" + raise ValueError( + f"Unexpected message with type {type(message)} at the position {i}." + ) + + parts = _convert_to_parts(message.content) + if raw_system_message: + if role == "model": + raise ValueError( + "SystemMessage should be followed by a HumanMessage and " + "not by AIMessage." ) + parts = _convert_to_parts(raw_system_message.content) + parts + raw_system_message = None + messages.append({"role": role, "parts": parts}) return messages @@ -457,8 +461,11 @@ class ChatGoogleGenerativeAI(BaseChatModel): n: int = Field(default=1, alias="candidate_count") """Number of chat completions to generate for each prompt. Note that the API may not return the full n completions if duplicates are generated.""" - - _generative_model: Any #: :meta private: + convert_system_message_to_human: bool = False + """Whether to merge any leading SystemMessage into the following HumanMessage. + + Gemini does not support system messages; any unsupported messages will + raise an error.""" class Config: allow_population_by_field_name = True @@ -499,7 +506,7 @@ def validate_environment(cls, values: Dict) -> Dict: if values.get("top_k") is not None and values["top_k"] <= 0: raise ValueError("top_k must be positive") model = values["model"] - values["_generative_model"] = genai.GenerativeModel(model_name=model) + values["client"] = genai.GenerativeModel(model_name=model) return values @property @@ -512,18 +519,9 @@ def _identifying_params(self) -> Dict[str, Any]: "n": self.n, } - @property - def _generation_method(self) -> Callable: - return self._generative_model.generate_content - - @property - def _async_generation_method(self) -> Callable: - return self._generative_model.generate_content_async - def _prepare_params( - self, messages: Sequence[BaseMessage], stop: Optional[List[str]], **kwargs: Any + self, stop: Optional[List[str]], **kwargs: Any ) -> Dict[str, Any]: - contents = _messages_to_genai_contents(messages) gen_config = { k: v for k, v in { @@ -538,7 +536,7 @@ def _prepare_params( } if "generation_config" in kwargs: gen_config = {**gen_config, **kwargs.pop("generation_config")} - params = {"generation_config": gen_config, "contents": contents, **kwargs} + params = {"generation_config": gen_config, **kwargs} return params def _generate( @@ -548,10 +546,11 @@ def _generate( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = self._prepare_params(messages, stop, **kwargs) + params, chat, message = self._prepare_chat(messages, stop=stop) response: genai.types.GenerateContentResponse = _chat_with_retry( + content=message, **params, - generation_method=self._generation_method, + generation_method=chat.send_message, ) return _response_to_result(response) @@ -562,10 +561,11 @@ async def _agenerate( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: - params = self._prepare_params(messages, stop, **kwargs) + params, chat, message = self._prepare_chat(messages, stop=stop) response: genai.types.GenerateContentResponse = await _achat_with_retry( + content=message, **params, - generation_method=self._async_generation_method, + generation_method=chat.send_message_async, ) return _response_to_result(response) @@ -576,10 +576,11 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - params = self._prepare_params(messages, stop, **kwargs) + params, chat, message = self._prepare_chat(messages, stop=stop) response: genai.types.GenerateContentResponse = _chat_with_retry( + content=message, **params, - generation_method=self._generation_method, + generation_method=chat.send_message, stream=True, ) for chunk in response: @@ -602,10 +603,11 @@ async def _astream( run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - params = self._prepare_params(messages, stop, **kwargs) + params, chat, message = self._prepare_chat(messages, stop=stop) async for chunk in await _achat_with_retry( + content=message, **params, - generation_method=self._async_generation_method, + generation_method=chat.send_message_async, stream=True, ): _chat_result = _response_to_result( @@ -619,3 +621,18 @@ async def _astream( yield gen if run_manager: await run_manager.on_llm_new_token(gen.text) + + def _prepare_chat( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]: + params = self._prepare_params(stop, **kwargs) + history = _parse_chat_history( + messages, + convert_system_message_to_human=self.convert_system_message_to_human, + ) + message = history.pop() + chat = self.client.start_chat(history=history) + return params, chat, message diff --git a/libs/partners/google-genai/tests/integration_tests/test_chat_models.py b/libs/partners/google-genai/tests/integration_tests/test_chat_models.py index 9436c5e252952..e3eef827a9fe3 100644 --- a/libs/partners/google-genai/tests/integration_tests/test_chat_models.py +++ b/libs/partners/google-genai/tests/integration_tests/test_chat_models.py @@ -1,6 +1,6 @@ """Test ChatGoogleGenerativeAI chat model.""" import pytest -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_google_genai.chat_models import ( ChatGoogleGenerativeAI, @@ -147,3 +147,40 @@ def test_chat_google_genai_invoke_multimodal_invalid_model() -> None: llm = ChatGoogleGenerativeAI(model=_MODEL) with pytest.raises(ChatGoogleGenerativeAIError): llm.invoke(messages) + + +def test_chat_google_genai_single_call_with_history() -> None: + model = ChatGoogleGenerativeAI(model=_MODEL) + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + response = model([message1, message2, message3]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) + + +def test_chat_google_genai_system_message_error() -> None: + model = ChatGoogleGenerativeAI(model=_MODEL) + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + system_message = SystemMessage(content="You're supposed to answer math questions.") + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + with pytest.raises(ValueError): + model([system_message, message1, message2, message3]) + + +def test_chat_google_genai_system_message() -> None: + model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True) + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + system_message = SystemMessage(content="You're supposed to answer math questions.") + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + response = model([system_message, message1, message2, message3]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) diff --git a/libs/partners/google-genai/tests/unit_tests/test_chat_models.py b/libs/partners/google-genai/tests/unit_tests/test_chat_models.py index 651d64508f963..93d5bb3231361 100644 --- a/libs/partners/google-genai/tests/unit_tests/test_chat_models.py +++ b/libs/partners/google-genai/tests/unit_tests/test_chat_models.py @@ -1,8 +1,12 @@ """Test chat model integration.""" +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.pydantic_v1 import SecretStr from pytest import CaptureFixture -from langchain_google_genai.chat_models import ChatGoogleGenerativeAI +from langchain_google_genai.chat_models import ( + ChatGoogleGenerativeAI, + _parse_chat_history, +) def test_integration_initialization() -> None: @@ -36,3 +40,21 @@ def test_api_key_masked_when_passed_via_constructor(capsys: CaptureFixture) -> N captured = capsys.readouterr() assert captured.out == "**********" + + +def test_parse_history() -> None: + system_input = "You're supposed to answer math questions." + text_question1, text_answer1 = "How much is 2+2?", "4" + text_question2 = "How much is 3+3?" + system_message = SystemMessage(content=system_input) + message1 = HumanMessage(content=text_question1) + message2 = AIMessage(content=text_answer1) + message3 = HumanMessage(content=text_question2) + messages = [system_message, message1, message2, message3] + history = _parse_chat_history(messages, convert_system_message_to_human=True) + assert len(history) == 3 + assert history[0] == { + "role": "user", + "parts": [{"text": system_input}, {"text": text_question1}], + } + assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}