Skip to content

Commit

Permalink
pushing context frames from assistant aggregators
Browse files Browse the repository at this point in the history
  • Loading branch information
kwindla committed Sep 29, 2024
1 parent d9b16d4 commit 6e4c47c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 60 deletions.
49 changes: 26 additions & 23 deletions src/pipecat/services/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,47 @@
#

import base64
import json
import io
import copy
from typing import Any, Dict, List, Optional
import io
import json
import re
from asyncio import CancelledError
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from loguru import logger
from PIL import Image
from asyncio import CancelledError
import re
from pydantic import BaseModel, Field

from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMEnablePromptCachingFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMModelUpdateFrame,
StartInterruptionFrame,
TextFrame,
VisionImageRawFrame,
UserImageRequestFrame,
UserImageRawFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
LLMFullResponseEndFrame,
FunctionCallResultFrame,
FunctionCallInProgressFrame,
StartInterruptionFrame,
UserImageRequestFrame,
VisionImageRawFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMUserContextAggregator,
LLMAssistantContextAggregator,
)

from loguru import logger
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService

try:
from anthropic import AsyncAnthropic, NOT_GIVEN, NotGiven
from anthropic import NOT_GIVEN, AsyncAnthropic, NotGiven
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
Expand Down Expand Up @@ -565,7 +565,7 @@ async def _push_aggregation(self):
run_llm = False

aggregation = self._aggregation
self._aggregation = ""
self._reset()

try:
if self._function_call_result:
Expand Down Expand Up @@ -616,5 +616,8 @@ async def _push_aggregation(self):
if run_llm:
await self._user_context_aggregator.push_context_frame()

frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)

except Exception as e:
logger.error(f"Error processing frame: {e}")
38 changes: 22 additions & 16 deletions src/pipecat/services/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,39 @@
# SPDX-License-Identifier: BSD 2-Clause License
#

import aiohttp
import base64
import io
import json
import httpx

from dataclasses import dataclass

from typing import Any, AsyncGenerator, Dict, List, Literal, Optional

import aiohttp
import httpx
from loguru import logger
from PIL import Image
from pydantic import BaseModel, Field

from pipecat.frames.frames import (
ErrorFrame,
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMModelUpdateFrame,
StartInterruptionFrame,
TextFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
TextFrame,
URLImageRawFrame,
VisionImageRawFrame,
FunctionCallResultFrame,
FunctionCallInProgressFrame,
StartInterruptionFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.llm_response import (
LLMUserContextAggregator,
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
Expand All @@ -44,12 +45,14 @@
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import ImageGenService, LLMService, TTSService

from PIL import Image

from loguru import logger

try:
from openai import AsyncOpenAI, AsyncStream, DefaultAsyncHttpxClient, BadRequestError, NOT_GIVEN
from openai import (
NOT_GIVEN,
AsyncOpenAI,
AsyncStream,
BadRequestError,
DefaultAsyncHttpxClient,
)
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
Expand Down Expand Up @@ -464,7 +467,7 @@ async def process_frame(self, frame, direction):
await self._push_aggregation()
else:
logger.warning(
f"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
)
self._function_call_in_progress = None
self._function_call_result = None
Expand All @@ -476,7 +479,7 @@ async def _push_aggregation(self):
run_llm = False

aggregation = self._aggregation
self._aggregation = ""
self._reset()

try:
if self._function_call_result:
Expand Down Expand Up @@ -512,5 +515,8 @@ async def _push_aggregation(self):
if run_llm:
await self._user_context_aggregator.push_context_frame()

frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)

except Exception as e:
logger.error(f"Error processing frame: {e}")
44 changes: 23 additions & 21 deletions src/pipecat/services/together.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,36 @@
import json
import re
import uuid
from pydantic import BaseModel, Field

from typing import Any, Dict, List, Optional
from dataclasses import dataclass
from asyncio import CancelledError
from dataclasses import dataclass
from typing import Any, Dict, List, Optional

from loguru import logger
from pydantic import BaseModel, Field

from pipecat.frames.frames import (
Frame,
FunctionCallInProgressFrame,
FunctionCallResultFrame,
LLMFullResponseEndFrame,
LLMFullResponseStartFrame,
LLMMessagesFrame,
LLMModelUpdateFrame,
StartInterruptionFrame,
TextFrame,
UserImageRequestFrame,
LLMMessagesFrame,
LLMFullResponseStartFrame,
LLMFullResponseEndFrame,
FunctionCallResultFrame,
FunctionCallInProgressFrame,
StartInterruptionFrame,
)
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
from pipecat.processors.aggregators.llm_response import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
)
from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContext,
OpenAILLMContextFrame,
)
from pipecat.processors.aggregators.llm_response import (
LLMUserContextAggregator,
LLMAssistantContextAggregator,
)

from loguru import logger
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService

try:
from together import AsyncTogether
Expand Down Expand Up @@ -188,7 +187,7 @@ async def _process_context(self, context: OpenAILLMContext):
if chunk.choices[0].finish_reason == "eos" and accumulating_function_call:
await self._extract_function_call(context, function_call_accumulator)

except CancelledError as e:
except CancelledError:
# todo: implement token counting estimates for use when the user interrupts a long generation
# we do this in the anthropic.py service
raise
Expand Down Expand Up @@ -338,7 +337,7 @@ async def process_frame(self, frame, direction):
await self._push_aggregation()
else:
logger.warning(
f"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
)
self._function_call_in_progress = None
self._function_call_result = None
Expand All @@ -353,7 +352,7 @@ async def _push_aggregation(self):
run_llm = False

aggregation = self._aggregation
self._aggregation = ""
self._reset()

try:
if self._function_call_result:
Expand All @@ -373,5 +372,8 @@ async def _push_aggregation(self):
if run_llm:
await self._user_context_aggregator.push_messages_frame()

frame = OpenAILLMContextFrame(self._context)
await self.push_frame(frame)

except Exception as e:
logger.error(f"Error processing frame: {e}")

0 comments on commit 6e4c47c

Please sign in to comment.