Skip to content

Commit

Permalink
fix: starter pack live api async tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
eliasecchig committed Jan 31, 2025
1 parent f035a04 commit ddd3a61
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,11 @@
vector_store = get_vector_store(embedding=embedding, urls=URLS)
retriever = vector_store.as_retriever()


def retrieve_docs(query: str) -> Dict[str, str]:
async def retrieve_docs(query: str) -> Dict[str, str]:
"""
Retrieves pre-formatted documents about MLOps (Machine Learning Operations),
Gen AI lifecycle, and production deployment best practices.
You should always warn the user that this tool might take few seconds.
Args:
query: Search query string related to MLOps, Gen AI, or production deployment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@
import logging
from typing import Any, Callable, Dict, Literal, Optional, Union

from app.agent import MODEL_ID, genai_client, live_connect_config, tool_functions
import backoff
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from google.cloud import logging as google_cloud_logging
from google.genai import types
from google.genai.types import LiveServerToolCall

import backoff
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from websockets.exceptions import ConnectionClosedError

from app.agent import (
MODEL_ID,
genai_client,
live_connect_config,
tool_functions,
)


app = FastAPI()
app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -93,36 +101,42 @@ def _get_func(self, action_label: str) -> Optional[Callable]:
async def _handle_tool_call(
self, session: Any, tool_call: LiveServerToolCall
) -> None:
"""Process tool calls from Gemini and send back responses.
"""Process tool calls from Gemini and send back responses."""
# Create a task for handling the tool call
asyncio.create_task(self._process_tool_call(session, tool_call))

Args:
session: The Gemini session
tool_call: Tool call request from Gemini
"""
for fc in tool_call.function_calls:
logging.debug(f"Calling tool function: {fc.name} with args: {fc.args}")
response = self._get_func(fc.name)(**fc.args)
tool_response = types.LiveClientToolResponse(
function_responses=[
types.FunctionResponse(name=fc.name, id=fc.id, response=response)
]
)
logging.debug(f"Tool response: {tool_response}")
await session.send(tool_response)
async def _process_tool_call(
self, session: Any, tool_call: LiveServerToolCall
) -> None:
"""Process tool calls in a separate task."""
try:
for fc in tool_call.function_calls:
logging.debug(f"Calling tool function: {fc.name} with args: {fc.args}")
response = await self._get_func(fc.name)(**fc.args)
tool_response = types.LiveClientToolResponse(
function_responses=[
types.FunctionResponse(name=fc.name, id=fc.id, response=response)
]
)
logging.debug(f"Tool response: {tool_response}")
await session.send(tool_response)
except Exception as e:
logging.error(f"Error processing tool call: {str(e)}")

async def receive_from_gemini(self) -> None:
"""Listen for and process messages from Gemini.
Continuously receives messages from Gemini, forwards them to the client,
and handles any tool calls. Handles connection errors gracefully.
"""
while result := await self.session._ws.recv(decode=False):
await self.websocket.send_bytes(result)
message = types.LiveServerMessage.model_validate(json.loads(result))
if message.tool_call:
tool_call = LiveServerToolCall.model_validate(message.tool_call)
await self._handle_tool_call(self.session, tool_call)

"""Listen for and process messages from Gemini."""
try:
while result := await self.session._ws.recv(decode=False):
# Send the message to the client immediately
await self.websocket.send_bytes(result)

# Process any tool calls asynchronously
message = types.LiveServerMessage.model_validate(json.loads(result))
if message.tool_call:
tool_call = LiveServerToolCall.model_validate(message.tool_call)
await self._handle_tool_call(self.session, tool_call)
except Exception as e:
logging.error(f"Error receiving from Gemini: {str(e)}")

def get_connect_and_run_callable(websocket: WebSocket) -> Callable:
"""Create a callable that handles Gemini connection with retry logic.
Expand Down

0 comments on commit ddd3a61

Please sign in to comment.