Skip to content

Commit

Permalink
use InterimTranscriptionFrame in LLMUserResponseAggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
aconchillo committed Apr 11, 2024
1 parent e288aa0 commit eb3bf11
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 23 deletions.
77 changes: 56 additions & 21 deletions src/dailyai/pipeline/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
EndPipeFrame,
Frame,
ImageFrame,
InterimTranscriptionFrame,
LLMMessagesFrame,
LLMResponseEndFrame,
LLMResponseStartFrame,
Expand Down Expand Up @@ -107,8 +108,8 @@ def __init__(
start_frame,
end_frame,
accumulator_frame,
interim_accumulator_frame=None,
pass_through=True,
end_frame_threshold=0.75,
):
self.aggregation = ""
self.aggregating = False
Expand All @@ -117,42 +118,75 @@ def __init__(
self._start_frame = start_frame
self._end_frame = end_frame
self._accumulator_frame = accumulator_frame
self._interim_accumulator_frame = interim_accumulator_frame
self._pass_through = pass_through
self._end_frame_threshold = end_frame_threshold
self._last_end_frame_time = 0

self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False

# Use cases implemented:
#
# S: Start, E: End, T: Transcription, I: Interim, X: Text
#
# S E -> None
# S T E -> X
# S I T E -> X
# S I E T -> X
# S I E I T -> X
#
# The following case would not be supported:
#
# S I E T1 I T2 -> X
#
# and T2 would be dropped.
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if not self.messages:
return

send_aggregation = False

if isinstance(frame, self._start_frame):
self._seen_start_frame = True
self.aggregating = True
elif isinstance(frame, self._end_frame):
self.aggregating = False
# Sometimes VAD triggers quickly on and off. If we don't get any transcription,
# it creates empty LLM message queue frames
if len(self.aggregation) > 0:
self.messages.append({"role": self._role, "content": self.aggregation})
self.aggregation = ""
yield self._end_frame()
yield LLMMessagesFrame(self.messages)
self._last_end_frame_time = time.time()
self._seen_end_frame = True

# We might have received the end frame but we might still be
# aggregating (i.e. we have seen interim results but not the final
# text).
self.aggregating = self._seen_interim_results

# Send the aggregation if we are not aggregating anymore (i.e. no
# more interim results received).
send_aggregation = not self.aggregating
elif isinstance(frame, self._accumulator_frame):
# Also accept transcription frames received for a short period after
# the last end frame was received. It might be that transcription
# frames are a bit delayed.
diff_time = time.time() - self._last_end_frame_time
if self.aggregating:
self.aggregation += f" {frame.text}"
elif diff_time <= self._end_frame_threshold:
self.messages.append({"role": self._role, "content": frame.text})
yield self._end_frame()
yield LLMMessagesFrame(self.messages)
# We have receied a complete sentence, so if we have seen the
# end frame and we were still aggregating, it means we should
# send the aggregation.
send_aggregation = self._seen_end_frame

if self._pass_through:
yield frame

# We just got our final result, so let's reset interim results.
self._seen_interim_results = False
elif self._interim_accumulator_frame and isinstance(frame, self._interim_accumulator_frame):
self._seen_interim_results = True
else:
yield frame

if send_aggregation and len(self.aggregation) > 0:
self.messages.append({"role": self._role, "content": self.aggregation})
yield self._end_frame()
yield LLMMessagesFrame(self.messages)
# Reset
self.aggregation = ""
self._seen_start_frame = False
self._seen_end_frame = False
self._seen_interim_results = False


class LLMAssistantResponseAggregator(LLMResponseAggregator):
def __init__(self, messages: list[dict]):
Expand All @@ -173,6 +207,7 @@ def __init__(self, messages: list[dict]):
start_frame=UserStartedSpeakingFrame,
end_frame=UserStoppedSpeakingFrame,
accumulator_frame=TranscriptionFrame,
interim_accumulator_frame=InterimTranscriptionFrame,
pass_through=False,
)

Expand Down
11 changes: 11 additions & 0 deletions src/dailyai/pipeline/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,17 @@ def __str__(self):
return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"


@dataclass()
class InterimTranscriptionFrame(TextFrame):
"""A text frame with interim transcription-specific data. Will be placed in
the transport's receive queue when a participant speaks."""
participantId: str
timestamp: str

def __str__(self):
return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"


class TTSStartFrame(ControlFrame):
"""Used to indicate the beginning of a TTS response. Following AudioFrames
are part of the TTS response until an TTEndFrame. These frames can be used
Expand Down
9 changes: 7 additions & 2 deletions src/dailyai/transports/daily_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any

from dailyai.pipeline.frames import (
InterimTranscriptionFrame,
ReceivedAppMessageFrame,
TranscriptionFrame,
UserImageFrame,
Expand Down Expand Up @@ -368,8 +369,12 @@ def on_transcription_message(self, message: dict):
elif "session_id" in message:
participantId = message["session_id"]
if self._my_participant_id and participantId != self._my_participant_id:
frame = TranscriptionFrame(
message["text"], participantId, message["timestamp"])
is_final = message["rawResponse"]["is_final"]
if is_final:
frame = TranscriptionFrame(message["text"], participantId, message["timestamp"])
else:
frame = InterimTranscriptionFrame(
message["text"], participantId, message["timestamp"])
asyncio.run_coroutine_threadsafe(
self.receive_queue.put(frame), self._loop)

Expand Down

0 comments on commit eb3bf11

Please sign in to comment.