Skip to content

Commit

Permalink
avoid directly accessing role
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Dec 10, 2024
1 parent 324a7fa commit 04b1bc2
Showing 1 changed file with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -231,45 +231,45 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
raise ValueError(msg)

def _message_to_part(self, message: ChatMessage) -> Part:
if message.role == ChatRole.ASSISTANT and message.name:
if message.is_from(ChatRole.ASSISTANT) and message.name:
p = Part()
p.function_call.name = message.name
p.function_call.args = {}
for k, v in json.loads(message.text).items():
p.function_call.args[k] = v
return p
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT):
p = Part()
p.text = message.text
return p
elif message.role == ChatRole.FUNCTION:
elif message.is_from(ChatRole.FUNCTION):
p = Part()
p.function_response.name = message.name
p.function_response.response = message.text
return p
elif message.role == ChatRole.USER:
elif message.is_from(ChatRole.USER):
return self._convert_part(message.text)

def _message_to_content(self, message: ChatMessage) -> Content:
if message.role == ChatRole.ASSISTANT and message.name:
if message.is_from(ChatRole.ASSISTANT) and message.name:
part = Part()
part.function_call.name = message.name
part.function_call.args = {}
for k, v in json.loads(message.text).items():
part.function_call.args[k] = v
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
elif message.is_from(ChatRole.SYSTEM) or message.is_from(ChatRole.ASSISTANT):
part = Part()
part.text = message.text
elif message.role == ChatRole.FUNCTION:
elif message.is_from(ChatRole.FUNCTION):
part = Part()
part.function_response.name = message.name
part.function_response.response = message.text
elif message.role == ChatRole.USER:
elif message.is_from(ChatRole.USER):
part = self._convert_part(message.text)
else:
msg = f"Unsupported message role {message.role}"
raise ValueError(msg)
role = "user" if message.role in [ChatRole.USER, ChatRole.FUNCTION] else "model"
role = "user" if message.is_from(ChatRole.USER) or message.is_from(ChatRole.FUNCTION) else "model"
return Content(parts=[part], role=role)

@component.output_types(replies=List[ChatMessage])
Expand Down

0 comments on commit 04b1bc2

Please sign in to comment.