Skip to content

Commit

Permalink
fixes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
lkuligin committed Dec 18, 2023
1 parent 95ad94b commit da2f4d1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 28 deletions.
46 changes: 20 additions & 26 deletions libs/partners/google-genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _create_retry_decorator() -> Callable[[Any], Any]:
)


def _chat_with_retry(*args: Any, generation_method: Callable, **kwargs: Any) -> Any:
def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
Expand All @@ -126,9 +126,9 @@ def _chat_with_retry(*args: Any, generation_method: Callable, **kwargs: Any) ->
from google.api_core.exceptions import InvalidArgument # type: ignore

@retry_decorator
def _chat_with_retry(*args: Any, **kwargs: Any) -> Any:
def _chat_with_retry(**kwargs: Any) -> Any:
try:
return generation_method(*args, **kwargs)
return generation_method(**kwargs)
except InvalidArgument as e:
# Do not retry for these errors.
raise ChatGoogleGenerativeAIError(
Expand All @@ -137,12 +137,10 @@ def _chat_with_retry(*args: Any, **kwargs: Any) -> Any:
except Exception as e:
raise e

return _chat_with_retry(*args, **kwargs)
return _chat_with_retry(**kwargs)


async def _achat_with_retry(
*args: Any, generation_method: Callable, **kwargs: Any
) -> Any:
async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any:
"""
Executes a chat generation method with retry logic using tenacity.
Expand All @@ -161,9 +159,9 @@ async def _achat_with_retry(
from google.api_core.exceptions import InvalidArgument # type: ignore

@retry_decorator
async def _achat_with_retry(*args: Any, **kwargs: Any) -> Any:
async def _achat_with_retry(**kwargs: Any) -> Any:
try:
return await generation_method(*args, **kwargs)
return await generation_method(**kwargs)
except InvalidArgument as e:
# Do not retry for these errors.
raise ChatGoogleGenerativeAIError(
Expand All @@ -172,7 +170,7 @@ async def _achat_with_retry(*args: Any, **kwargs: Any) -> Any:
except Exception as e:
raise e

return await _achat_with_retry(*args, **kwargs)
return await _achat_with_retry(**kwargs)


def _is_openai_parts_format(part: dict) -> bool:
Expand Down Expand Up @@ -249,10 +247,11 @@ def _url_to_pil(image_source: str) -> Image:


def _convert_to_parts(
content: Sequence[Union[str, dict]],
raw_content: Union[str, Sequence[Union[str, dict]]],
) -> List[genai.types.PartType]:
"""Converts a list of LangChain messages into a google parts."""
parts = []
content = [raw_content] if isinstance(raw_content, str) else raw_content
for part in content:
if isinstance(part, str):
parts.append(genai.types.PartDict(text=part))
Expand Down Expand Up @@ -292,24 +291,24 @@ def _parse_chat_history(
) -> List[genai.types.ContentDict]:
messages: List[genai.types.MessageDict] = []

raw_system_message: Optional[str] = None
raw_system_message: Optional[SystemMessage] = None
for i, message in enumerate(input_messages):
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError("""SystemMessages are not yet supported!
raise ValueError(
"""SystemMessages are not yet supported!
To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:
llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human=True)
""")
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = (
message.content if isinstance(message.content, str) else None
"""
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
continue
elif isinstance(message, AIMessage):
role = "model"
Expand All @@ -320,20 +319,14 @@ def _parse_chat_history(
f"Unexpected message with type {type(message)} at the position {i}."
)

raw_content = message.content
if isinstance(raw_content, str):
raw_content = [raw_content]
parts = _convert_to_parts(raw_content)
parts = _convert_to_parts(message.content)
if raw_system_message:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
if "text" in parts[0]:
parts[0]["text"] = f"{raw_system_message}\n{parts[0]['text']}"
else:
parts = [{"text": raw_system_message}] + parts
parts = _convert_to_parts(raw_system_message.content) + parts
raw_system_message = None
messages.append({"role": role, "parts": parts})
return messages
Expand Down Expand Up @@ -471,7 +464,8 @@ class ChatGoogleGenerativeAI(BaseChatModel):
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.
Gemini does not support system messages; any unsupported messages will raise an error."""
Gemini does not support system messages; any unsupported messages will
raise an error."""

class Config:
allow_population_by_field_name = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ def test_parse_history() -> None:
assert len(history) == 3
assert history[0] == {
"role": "user",
"parts": [{"text": f"{system_input}\n{text_question1}"}],
"parts": [{"text": system_input}, {"text": text_question1}],
}
assert history[1] == {"role": "model", "parts": [{"text": f"{text_answer1}"}]}
assert history[1] == {"role": "model", "parts": [{"text": text_answer1}]}

0 comments on commit da2f4d1

Please sign in to comment.