diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 910a533c7..089b38b10 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -231,45 +231,45 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part: raise ValueError(msg) def _message_to_part(self, message: ChatMessage) -> Part: - if message.role == ChatRole.ASSISTANT and message.name: + if message.is_from(ChatRole.ASSISTANT) and message.name: p = Part() p.function_call.name = message.name p.function_call.args = {} for k, v in json.loads(message.text).items(): p.function_call.args[k] = v return p - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): p = Part() p.text = message.text return p - elif message.role == ChatRole.FUNCTION: + elif message.is_from(ChatRole.FUNCTION): p = Part() p.function_response.name = message.name p.function_response.response = message.text return p - elif message.role == ChatRole.USER: + elif message.is_from(ChatRole.USER): return self._convert_part(message.text) def _message_to_content(self, message: ChatMessage) -> Content: - if message.role == ChatRole.ASSISTANT and message.name: + if message.is_from(ChatRole.ASSISTANT) and message.name: part = Part() part.function_call.name = message.name part.function_call.args = {} for k, v in json.loads(message.text).items(): part.function_call.args[k] = v - elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}: + elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT): part = Part() part.text = message.text - elif message.role == ChatRole.FUNCTION: + elif message.is_from(ChatRole.FUNCTION): part = Part() part.function_response.name = message.name part.function_response.response = message.text - elif message.role == ChatRole.USER: + elif message.is_from(ChatRole.USER): part = self._convert_part(message.text) else: msg = f"Unsupported message role {message.role}" raise ValueError(msg) - role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model" + role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model" return Content(parts=[part], role=role) @component.output_types(replies=List[ChatMessage])