Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Commit

Permalink
aligned anthropic with openai; minor changes to openai
Browse files Browse the repository at this point in the history
  • Loading branch information
nsosio committed Jun 10, 2024
1 parent 4d04998 commit da4ca42
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
25 changes: 22 additions & 3 deletions prem_utils/connectors/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,33 @@ def preprocess_messages(self, messages):

def _get_content(self, response, tools=False):
if tools:
return {"content": response.content[0].text, "role": "assistant", "tools": []}
return {"content": response.content[0].text, "role": "assistant", "tool_calls": []}
else:
tool_messages = filter(lambda x: "input" in x.__dir__() and "name" in x.__dir__(), response.content)
return {
"content": "",
"role": "assistant",
"tools": [{"input": tool_message.input, "name": tool_message.name} for tool_message in tool_messages],
"tool_calls": [
{"function_arguments": tool_message.input, "name": tool_message.name, "type": "function"}
for tool_message in tool_messages
],
}

def _parse_tools(self, tools: list[dict[str, any]]):
if not tools:
return []
parsed_tools = []

for tool in tools:
transformed_tool = {
"name": tool["function"]["name"],
"description": tool["function"]["description"],
"input_schema": tool["function"]["parameters"],
}
parsed_tools.append(transformed_tool)

return parsed_tools

async def chat_completion(
self,
model: str,
Expand All @@ -107,8 +125,9 @@ async def chat_completion(
stream: bool = False,
temperature: float = 1,
top_p: float = 1,
tools=[],
tools=None,
):
tools = self._parse_tools(tools)
if tools != [] and stream:
raise errors.PremProviderError(
"Cannot use tools with stream=True",
Expand Down
10 changes: 8 additions & 2 deletions prem_utils/connectors/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,15 @@ async def chat_completion(
temperature: float = 1,
top_p: float = 1,
tools=None,
tool_choice=None,
):
if tools is not None and stream:
raise errors.PremProviderError(
"Cannot use tools with stream=True",
provider="openai",
model=model,
provider_message="Cannot use tools with stream=True",
)

if self.prompt_template is not None:
messages = self.apply_prompt_template(messages)

Expand All @@ -103,7 +110,6 @@ async def chat_completion(
top_p=top_p,
logprobs=log_probs,
logit_bias=logit_bias,
tool_choice=tool_choice,
tools=tools,
)
try:
Expand Down

0 comments on commit da4ca42

Please sign in to comment.