Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ccurme committed Apr 10, 2024
2 parents d98d412 + 8cac88a commit c7cfdd2
Show file tree
Hide file tree
Showing 19 changed files with 785 additions and 501 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# 🦜️🔗 LangChain Google

This repository contains two packages with Google integrations with Langhchain:
This repository contains two packages with Google integrations with LangChain:
- [langchain-google-genai](https://pypi.org/project/langchain-google-genai/) implements integrations of Google [Generative AI](https://ai.google.dev/) models.
- [langchain-google-vertexai](https://pypi.org/project/langchain-google-vertexai/) implements integrations of Google Cloud [Generative AI on Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/overview)
42 changes: 14 additions & 28 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import os
import warnings
from io import BytesIO
from typing import (
Any,
Expand Down Expand Up @@ -300,27 +301,16 @@ def _convert_to_parts(

def _parse_chat_history(
input_messages: Sequence[BaseMessage], convert_system_message_to_human: bool = False
) -> List[genai.types.ContentDict]:
) -> Tuple[Optional[genai.types.ContentDict], List[genai.types.ContentDict]]:
messages: List[genai.types.MessageDict] = []

raw_system_message: Optional[SystemMessage] = None
for i, message in enumerate(input_messages):
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
if convert_system_message_to_human:
warnings.warn("Convert_system_message_to_human will be deprecated!")

llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
system_instruction: Optional[genai.types.ContentDict] = None
for i, message in enumerate(input_messages):
if i == 0 and isinstance(message, SystemMessage):
system_instruction = _convert_to_parts(message.content)
continue
elif isinstance(message, AIMessage):
role = "model"
Expand Down Expand Up @@ -365,16 +355,8 @@ def _parse_chat_history(
f"Unexpected message with type {type(message)} at the position {i}."
)

if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message.content) + parts
raw_system_message = None
messages.append({"role": role, "parts": parts})
return messages
return system_instruction, messages


def _parse_response_candidate(
Expand Down Expand Up @@ -659,11 +641,15 @@ def _prepare_chat(
)

params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
system_instruction, history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
if self.client._system_instruction != system_instruction:
self.client = genai.GenerativeModel(
model_name=self.model, system_instruction=system_instruction
)
chat = client.start_chat(history=history)
return params, chat, message

Expand Down
462 changes: 255 additions & 207 deletions libs/genai/poetry.lock

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions libs/genai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ license = "MIT"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
langchain-core = "^0.1"
google-generativeai = "^0.4.1"
langchain-core = ">=0.1.27,<0.2"
google-generativeai = "^0.5.0"
pillow = { version = "^10.1.0", optional = true }

[tool.poetry.extras]
Expand All @@ -30,6 +30,7 @@ syrupy = "^4.0.2"
pytest-watcher = "^0.3.4"
pytest-asyncio = "^0.21.1"
numpy = "^1.26.2"
langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"}

[tool.poetry.group.codespell]
optional = true
Expand All @@ -56,6 +57,7 @@ types-requests = "^2.28.11.5"
types-google-cloud-ndb = "^2.2.0.1"
types-pillow = "^10.1.0.2"
types-protobuf = "^4.24.0.20240302"
langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"}

[tool.poetry.group.dev]
optional = true
Expand All @@ -65,6 +67,7 @@ pillow = "^10.1.0"
types-requests = "^2.31.0.10"
types-pillow = "^10.1.0.2"
types-google-cloud-ndb = "^2.2.0.1"
langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"}

[tool.ruff]
select = [
Expand Down
22 changes: 10 additions & 12 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ def test_chat_google_genai_invoke_multimodal() -> None:
assert len(chunk.content.strip()) > 0


def test_system_message() -> None:
messages = [
SystemMessage(content="Be a helful assistant."),
HumanMessage(content="Hi, how are you?"),
]
llm = ChatGoogleGenerativeAI(model="models/gemini-1.0-pro-latest")
answer = llm.invoke(messages)
assert isinstance(answer.content, str)


def test_chat_google_genai_invoke_multimodal_too_many_messages() -> None:
# Only supports 1 turn...
messages: list = [
Expand Down Expand Up @@ -168,18 +178,6 @@ def test_chat_google_genai_single_call_with_history() -> None:
assert isinstance(response.content, str)


def test_chat_google_genai_system_message_error() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL)
text_question1, text_answer1 = "How much is 2+2?", "4"
text_question2 = "How much is 3+3?"
system_message = SystemMessage(content="You're supposed to answer math questions.")
message1 = HumanMessage(content=text_question1)
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
with pytest.raises(ValueError):
model([system_message, message1, message2, message3])


def test_chat_google_genai_system_message() -> None:
model = ChatGoogleGenerativeAI(model=_MODEL, convert_system_message_to_human=True)
text_question1, text_answer1 = "How much is 2+2?", "4"
Expand Down
2 changes: 1 addition & 1 deletion libs/genai/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_safety_settings_gemini() -> None:
llm = GoogleGenerativeAI(temperature=0, model="gemini-pro")
output = llm.generate(prompts=["how to make a bomb?"])
assert isinstance(output, LLMResult)
assert len(output.generations[0]) == 0
assert len(output.generations[0]) > 0

# safety filters
safety_settings = {
Expand Down
7 changes: 5 additions & 2 deletions libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,16 @@ def test_parse_history() -> None:
message2 = AIMessage(content=text_answer1)
message3 = HumanMessage(content=text_question2)
messages = [system_message, message1, message2, message3]
history = _parse_chat_history(messages, convert_system_message_to_human=True)
system_instruction, history = _parse_chat_history(
messages, convert_system_message_to_human=True
)
assert len(history) == 3
assert history[0] == {
"role": "user",
"parts": [{"text": system_input}, {"text": text_question1}],
"parts": [{"text": text_question1}],
}
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}
assert system_instruction == [{"text": system_input}]


@pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"])
Expand Down
Loading

0 comments on commit c7cfdd2

Please sign in to comment.