diff --git a/examples/foundational/14j-function-calling-nim.py b/examples/foundational/14j-function-calling-nim.py index 68bb3b52e..4059bc386 100644 --- a/examples/foundational/14j-function-calling-nim.py +++ b/examples/foundational/14j-function-calling-nim.py @@ -65,7 +65,7 @@ async def main(): ) llm = NimLLMService( - api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct" + api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.3-70b-instruct" ) # Register a function_name of None to get all functions # sent to the same callback with an additional function_name parameter. @@ -76,18 +76,18 @@ async def main(): type="function", function={ "name": "get_current_weather", - "description": "Get the current weather", + "description": "Returns the current weather at a location, if one is specified, and defaults to the user's location.", "parameters": { "type": "object", "properties": { "location": { "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + "description": "The location to find the weather of, or if not provided, it's the default location.", }, "format": { "type": "string", "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location.", + "description": "Whether to use SI or USCS units (celsius or fahrenheit).", }, }, "required": ["location", "format"], diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok.py index 505dfcca5..a6a1b3a64 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok.py @@ -5,11 +5,102 @@ # +import json +from dataclasses import dataclass + from loguru import logger from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.openai import OpenAILLMService +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.services.openai import ( + OpenAIAssistantContextAggregator, + OpenAILLMService, + OpenAIUserContextAggregator, +) + + +class GrokAssistantContextAggregator(OpenAIAssistantContextAggregator): + """Custom assistant context aggregator for Grok that handles empty content requirement.""" + + async def _push_aggregation(self): + if not ( + self._aggregation or self._function_call_result or self._pending_image_frame_message + ): + return + + run_llm = False + + aggregation = self._aggregation + self._reset() + + try: + if self._function_call_result: + frame = self._function_call_result + self._function_call_result = None + if frame.result: + # Grok requires an empty content field for function calls + self._context.add_message( + { + "role": "assistant", + "content": "", # Required by Grok + "tool_calls": [ + { + "id": frame.tool_call_id, + "function": { + "name": frame.function_name, + "arguments": json.dumps(frame.arguments), + }, + "type": "function", + } + ], + } + ) + self._context.add_message( + { + "role": "tool", + "content": json.dumps(frame.result), + "tool_call_id": frame.tool_call_id, + } + ) + # Only run the LLM if there are no more function calls in progress. + run_llm = not bool(self._function_calls_in_progress) + else: + self._context.add_message({"role": "assistant", "content": aggregation}) + + if self._pending_image_frame_message: + frame = self._pending_image_frame_message + self._pending_image_frame_message = None + self._context.add_image_frame_message( + format=frame.user_image_raw_frame.format, + size=frame.user_image_raw_frame.size, + image=frame.user_image_raw_frame.image, + text=frame.text, + ) + run_llm = True + + if run_llm: + await self._user_context_aggregator.push_context_frame() + + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + + except Exception as e: + logger.error(f"Error processing frame: {e}") + + +@dataclass +class GrokContextAggregatorPair: + _user: "OpenAIUserContextAggregator" + _assistant: "GrokAssistantContextAggregator" + + def user(self) -> "OpenAIUserContextAggregator": + return self._user + + def assistant(self) -> "GrokAssistantContextAggregator": + return self._assistant class GrokLLMService(OpenAILLMService): @@ -101,3 +192,13 @@ async def start_llm_usage_metrics(self, tokens: LLMTokenUsage): # Update completion tokens count if it has increased if tokens.completion_tokens > self._completion_tokens: self._completion_tokens = tokens.completion_tokens + + @staticmethod + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> GrokContextAggregatorPair: + user = OpenAIUserContextAggregator(context) + assistant = GrokAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) + return GrokContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py index 43ad16536..85e1a95f0 100644 --- a/src/pipecat/services/openai.py +++ b/src/pipecat/services/openai.py @@ -559,7 +559,6 @@ async def _push_aggregation(self): self._context.add_message( { "role": "assistant", - "content": "", # content field required for Grok function calling "tool_calls": [ { "id": frame.tool_call_id,