From 402a62fda46bf66b45b554d858c7ac87f0e5285d Mon Sep 17 00:00:00 2001 From: nsosio Date: Mon, 10 Jun 2024 13:28:48 +0200 Subject: [PATCH] added cohere --- prem_utils/connectors/cohere.py | 55 +++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/prem_utils/connectors/cohere.py b/prem_utils/connectors/cohere.py index d926f5b..7307cd7 100644 --- a/prem_utils/connectors/cohere.py +++ b/prem_utils/connectors/cohere.py @@ -62,6 +62,32 @@ def parse_chunk(self, chunk): ], } + def _parse_tools(self, tools): + if tools is None: + return None + parsed_tools = [] + + for tool in tools: + parameters = tool["function"]["parameters"] + parameter_definitions = {} + + for param, details in parameters["properties"].items(): + parameter_definitions[param] = { + "description": details["description"], + "type": details["type"], + "required": param in parameters.get("required", []), + } + + transformed_tool = { + "name": tool["function"]["name"], + "description": tool["function"]["description"], + "parameter_definitions": parameter_definitions, + } + + parsed_tools.append(transformed_tool) + + return parsed_tools + async def chat_completion( self, model: str, @@ -76,7 +102,16 @@ async def chat_completion( stream: bool = False, temperature: float = 1, top_p: float = 1, + tools=None, ): + if tools is not None and stream: + raise errors.PremProviderError( + "Cannot use tools with stream=True", + provider="cohere", + model=model, + provider_message="Cannot use tools with stream=True", + ) + tools = self._parse_tools(tools) chat_history, message = self.preprocess_messages(messages) try: if stream: @@ -98,6 +133,7 @@ async def chat_completion( p=top_p, temperature=temperature, stream=stream, + tools=tools, ) except (CohereAPIError, CohereConnectionError) as error: @@ -107,9 +143,24 @@ async def chat_completion( plain_response = { "choices": [ { - "finish_reason": "stop", + "finish_reason": "stop" if not response.tool_calls else "tools", "index": 0, - "message": {"content": response.text, "role": "assistant"}, + "message": { + "content": response.text, + "role": "assistant", + "tool_calls": [ + { + "id": str(uuid4()), + "function": { + "arguments": tool_call.parameters, + "name": tool_call.name, + }, + } + for tool_call in response.tool_calls + ] + if response.tool_calls + else None, + }, } ], "created": connector_utils.default_chatcompletion_response_created(),