Skip to content

Commit

Permalink
Merge pull request #34 from daily-co/rename-frames
Browse files Browse the repository at this point in the history
Remove Queue in frame names
  • Loading branch information
Moishe authored Mar 6, 2024
2 parents b955671 + 62fd371 commit d3e76c4
Show file tree
Hide file tree
Showing 25 changed files with 239 additions and 238 deletions.
60 changes: 30 additions & 30 deletions src/dailyai/pipeline/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from dailyai.pipeline.frame_processor import FrameProcessor

from dailyai.pipeline.frames import (
ControlQueueFrame,
EndParallelPipeQueueFrame,
EndStreamQueueFrame,
ControlFrame,
EndPipeFrame,
EndFrame,
LLMMessagesQueueFrame,
LLMResponseEndQueueFrame,
QueueFrame,
TextQueueFrame,
LLMResponseEndFrame,
Frame,
TextFrame,
TranscriptionQueueFrame,
)
from dailyai.pipeline.pipeline import Pipeline
Expand All @@ -38,10 +38,10 @@ def __init__(
self.pass_through = pass_through

async def process_frame(
self, frame: QueueFrame
) -> AsyncGenerator[QueueFrame, None]:
self, frame: Frame
) -> AsyncGenerator[Frame, None]:
# We don't do anything with non-text frames, pass it along to next in the pipeline.
if not isinstance(frame, TextQueueFrame):
if not isinstance(frame, TextFrame):
yield frame
return

Expand Down Expand Up @@ -71,7 +71,7 @@ async def process_frame(
self.messages.append({"role": self.role, "content": frame.text})
yield LLMMessagesQueueFrame(self.messages)

async def finalize(self) -> AsyncGenerator[QueueFrame, None]:
async def finalize(self) -> AsyncGenerator[Frame, None]:
# Send any dangling words that weren't finished with punctuation.
if self.complete_sentences and self.sentence:
self.messages.append({"role": self.role, "content": self.sentence})
Expand Down Expand Up @@ -106,18 +106,18 @@ def __init__(self):
self.aggregation = ""

async def process_frame(
self, frame: QueueFrame
) -> AsyncGenerator[QueueFrame, None]:
if isinstance(frame, TextQueueFrame):
self, frame: Frame
) -> AsyncGenerator[Frame, None]:
if isinstance(frame, TextFrame):
m = re.search("(.*[?.!])(.*)", frame.text)
if m:
yield TextQueueFrame(self.aggregation + m.group(1))
yield TextFrame(self.aggregation + m.group(1))
self.aggregation = m.group(2)
else:
self.aggregation += frame.text
elif isinstance(frame, EndStreamQueueFrame):
elif isinstance(frame, EndFrame):
if self.aggregation:
yield TextQueueFrame(self.aggregation)
yield TextFrame(self.aggregation)
yield frame
else:
yield frame
Expand All @@ -128,12 +128,12 @@ def __init__(self):
self.aggregation = ""

async def process_frame(
self, frame: QueueFrame
) -> AsyncGenerator[QueueFrame, None]:
if isinstance(frame, TextQueueFrame):
self, frame: Frame
) -> AsyncGenerator[Frame, None]:
if isinstance(frame, TextFrame):
self.aggregation += frame.text
elif isinstance(frame, LLMResponseEndQueueFrame):
yield TextQueueFrame(self.aggregation)
elif isinstance(frame, LLMResponseEndFrame):
yield TextFrame(self.aggregation)
self.aggregation = ""
else:
yield frame
Expand All @@ -143,20 +143,20 @@ class StatelessTextTransformer(FrameProcessor):
def __init__(self, transform_fn):
self.transform_fn = transform_fn

async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if isinstance(frame, TextQueueFrame):
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, TextFrame):
result = self.transform_fn(frame.text)
if isinstance(result, Coroutine):
result = await result

yield TextQueueFrame(result)
yield TextFrame(result)
else:
yield frame

