Skip to content

Commit

Permalink
Websocket transport
Browse files Browse the repository at this point in the history
  • Loading branch information
Moishe committed Mar 25, 2024
1 parent 2c5628a commit 2bda4c3
Show file tree
Hide file tree
Showing 19 changed files with 669 additions and 21 deletions.
18 changes: 15 additions & 3 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,23 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Setup virtual environment
run: |
python -m venv .venv
- name: Install basic Python dependencies
run: |
source .venv/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: autopep8
id: autopep8
uses: peter-evans/autopep8@v2
with:
args: --exit-code -r -d -a -a src/
run: |
source .venv/bin/activate
autopep8 --max-line-length 100 --exit-code -r -d --exclude "*_pb2.py" -a -a src/
- name: Fail if autopep8 requires changes
if: steps.autopep8.outputs.exit-code == 2
run: exit 1
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ You can use [use-package](https://github.com/jwiegley/use-package) to install [p
:defer t
:hook ((python-mode . py-autopep8-mode))
:config
(setq py-autopep8-options '("-a" "-a")))
(setq py-autopep8-options '("-a" "-a", "--max-line-length=100")))
```

`autopep8` was installed in the `venv` environment described before, so you should be able to use [pyvenv-auto](https://github.com/ryotaro612/pyvenv-auto) to automatically load that environment inside Emacs.
Expand All @@ -152,6 +152,7 @@ Install the
},
"autopep8.args": [
"-a",
"-a"
"-a",
"--max-line-length=100"
],
```
3 changes: 2 additions & 1 deletion examples/foundational/10-wake-word.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ async def handle_transcriptions():
isa.run(
tma_out.run(
llm.run(
tma_in.run(ncf.run(tf.run(transport.get_receive_frames())))
tma_in.run(
ncf.run(tf.run(transport.get_receive_frames())))
)
)
),
Expand Down
25 changes: 25 additions & 0 deletions examples/foundational/websocket-server/frames.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
syntax = "proto3";

package dailyai_proto;

message TextFrame {
string text = 1;
}

message AudioFrame {
bytes audio = 1;
}

message TranscriptionFrame {
string text = 1;
string participant_id = 2;
string timestamp = 3;
}

message Frame {
oneof frame {
TextFrame text = 1;
AudioFrame audio = 2;
TranscriptionFrame transcription = 3;
}
}
134 changes: 134 additions & 0 deletions examples/foundational/websocket-server/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<!DOCTYPE html>
<html lang="en">

<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="//cdn.jsdelivr.net/npm/[email protected]/dist/protobuf.min.js"></script>
<title>WebSocket Audio Stream</title>
</head>

<body>
<h1>WebSocket Audio Stream</h1>
<button id="startAudioBtn">Start Audio</button>
<button id="stopAudioBtn">Stop Audio</button>
<script>
const SAMPLE_RATE = 16000;
const BUFFER_SIZE = 8192;
const MIN_AUDIO_SIZE = 6400;

let audioContext;
let microphoneStream;
let scriptProcessor;
let source;
let frame;
let audioChunks = [];
let isPlaying = false;
let ws;

const proto = protobuf.load("frames.proto", (err, root) => {
if (err) throw err;
frame = root.lookupType("dailyai_proto.Frame");
});

function initWebSocket() {
ws = new WebSocket('ws://localhost:8765');

ws.addEventListener('open', () => console.log('WebSocket connection established.'));
ws.addEventListener('message', handleWebSocketMessage);
ws.addEventListener('close', (event) => console.log("WebSocket connection closed.", event.code, event.reason));
ws.addEventListener('error', (event) => console.error('WebSocket error:', event));
}

async function handleWebSocketMessage(event) {
const arrayBuffer = await event.data.arrayBuffer();
enqueueAudioFromProto(arrayBuffer);
}

function enqueueAudioFromProto(arrayBuffer) {
const parsedFrame = frame.decode(new Uint8Array(arrayBuffer));
if (!parsedFrame?.audio) return false;

const frameCount = parsedFrame.audio.data.length / 2;
const audioOutBuffer = audioContext.createBuffer(1, frameCount, SAMPLE_RATE);
const nowBuffering = audioOutBuffer.getChannelData(0);
const view = new Int16Array(parsedFrame.audio.data.buffer);

for (let i = 0; i < frameCount; i++) {
const word = view[i];
nowBuffering[i] = ((word + 32768) % 65536 - 32768) / 32768.0;
}

audioChunks.push(audioOutBuffer);
if (!isPlaying) playNextChunk();
}

function playNextChunk() {
if (audioChunks.length === 0) {
isPlaying = false;
return;
}

isPlaying = true;
const audioOutBuffer = audioChunks.shift();
const source = audioContext.createBufferSource();
source.buffer = audioOutBuffer;
source.connect(audioContext.destination);
source.onended = playNextChunk;
source.start();
}

function startAudio() {
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
alert('getUserMedia is not supported in your browser.');
return;
}

navigator.mediaDevices.getUserMedia({ audio: true })
.then((stream) => {
microphoneStream = stream;
audioContext = new (window.AudioContext || window.webkitAudioContext)();
scriptProcessor = audioContext.createScriptProcessor(BUFFER_SIZE, 1, 1);
source = audioContext.createMediaStreamSource(stream);
source.connect(scriptProcessor);
scriptProcessor.connect(audioContext.destination);

const audioBuffer = [];
const skipRatio = Math.floor(audioContext.sampleRate / (SAMPLE_RATE * 2));

scriptProcessor.onaudioprocess = (event) => {
const rawLeftChannelData = event.inputBuffer.getChannelData(0);
for (let i = 0; i < rawLeftChannelData.length; i += skipRatio) {
const normalized = ((rawLeftChannelData[i] * 32768.0) + 32768) % 65536 - 32768;
const swappedBytes = ((normalized & 0xff) << 8) | ((normalized >> 8) & 0xff);
audioBuffer.push(swappedBytes);
}

if (audioBuffer.length >= MIN_AUDIO_SIZE) {
const audioFrame = frame.create({ audio: { audio: audioBuffer.slice(0, MIN_AUDIO_SIZE) } });
const encodedFrame = new Uint8Array(frame.encode(audioFrame).finish());
ws.send(encodedFrame);
audioBuffer.splice(0, MIN_AUDIO_SIZE);
}
};

initWebSocket();
})
.catch((error) => console.error('Error accessing microphone:', error));
}

function stopAudio() {
if (ws) {
ws.close();
scriptProcessor.disconnect();
source.disconnect();
ws = undefined;
}
}

document.getElementById('startAudioBtn').addEventListener('click', startAudio);
document.getElementById('stopAudioBtn').addEventListener('click', stopAudio);
</script>
</body>

</html>
50 changes: 50 additions & 0 deletions examples/foundational/websocket-server/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
import aiohttp
import logging
import os
from dailyai.pipeline.frame_processor import FrameProcessor
from dailyai.pipeline.frames import TextFrame, TranscriptionQueueFrame
from dailyai.pipeline.pipeline import Pipeline
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.services.websocket_transport_service import WebsocketTransport
from dailyai.services.whisper_ai_services import WhisperSTTService

logging.basicConfig(format="%(levelno)s %(asctime)s %(message)s")
logger = logging.getLogger("dailyai")
logger.setLevel(logging.DEBUG)


class WhisperTranscriber(FrameProcessor):
async def process_frame(self, frame):
if isinstance(frame, TranscriptionQueueFrame):
print(f"Transcribed: {frame.text}")
else:
yield frame


async def main():
async with aiohttp.ClientSession() as session:
transport = WebsocketTransport(
mic_enabled=True,
speaker_enabled=True,
)
tts = ElevenLabsTTSService(
aiohttp_session=session,
api_key=os.getenv("ELEVENLABS_API_KEY"),
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)

pipeline = Pipeline([
WhisperSTTService(),
WhisperTranscriber(),
tts,
])

@transport.on_connection
async def queue_frame():
await pipeline.queue_frames([TextFrame("Hello there!")])

await transport.run(pipeline)

if __name__ == "__main__":
asyncio.run(main())
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ dependencies = [
"torch",
"torchaudio",
"pyaudio",
"typing-extensions"
"typing-extensions",
"websockets"
]

[project.urls]
Expand Down
2 changes: 0 additions & 2 deletions src/dailyai/pipeline/frame_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ async def process_frame(
self, frame: Frame
) -> AsyncGenerator[Frame, None]:
"""Process a single frame and yield 0 or more frames."""
if isinstance(frame, ControlFrame):
yield frame
yield frame

@abstractmethod
Expand Down
25 changes: 25 additions & 0 deletions src/dailyai/pipeline/frames.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
syntax = "proto3";

package dailyai_proto;

message TextFrame {
string text = 1;
}

message AudioFrame {
bytes data = 1;
}

message TranscriptionFrame {
string text = 1;
string participantId = 2;
string timestamp = 3;
}

message Frame {
oneof frame {
TextFrame text = 1;
AudioFrame audio = 2;
TranscriptionFrame transcription = 3;
}
}
17 changes: 17 additions & 0 deletions src/dailyai/pipeline/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, List

from dailyai.services.openai_llm_context import OpenAILLMContext
import dailyai.pipeline.protobufs.frames_pb2 as frame_protos


class Frame:
Expand Down Expand Up @@ -107,6 +108,22 @@ class TranscriptionQueueFrame(TextFrame):
participantId: str
timestamp: str

def __str__(self):
return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"


class TTSStartFrame(ControlFrame):
"""Used to indicate the beginning of a TTS response. Following AudioFrames
are part of the TTS response until an TTEndFrame. These frames can be used
for aggregating audio frames in a transport to optimize the size of frames
sent to the session, without needing to control this in the TTS service."""
pass


class TTSEndFrame(ControlFrame):
"""Indicates the end of a TTS response."""
pass


@dataclass()
class LLMMessagesQueueFrame(Frame):
Expand Down
9 changes: 6 additions & 3 deletions src/dailyai/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
queues. If this pipeline is run by a transport, its sink and source queues
will be overridden.
"""
self.processors: List[FrameProcessor] = processors
self._processors: List[FrameProcessor] = processors

self.source: asyncio.Queue[Frame] = source or asyncio.Queue()
self.sink: asyncio.Queue[Frame] = sink or asyncio.Queue()
Expand All @@ -40,6 +40,9 @@ def set_sink(self, sink: asyncio.Queue[Frame]):
has processed a frame, its output will be placed on this queue."""
self.sink = sink

def add_processor(self, processor: FrameProcessor):
self._processors.append(processor)

async def get_next_source_frame(self) -> AsyncGenerator[Frame, None]:
"""Convenience function to get the next frame from the source queue. This
lets us consistently have an AsyncGenerator yield frames, from either the
Expand Down Expand Up @@ -80,7 +83,7 @@ async def run_pipeline(self):
while True:
initial_frame = await self.source.get()
async for frame in self._run_pipeline_recursively(
initial_frame, self.processors
initial_frame, self._processors
):
await self.sink.put(frame)

Expand All @@ -91,7 +94,7 @@ async def run_pipeline(self):
except asyncio.CancelledError:
# this means there's been an interruption, do any cleanup necessary
# here.
for processor in self.processors:
for processor in self._processors:
await processor.interrupted()
pass

Expand Down
Loading

0 comments on commit 2bda4c3

Please sign in to comment.