Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add system_instructions for Gemini models #125

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add system_instructions for Gemini models
  • Loading branch information
eliasecchig committed Apr 10, 2024
commit 2557f90060218664c85eab15fb3c9a6132bc6e6c
102 changes: 76 additions & 26 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _parse_chat_history_gemini(
history: List[BaseMessage],
project: Optional[str] = None,
convert_system_message_to_human: Optional[bool] = False,
) -> List[Content]:
) -> tuple[Content | None, list[Content]]:
def _convert_to_prompt(part: Union[str, Dict]) -> Part:
if isinstance(part, str):
return Part.from_text(part)
Expand All @@ -142,24 +142,35 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
return [_convert_to_prompt(part) for part in raw_content]

vertex_messages = []
raw_system_message = None
convert_system_message_to_human_content = None
system_instruction = None
for i, message in enumerate(history):
if (
i == 0
and isinstance(message, SystemMessage)
and not convert_system_message_to_human
):
raise ValueError(
"""SystemMessages are not yet supported!

To automatically convert the leading SystemMessage to a HumanMessage,
set `convert_system_message_to_human` to True. Example:

llm = ChatVertexAI(model_name="gemini-pro", convert_system_message_to_human=True)
"""
if system_instruction is not None:
raise ValueError(
"Detected more than one SystemMessage in the list of messages."
"Gemini APIs support the insertion of only SystemMessage."
)
else:
system_instruction = Content(
role="user", parts=_convert_to_parts(message)
)
continue
elif (
i == 0
and isinstance(message, SystemMessage)
and convert_system_message_to_human
):
logger.warning(
"gemini models released from April 2024 support SystemMessages"
"natively. For best performances, when working with these models,"
"set convert_system_message_to_human to False"
)
elif i == 0 and isinstance(message, SystemMessage):
raw_system_message = message
convert_system_message_to_human_content = message
continue
elif isinstance(message, AIMessage):
raw_function_call = message.additional_kwargs.get("function_call")
Expand Down Expand Up @@ -193,18 +204,18 @@ def _convert_to_parts(message: BaseMessage) -> List[Part]:
f"Unexpected message with type {type(message)} at the position {i}."
)

if raw_system_message:
if convert_system_message_to_human_content:
if role == "model":
raise ValueError(
"SystemMessage should be followed by a HumanMessage and "
"not by AIMessage."
)
parts = _convert_to_parts(raw_system_message) + parts
raw_system_message = None
parts = _convert_to_parts(convert_system_message_to_human_content) + parts
convert_system_message_to_human_content = None

vertex_message = Content(role=role, parts=parts)
vertex_messages.append(vertex_message)
return vertex_messages
return system_instruction, vertex_messages


def _parse_examples(examples: List[BaseMessage]) -> List[InputOutputTextPair]:
Expand Down Expand Up @@ -247,6 +258,21 @@ def _get_question(messages: List[BaseMessage]) -> HumanMessage:
return question


def _get_client_with_sys_instruction(
client: GenerativeModel,
system_instruction: Content,
model_name: str,
safety_settings: Optional[Dict] = None,
):
if client._system_instruction != system_instruction:
client = GenerativeModel(
model_name=model_name,
safety_settings=safety_settings,
system_instruction=system_instruction,
)
return client


def _parse_response_candidate(response_candidate: "Candidate") -> AIMessage:
try:
content = response_candidate.text
Expand Down Expand Up @@ -275,10 +301,9 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
"Underlying model name."
examples: Optional[List[BaseMessage]] = None
convert_system_message_to_human: bool = False
"""Whether to merge any leading SystemMessage into the following HumanMessage.

Gemini does not support system messages; any unsupported messages will
raise an error."""
"""[Deprecated] Since new Gemini models support setting a System Message,
setting this parameter to True is discouraged.
"""

@classmethod
def is_lc_serializable(self) -> bool:
Expand Down Expand Up @@ -357,12 +382,18 @@ def _generate(
msg_params["candidate_count"] = params.pop("candidate_count")

if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(
system_instruction, history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)

Expand Down Expand Up @@ -441,12 +472,19 @@ async def _agenerate(
msg_params["candidate_count"] = params.pop("candidate_count")

if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(
system_instruction, history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()

self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
Expand Down Expand Up @@ -501,18 +539,24 @@ def _stream(
) -> Iterator[ChatGenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
if self._is_gemini_model:
history_gemini = _parse_chat_history_gemini(
safety_settings = params.pop("safety_settings", None)
system_instruction, history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
# set param to `functions` until core tool/function calling implemented
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
safety_settings = params.pop("safety_settings", None)
with telemetry.tool_context_manager(self._user_agent):
responses = chat.send_message(
message,
Expand Down Expand Up @@ -567,17 +611,23 @@ async def _astream(
if not self._is_gemini_model:
raise NotImplementedError()
params = self._prepare_params(stop=stop, stream=True, **kwargs)
history_gemini = _parse_chat_history_gemini(
safety_settings = params.pop("safety_settings", None)
system_instruction, history_gemini = _parse_chat_history_gemini(
messages,
project=self.project,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history_gemini.pop()
self.client = _get_client_with_sys_instruction(
client=self.client,
system_instruction=system_instruction,
model_name=self.model_name,
safety_settings=safety_settings,
)
with telemetry.tool_context_manager(self._user_agent):
chat = self.client.start_chat(history=history_gemini)
raw_tools = params.pop("functions") if "functions" in params else None
tools = _format_tools_to_vertex_tool(raw_tools) if raw_tools else None
safety_settings = params.pop("safety_settings", None)
with telemetry.tool_context_manager(self._user_agent):
async for chunk in await chat.send_message_async(
message,
Expand Down
38 changes: 21 additions & 17 deletions libs/vertexai/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading