Skip to content

Commit

Permalink
Merge pull request #117 from daily-co/llm-use-aggregator-pass-through…
Browse files Browse the repository at this point in the history
…-fix

aggregators: fix LLMUserResponseAggregator passs-through
  • Loading branch information
aconchillo authored Apr 11, 2024
2 parents db05a9b + 7336866 commit 1e83a40
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 51 deletions.
14 changes: 5 additions & 9 deletions examples/foundational/06-listen-and-respond.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from dailyai.services.open_ai_services import OpenAILLMService
from dailyai.services.ai_services import FrameLogger
from dailyai.pipeline.aggregators import (
LLMAssistantContextAggregator,
LLMUserContextAggregator,
LLMAssistantResponseAggregator,
LLMUserResponseAggregator,
)
from runner import configure

Expand Down Expand Up @@ -55,11 +55,9 @@ async def main(room_url: str, token):
},
]

tma_in = LLMUserContextAggregator(
messages, transport._my_participant_id)
tma_out = LLMAssistantContextAggregator(
messages, transport._my_participant_id
)
tma_in = LLMUserResponseAggregator(messages)
tma_out = LLMAssistantResponseAggregator(messages)

pipeline = Pipeline(
processors=[
fl,
Expand All @@ -78,8 +76,6 @@ async def on_first_other_participant_joined(transport, participant):
{"role": "system", "content": "Please introduce yourself to the user."})
await pipeline.queue_frames([LLMMessagesFrame(messages)])

transport.transcription_settings["extra"]["endpointing"] = True
transport.transcription_settings["extra"]["punctuate"] = True
await transport.run(pipeline)


Expand Down
11 changes: 5 additions & 6 deletions examples/foundational/06a-image-sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,12 @@ async def main(room_url: str, token):
token,
"Respond bot",
5,
camera_enabled=True,
camera_width=1024,
camera_height=1024,
mic_enabled=True,
mic_sample_rate=16000,
)
transport._camera_enabled = True
transport._camera_width = 1024
transport._camera_height = 1024
transport._mic_enabled = True
transport._mic_sample_rate = 16000
transport.transcription_settings["extra"]["punctuate"] = True

tts = ElevenLabsTTSService(
aiohttp_session=session,
Expand Down
1 change: 0 additions & 1 deletion examples/foundational/07-interruptible.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ async def run_conversation():
pre_processor=LLMUserResponseAggregator(messages),
)

transport.transcription_settings["extra"]["punctuate"] = False
await asyncio.gather(transport.run(), run_conversation())


Expand Down
6 changes: 0 additions & 6 deletions examples/foundational/10-wake-word.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ async def main(room_url: str, token):
camera_width=720,
camera_height=1280,
)
transport._mic_enabled = True
transport._mic_sample_rate = 16000
transport._camera_enabled = True
transport._camera_width = 720
transport._camera_height = 1280
transport.transcription_settings["extra"]["punctuate"] = True

llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
Expand Down
1 change: 0 additions & 1 deletion examples/foundational/11-sound-effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ async def main(room_url: str, token):
mic_sample_rate=16000,
camera_enabled=False,
)
transport.transcription_settings["extra"]["punctuate"] = True

llm = OpenAILLMService(
api_key=os.getenv("OPENAI_API_KEY"),
Expand Down
2 changes: 0 additions & 2 deletions examples/image-gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ async def on_participant_joined(transport, participant):
async for audio in audio_generator:
transport.output_queue.put(Frame(FrameType.AUDIO_FRAME, audio))

transport.transcription_settings["extra"]["punctuate"] = False
transport.transcription_settings["extra"]["endpointing"] = False
await asyncio.gather(transport.run(), handle_transcriptions())


Expand Down
2 changes: 0 additions & 2 deletions examples/internal/11a-dial-out.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ async def on_call_state_updated(transport, state):
transport.start_recording()
transport.dialout(phone)

transport.transcription_settings["extra"]["punctuate"] = True

await asyncio.gather(transport.run(), handle_transcriptions())


Expand Down
2 changes: 0 additions & 2 deletions examples/starter-apps/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ async def run_conversation():
pre_processor=LLMUserResponseAggregator(messages),
)

transport.transcription_settings["extra"]["endpointing"] = True
transport.transcription_settings["extra"]["punctuate"] = True
await asyncio.gather(transport.run(), run_conversation())


Expand Down
2 changes: 0 additions & 2 deletions examples/starter-apps/patient-intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,6 @@ async def handle_intake():
pre_processor=OpenAIUserContextAggregator(context),
)

transport.transcription_settings["extra"]["endpointing"] = True
transport.transcription_settings["extra"]["punctuate"] = True
try:
await asyncio.gather(transport.run(), handle_intake())
except (asyncio.CancelledError, KeyboardInterrupt):
Expand Down
2 changes: 0 additions & 2 deletions examples/starter-apps/storybot.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,6 @@ async def storytime():
pipeline,
)

transport.transcription_settings["extra"]["endpointing"] = True
transport.transcription_settings["extra"]["punctuate"] = True
try:
await asyncio.gather(transport.run(), storytime())
except (asyncio.CancelledError, KeyboardInterrupt):
Expand Down
2 changes: 0 additions & 2 deletions examples/starter-apps/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@ async def main(room_url: str, token):
ts = TranslationSubtitles("spanish")
pipeline = Pipeline([sa, tp, llm, lfra, ts, tts])

transport.transcription_settings["extra"]["endpointing"] = True
transport.transcription_settings["extra"]["punctuate"] = True
await transport.run(pipeline)


Expand Down
72 changes: 60 additions & 12 deletions src/dailyai/pipeline/aggregators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import re
import time

from dailyai.pipeline.frame_processor import FrameProcessor

Expand All @@ -8,6 +9,7 @@
EndPipeFrame,
Frame,
ImageFrame,
InterimTranscriptionFrame,
LLMMessagesFrame,
LLMResponseEndFrame,
LLMResponseStartFrame,
Expand Down Expand Up @@ -106,6 +108,7 @@ def __init__(
start_frame,
end_frame,
accumulator_frame,
interim_accumulator_frame=None,
pass_through=True,
):
self.aggregation = ""
Expand All @@ -115,31 +118,75 @@ def __init__(
self._start_frame = start_frame
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._pass_through = pass_through

self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False

# Use cases implemented:
#
# S: Start, E: End, T: Transcription, I: Interim, X: Text
#
# S E -> None
# S T E -> X
# S I T E -> X
# S I E T -> X
# S I E I T -> X
#
# The following case would not be supported:
#
# S I E T1 I T2 -> X
#
# and T2 would be dropped.
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if not self.messages:
return

send_aggregation = False

if isinstance(frame, self._start_frame):
self._seen_start_frame = True
self.aggregating = True
elif isinstance(frame, self._end_frame):
self.aggregating = False
# Sometimes VAD triggers quickly on and off. If we don't get any transcription,
# it creates empty LLM message queue frames
if len(self.aggregation) > 0:
self.messages.append(
{"role": self._role, "content": self.aggregation})
self.aggregation = ""
yield self._end_frame()
yield LLMMessagesFrame(self.messages)
elif isinstance(frame, self._accumulator_frame) and self.aggregating:
self.aggregation += f" {frame.text}"
self._seen_end_frame = True

# We might have received the end frame but we might still be
# aggregating (i.e. we have seen interim results but not the final
# text).
self.aggregating = self._seen_interim_results

# Send the aggregation if we are not aggregating anymore (i.e. no
# more interim results received).
send_aggregation = not self.aggregating
elif isinstance(frame, self._accumulator_frame):
if self.aggregating:
self.aggregation += f" {frame.text}"
# We have receied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
send_aggregation = self._seen_end_frame

if self._pass_through:
yield frame

# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
else:
yield frame

if send_aggregation and len(self.aggregation) > 0:
self.messages.append({"role": self._role, "content": self.aggregation})
yield self._end_frame()
yield LLMMessagesFrame(self.messages)
# Reset
self.aggregation = ""
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False


class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: list[dict]):
Expand All @@ -160,6 +207,7 @@ def __init__(self, messages: list[dict]):
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
pass_through=False,
)

Expand Down
11 changes: 11 additions & 0 deletions src/dailyai/pipeline/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@ def __str__(self):
return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"


@dataclass()
class InterimTranscriptionFrame(TextFrame):
"""A text frame with interim transcription-specific data. Will be placed in
the transport's receive queue when a participant speaks."""
participantId: str
timestamp: str

def __str__(self):
return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"


class TTSStartFrame(ControlFrame):
"""Used to indicate the beginning of a TTS response. Following AudioFrames
are part of the TTS response until an TTEndFrame. These frames can be used
Expand Down
15 changes: 11 additions & 4 deletions src/dailyai/transports/daily_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any

from dailyai.pipeline.frames import (
InterimTranscriptionFrame,
ReceivedAppMessageFrame,
TranscriptionFrame,
UserImageFrame,
Expand Down Expand Up @@ -88,9 +89,11 @@ def __init__(
"model": "2-conversationalai",
"profanity_filter": True,
"redact": False,
"endpointing": True,
"punctuate": True,
"includeRawResponse": True,
"extra": {
"endpointing": True,
"punctuate": False,
"interim_results": True,
},
}

Expand Down Expand Up @@ -368,8 +371,12 @@ def on_transcription_message(self, message: dict):
elif "session_id" in message:
participantId = message["session_id"]
if self._my_participant_id and participantId != self._my_participant_id:
frame = TranscriptionFrame(
message["text"], participantId, message["timestamp"])
is_final = message["rawResponse"]["is_final"]
if is_final:
frame = TranscriptionFrame(message["text"], participantId, message["timestamp"])
else:
frame = InterimTranscriptionFrame(
message["text"], participantId, message["timestamp"])
asyncio.run_coroutine_threadsafe(
self.receive_queue.put(frame), self._loop)

Expand Down

0 comments on commit 1e83a40

Please sign in to comment.