Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Commit

Permalink
added cohere
Browse files Browse the repository at this point in the history
  • Loading branch information
nsosio committed Jun 10, 2024
1 parent 48ccd72 commit 402a62f
Showing 1 changed file with 53 additions and 2 deletions.
55 changes: 53 additions & 2 deletions prem_utils/connectors/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,32 @@ def parse_chunk(self, chunk):
],
}

def _parse_tools(self, tools):
if tools is None:
return None
parsed_tools = []

for tool in tools:
parameters = tool["function"]["parameters"]
parameter_definitions = {}

for param, details in parameters["properties"].items():
parameter_definitions[param] = {
"description": details["description"],
"type": details["type"],
"required": param in parameters.get("required", []),
}

transformed_tool = {
"name": tool["function"]["name"],
"description": tool["function"]["description"],
"parameter_definitions": parameter_definitions,
}

parsed_tools.append(transformed_tool)

return parsed_tools

async def chat_completion(
self,
model: str,
Expand All @@ -76,7 +102,16 @@ async def chat_completion(
stream: bool = False,
temperature: float = 1,
top_p: float = 1,
tools=None,
):
if tools is not None and stream:
raise errors.PremProviderError(
"Cannot use tools with stream=True",
provider="cohere",
model=model,
provider_message="Cannot use tools with stream=True",
)
tools = self._parse_tools(tools)
chat_history, message = self.preprocess_messages(messages)
try:
if stream:
Expand All @@ -98,6 +133,7 @@ async def chat_completion(
p=top_p,
temperature=temperature,
stream=stream,
tools=tools,
)

except (CohereAPIError, CohereConnectionError) as error:
Expand All @@ -107,9 +143,24 @@ async def chat_completion(
plain_response = {
"choices": [
{
"finish_reason": "stop",
"finish_reason": "stop" if not response.tool_calls else "tools",
"index": 0,
"message": {"content": response.text, "role": "assistant"},
"message": {
"content": response.text,
"role": "assistant",
"tool_calls": [
{
"id": str(uuid4()),
"function": {
"arguments": tool_call.parameters,
"name": tool_call.name,
},
}
for tool_call in response.tool_calls
]
if response.tool_calls
else None,
},
}
],
"created": connector_utils.default_chatcompletion_response_created(),
Expand Down

0 comments on commit 402a62f

Please sign in to comment.