From 9360b7a92b9a07d7e725fdeddbad691d5134226c Mon Sep 17 00:00:00 2001 From: vipyne Date: Mon, 16 Dec 2024 23:15:44 -0600 Subject: [PATCH] services(nim): fix function call --- .../foundational/14j-function-calling-nim.py | 9 +- src/pipecat/services/nim.py | 102 +++++++++++++++++- 2 files changed, 106 insertions(+), 5 deletions(-) diff --git a/examples/foundational/14j-function-calling-nim.py b/examples/foundational/14j-function-calling-nim.py index 68bb3b52e..9ae9bb687 100644 --- a/examples/foundational/14j-function-calling-nim.py +++ b/examples/foundational/14j-function-calling-nim.py @@ -65,8 +65,9 @@ async def main(): ) llm = NimLLMService( - api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct" + api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.3-70b-instruct" ) + # Register a function_name of None to get all functions # sent to the same callback with an additional function_name parameter. llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather) @@ -76,18 +77,18 @@ async def main(): type="function", function={ "name": "get_current_weather", - "description": "Get the current weather", + "description": "Returns the current weather at a location, if one is specified, and defaults to the user's location.", "parameters": { "type": "object", "properties": { "location": { "type": "string", - "description": "The city and state, e.g. San Francisco, CA", + "description": "The location to find the weather of, or if not provided, it's the default location.", }, "format": { "type": "string", "enum": ["celsius", "fahrenheit"], - "description": "The temperature unit to use. Infer this from the users location.", + "description": "Whether to use SI or USCS units (celsius or fahrenheit).", }, }, "required": ["location", "format"], diff --git a/src/pipecat/services/nim.py b/src/pipecat/services/nim.py index 2b57a5047..c53d399b7 100644 --- a/src/pipecat/services/nim.py +++ b/src/pipecat/services/nim.py @@ -4,11 +4,100 @@ # SPDX-License-Identifier: BSD 2-Clause License # +import json + + +from dataclasses import dataclass + +from loguru import logger from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai import ( + OpenAIAssistantContextAggregator, + OpenAIUserContextAggregator, +) + + +class NimAssistantContextAggregator(OpenAIAssistantContextAggregator): + async def _push_aggregation(self): + if not ( + self._aggregation or self._function_call_result or self._pending_image_frame_message + ): + return + + run_llm = False + + aggregation = self._aggregation + self._reset() + + try: + if self._function_call_result: + frame = self._function_call_result + self._function_call_result = None + if frame.result: + self._context.add_message( + { + "role": "assistant", + # "content": "", # empty content here will break nim + "tool_calls": [ + { + "id": frame.tool_call_id, + "function": { + "name": frame.function_name, + "arguments": json.dumps(frame.arguments), + }, + "type": "function", + } + ], + } + ) + self._context.add_message( + { + "role": "tool", + "content": json.dumps(frame.result), + "tool_call_id": frame.tool_call_id, + } + ) + # Only run the LLM if there are no more function calls in progress. + run_llm = not bool(self._function_calls_in_progress) + else: + self._context.add_message({"role": "assistant", "content": aggregation}) + + if self._pending_image_frame_message: + frame = self._pending_image_frame_message + self._pending_image_frame_message = None + self._context.add_image_frame_message( + format=frame.user_image_raw_frame.format, + size=frame.user_image_raw_frame.size, + image=frame.user_image_raw_frame.image, + text=frame.text, + ) + run_llm = True + + if run_llm: + await self._user_context_aggregator.push_context_frame() + + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + + except Exception as e: + logger.exception(f"Error processing frame: {e}") +@dataclass +class NimContextAggregatorPair: + _user: "OpenAIUserContextAggregator" + _assistant: "NimAssistantContextAggregator" + + def user(self) -> "OpenAIUserContextAggregator": + return self._user + + def assistant(self) -> "NimAssistantContextAggregator": + return self._assistant class NimLLMService(OpenAILLMService): """A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API. @@ -95,3 +184,14 @@ async def start_llm_usage_metrics(self, tokens: LLMTokenUsage): # Update completion tokens count if it has increased if tokens.completion_tokens > self._completion_tokens: self._completion_tokens = tokens.completion_tokens + + + @staticmethod + def create_context_aggregator( + context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True + ) -> NimContextAggregatorPair: + user = OpenAIUserContextAggregator(context) + assistant = NimAssistantContextAggregator( + user, expect_stripped_words=assistant_expect_stripped_words + ) + return NimContextAggregatorPair(_user=user, _assistant=assistant)