Skip to content
This repository has been archived by the owner on Aug 12, 2024. It is now read-only.

Commit

Permalink
fixes in function calling args (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsosio authored Jun 10, 2024
1 parent 7c62b2d commit b98134a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
9 changes: 8 additions & 1 deletion prem_utils/connectors/groq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging

from groq import AsyncGroq, Groq
Expand Down Expand Up @@ -63,6 +64,12 @@ def parse_chunk(self, chunk):
],
}

def _get_arguments(self, arguments):
try:
return json.loads(arguments)
except json.JSONDecodeError:
return None

async def chat_completion(
self,
model: str,
Expand Down Expand Up @@ -144,7 +151,7 @@ async def chat_completion(
{
"id": tool_call.id,
"function": {
"arguments": tool_call.function.arguments,
"arguments": self._get_arguments(tool_call.function.arguments),
"name": tool_call.function.name,
},
"type": tool_call.type,
Expand Down
9 changes: 8 additions & 1 deletion prem_utils/connectors/mistral.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from collections.abc import Sequence

from mistralai.async_client import MistralAsyncClient
Expand Down Expand Up @@ -46,6 +47,12 @@ def build_messages(self, messages):
chat_messages.append(chat_message)
return chat_messages

def _get_arguments(self, arguments):
try:
return json.loads(arguments)
except json.JSONDecodeError:
return None

async def chat_completion(
self,
model: str,
Expand Down Expand Up @@ -100,7 +107,7 @@ async def chat_completion(
{
"id": tool_call.id,
"function": {
"arguments": tool_call.function.arguments,
"arguments": self._get_arguments(tool_call.function.arguments),
"name": tool_call.function.name,
},
"type": tool_call.type.value if tool_call.type else None,
Expand Down
8 changes: 7 additions & 1 deletion prem_utils/connectors/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def parse_chunk(self, chunk):
],
}

def _get_arguments(self, arguments):
try:
return json.loads(arguments)
except json.JSONDecodeError:
return None

async def chat_completion(
self,
model: str,
Expand Down Expand Up @@ -148,7 +154,7 @@ async def chat_completion(
{
"id": tool_call.id,
"function": {
"arguments": tool_call.function.arguments,
"arguments": self._get_arguments(tool_call.function.arguments),
"name": tool_call.function.name,
},
"type": tool_call.type,
Expand Down

0 comments on commit b98134a

Please sign in to comment.