Skip to content

Commit

Permalink
llm user and assistant aggregator renames
Browse files Browse the repository at this point in the history
  • Loading branch information
aconchillo committed Apr 9, 2024
1 parent d7b2e67 commit 97b923e
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
8 changes: 4 additions & 4 deletions examples/foundational/07-interruptible.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import logging
import os
from dailyai.pipeline.aggregators import (
LLMResponseAggregator,
UserResponseAggregator,
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)

from dailyai.pipeline.pipeline import Pipeline
Expand Down Expand Up @@ -63,8 +63,8 @@ async def run_conversation():

await transport.run_interruptible_pipeline(
pipeline,
post_processor=LLMResponseAggregator(messages),
pre_processor=UserResponseAggregator(messages),
post_processor=LLMAssistantResponseAggregator(messages),
pre_processor=LLMUserResponseAggregator(messages),
)

transport.transcription_settings["extra"]["punctuate"] = False
Expand Down
8 changes: 4 additions & 4 deletions examples/starter-apps/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import AsyncGenerator

from dailyai.pipeline.aggregators import (
LLMResponseAggregator,
UserResponseAggregator,
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from dailyai.pipeline.frames import (
ImageFrame,
Expand Down Expand Up @@ -135,8 +135,8 @@ async def run_conversation():

await transport.run_interruptible_pipeline(
pipeline,
post_processor=LLMResponseAggregator(messages),
pre_processor=UserResponseAggregator(messages),
post_processor=LLMAssistantResponseAggregator(messages),
pre_processor=LLMUserResponseAggregator(messages),
)

transport.transcription_settings["extra"]["endpointing"] = True
Expand Down
8 changes: 4 additions & 4 deletions examples/starter-apps/storybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.pipeline.aggregators import (
LLMAssistantContextAggregator,
UserResponseAggregator,
LLMResponseAggregator,
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from dailyai.pipeline.frames import (
EndPipeFrame,
Expand Down Expand Up @@ -209,8 +209,8 @@ async def main(room_url: str, token):
key_id=os.getenv("FAL_KEY_ID"),
key_secret=os.getenv("FAL_KEY_SECRET"),
)
lra = LLMResponseAggregator(messages)
ura = UserResponseAggregator(messages)
lra = LLMAssistantResponseAggregator(messages)
ura = LLMUserResponseAggregator(messages)
sp = StoryProcessor(messages, story)
sig = StoryImageGenerator(story, llm, img)

Expand Down
18 changes: 9 additions & 9 deletions src/dailyai/pipeline/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import AsyncGenerator, Coroutine, List


class BasicResponseAggregator(FrameProcessor):
class ResponseAggregator(FrameProcessor):
"""This frame processor aggregates frames between a start and an end frame
into complete text frame sentences.
Expand All @@ -37,10 +37,10 @@ class BasicResponseAggregator(FrameProcessor):
... if isinstance(frame, TextFrame):
... print(frame.text)
>>> aggregator = BasicResponseAggregator(start_frame = UserStartedSpeakingFrame,
... end_frame=UserStoppedSpeakingFrame,
... accumulator_frame=TranscriptionFrame,
... pass_through=False)
>>> aggregator = ResponseAggregator(start_frame = UserStartedSpeakingFrame,
... end_frame=UserStoppedSpeakingFrame,
... accumulator_frame=TranscriptionFrame,
... pass_through=False)
>>> asyncio.run(print_frames(aggregator, UserStartedSpeakingFrame()))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("Hello,", 1, 1)))
>>> asyncio.run(print_frames(aggregator, TranscriptionFrame("world.", 1, 2)))
Expand Down Expand Up @@ -84,7 +84,7 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
yield frame


class UserTranscriptionAggregator(BasicResponseAggregator):
class UserResponseAggregator(ResponseAggregator):
def __init__(self):
super().__init__(
start_frame=UserStartedSpeakingFrame,
Expand All @@ -94,7 +94,7 @@ def __init__(self):
)


class ResponseAggregator(FrameProcessor):
class LLMResponseAggregator(FrameProcessor):

def __init__(
self,
Expand Down Expand Up @@ -139,7 +139,7 @@ async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
yield frame


class LLMResponseAggregator(ResponseAggregator):
class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: list[dict]):
super().__init__(
messages=messages,
Expand All @@ -150,7 +150,7 @@ def __init__(self, messages: list[dict]):
)


class UserResponseAggregator(ResponseAggregator):
class LLMUserResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: list[dict]):
super().__init__(
messages=messages,
Expand Down

0 comments on commit 97b923e

Please sign in to comment.