Skip to content

Commit

Permalink
fix: function call error
Browse files Browse the repository at this point in the history
  • Loading branch information
whybeyoung committed May 21, 2024
1 parent beaf8f6 commit f1b48fc
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "spark-ai-python"
version = "0.3.27"
version = "0.3.30"
description = "a sdk for iflytek's spark LLM."
authors = ["whybeyoung <[email protected]>", "mingduan <[email protected]>"]
license = "MIT"
Expand Down
4 changes: 2 additions & 2 deletions sparkai/core/messages/function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, List, Literal

from sparkai.core.messages import AIMessageChunk
from sparkai.core.messages.base import (
BaseMessage,
BaseMessageChunk,
Expand Down Expand Up @@ -96,10 +97,9 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
return self.__class__(
name=self.name,
content=merge_content(self.content, other.content),
function_call=other.function_call, # function call no need chunk now
function_call=self._merge_kwargs_dict(self.function_call,other.function_call), # function call no need chunk now
additional_kwargs=self._merge_kwargs_dict(
self.additional_kwargs, other.additional_kwargs
),
)

return super().__add__(other)
8 changes: 5 additions & 3 deletions sparkai/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ async def _astream(
function_definition = []
if "function_definition" in kwargs:
function_definition = kwargs['function_definition']
default_chunk_class = FunctionCallMessageChunk

llm_output = {}
if "llm_output" in kwargs:
Expand Down Expand Up @@ -305,6 +306,7 @@ def _stream(
function_definition = []
if "function_definition" in kwargs:
function_definition = kwargs['function_definition']
default_chunk_class = FunctionCallMessageChunk

llm_output = {}
if "llm_output" in kwargs:
Expand Down Expand Up @@ -354,14 +356,14 @@ def _generate(
function_definition = []
if "function_definition" in kwargs:
function_definition = kwargs['function_definition']
converted_messages = convert_message_to_dict(messages)

if self.streaming:
stream_iter = self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter, llm_output)

converted_messages = convert_message_to_dict(messages)
self.client.arun(
converted_messages,
self.spark_user_id,
Expand Down Expand Up @@ -611,8 +613,8 @@ def on_message(self, ws: Any, message: str) -> None:
self.blocking_message["function_call"] = function_call

if status == 2:
if not ws.streaming:
self.queue.put({"data": self.blocking_message})
#if ws.streaming:
self.queue.put({"data": self.blocking_message})
usage_data = (
data.get("payload", {}).get("usage", {}).get("text", {})
if data
Expand Down

0 comments on commit f1b48fc

Please sign in to comment.