class ParallelPipeline(FrameProcessor):
def __init__(self, pipeline_definitions: List[List[FrameProcessor]]):
self.sources = [asyncio.Queue() for _ in pipeline_definitions]
self.sink: asyncio.Queue[QueueFrame] = asyncio.Queue()
self.sink: asyncio.Queue[Frame] = asyncio.Queue()
self.pipelines: list[Pipeline] = [
Pipeline(
pipeline_definition,
Expand All @@ -166,10 +166,10 @@ def __init__(self, pipeline_definitions: List[List[FrameProcessor]]):
for source, pipeline_definition in zip(self.sources, pipeline_definitions)
]

async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
for source in self.sources:
await source.put(frame)
await source.put(EndParallelPipeQueueFrame())
await source.put(EndPipeFrame())

await asyncio.gather(*[pipeline.run_pipeline() for pipeline in self.pipelines])

Expand All @@ -186,17 +186,17 @@ async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, N
seen_ids.add(id(frame))

# Skip passing along EndParallelPipeQueueFrame, because we use them for our own flow control.
if not isinstance(frame, EndParallelPipeQueueFrame):
if not isinstance(frame, EndPipeFrame):
yield frame

class GatedAggregator(FrameProcessor):
def __init__(self, gate_open_fn, gate_close_fn, start_open):
self.gate_open_fn = gate_open_fn
self.gate_close_fn = gate_close_fn
self.gate_open = start_open
self.accumulator: List[QueueFrame] = []
self.accumulator: List[Frame] = []

async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if self.gate_open:
if self.gate_close_fn(frame):
self.gate_open = False
Expand Down
12 changes: 6 additions & 6 deletions src/dailyai/pipeline/frame_processor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import AsyncGenerator

from dailyai.pipeline.frames import ControlQueueFrame, QueueFrame
from dailyai.pipeline.frames import ControlFrame, Frame

"""
This is the base class for all frame processors. Frame processors consume a frame
Expand All @@ -20,16 +20,16 @@
class FrameProcessor:
@abstractmethod
async def process_frame(
self, frame: QueueFrame
) -> AsyncGenerator[QueueFrame, None]:
if isinstance(frame, ControlQueueFrame):
self, frame: Frame
) -> AsyncGenerator[Frame, None]:
if isinstance(frame, ControlFrame):
yield frame

@abstractmethod
async def finalize(self) -> AsyncGenerator[QueueFrame, None]:
async def finalize(self) -> AsyncGenerator[Frame, None]:
# This is a trick for the interpreter (and linter) to know that this is a generator.
if False:
yield QueueFrame()
yield Frame()

@abstractmethod
async def interrupted(self) -> None:
Expand Down
41 changes: 21 additions & 20 deletions src/dailyai/pipeline/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,73 @@
from typing import Any


class QueueFrame:
def __eq__(self, other):
return isinstance(other, self.__class__)


class ControlQueueFrame(QueueFrame):
class Frame:
pass

class ControlFrame(Frame):
# Control frames should contain no instance data, so
# equality is based solely on the class.
def __eq__(self, other):
return type(other) == self.__class__


class StartStreamQueueFrame(ControlQueueFrame):
class StartFrame(ControlFrame):
pass


class EndStreamQueueFrame(ControlQueueFrame):
class EndFrame(ControlFrame):
pass

class EndParallelPipeQueueFrame(ControlQueueFrame):
class EndPipeFrame(ControlFrame):
pass


class LLMResponseStartQueueFrame(QueueFrame):
class LLMResponseStartFrame(ControlFrame):
pass


class LLMResponseEndQueueFrame(QueueFrame):
class LLMResponseEndFrame(ControlFrame):
pass


@dataclass()
class AudioQueueFrame(QueueFrame):
class AudioFrame(Frame):
data: bytes


@dataclass()
class ImageQueueFrame(QueueFrame):
class ImageFrame(Frame):
url: str | None
image: bytes


@dataclass()
class SpriteQueueFrame(QueueFrame):
class SpriteFrame(Frame):
images: list[bytes]


@dataclass()
class TextQueueFrame(QueueFrame):
class TextFrame(Frame):
text: str


@dataclass()
class TranscriptionQueueFrame(TextQueueFrame):
class TranscriptionQueueFrame(TextFrame):
participantId: str
timestamp: str


@dataclass()
class LLMMessagesQueueFrame(QueueFrame):
class LLMMessagesQueueFrame(Frame):
messages: list[dict[str, str]] # TODO: define this more concretely!


class AppMessageQueueFrame(QueueFrame):
class AppMessageQueueFrame(Frame):
message: Any
participantId: str

class UserStartedSpeakingFrame(QueueFrame):
class UserStartedSpeakingFrame(Frame):
pass

class UserStoppedSpeakingFrame(QueueFrame):
class UserStoppedSpeakingFrame(Frame):
pass
18 changes: 9 additions & 9 deletions src/dailyai/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import AsyncGenerator, List
from dailyai.pipeline.frame_processor import FrameProcessor

from dailyai.pipeline.frames import EndParallelPipeQueueFrame, EndStreamQueueFrame, QueueFrame
from dailyai.pipeline.frames import EndPipeFrame, EndFrame, Frame

"""
This class manages a pipe of FrameProcessors, and runs them in sequence. The "source"
Expand All @@ -17,19 +17,19 @@ def __init__(
self,
processors: List[FrameProcessor],
source: asyncio.Queue | None = None,
sink: asyncio.Queue[QueueFrame] | None = None,
sink: asyncio.Queue[Frame] | None = None,
):
self.processors = processors
self.source: asyncio.Queue[QueueFrame] | None = source
self.sink: asyncio.Queue[QueueFrame] | None = sink
self.source: asyncio.Queue[Frame] | None = source
self.sink: asyncio.Queue[Frame] | None = sink

def set_source(self, source: asyncio.Queue[QueueFrame]):
def set_source(self, source: asyncio.Queue[Frame]):
self.source = source

def set_sink(self, sink: asyncio.Queue[QueueFrame]):
def set_sink(self, sink: asyncio.Queue[Frame]):
self.sink = sink

async def get_next_source_frame(self) -> AsyncGenerator[QueueFrame, None]:
async def get_next_source_frame(self) -> AsyncGenerator[Frame, None]:
if self.source is None:
raise ValueError("Source queue not set")
yield await self.source.get()
Expand All @@ -52,9 +52,9 @@ async def run_pipeline(self):
async for frame in frame_generator:
await self.sink.put(frame)
if isinstance(
frame, EndStreamQueueFrame
frame, EndFrame
) or isinstance(
frame, EndParallelPipeQueueFrame
frame, EndPipeFrame
):
return
except asyncio.CancelledError:
Expand Down
Loading

0 comments on commit d3e76c4

Please sign in to comment.