Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Websocket transport #81

Merged
merged 3 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,23 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Setup virtual environment
run: |
python -m venv .venv
- name: Install basic Python dependencies
run: |
source .venv/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: autopep8
id: autopep8
uses: peter-evans/autopep8@v2
with:
args: --exit-code -r -d -a -a src/
run: |
source .venv/bin/activate
autopep8 --max-line-length 100 --exit-code -r -d --exclude "*_pb2.py" -a -a src/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hahahaha. OK.... but what will we do in 2040?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by then my vision will be bad enough that I'll lobby we go back to 80 columns 😂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh... we should update the README with this setting.

- name: Fail if autopep8 requires changes
if: steps.autopep8.outputs.exit-code == 2
run: exit 1
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ You can use [use-package](https://github.com/jwiegley/use-package) to install [p
:defer t
:hook ((python-mode . py-autopep8-mode))
:config
(setq py-autopep8-options '("-a" "-a")))
(setq py-autopep8-options '("-a" "-a", "--max-line-length=100")))
```

`autopep8` was installed in the `venv` environment described before, so you should be able to use [pyvenv-auto](https://github.com/ryotaro612/pyvenv-auto) to automatically load that environment inside Emacs.
Expand All @@ -152,6 +152,7 @@ Install the
},
"autopep8.args": [
"-a",
"-a"
"-a",
"--max-line-length=100"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the = work? If so, use that on Emacs as well.

],
```
3 changes: 2 additions & 1 deletion examples/foundational/10-wake-word.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ async def handle_transcriptions():
isa.run(
tma_out.run(
llm.run(
tma_in.run(ncf.run(tf.run(transport.get_receive_frames())))
tma_in.run(
ncf.run(tf.run(transport.get_receive_frames())))
)
)
),
Expand Down
25 changes: 25 additions & 0 deletions examples/foundational/websocket-server/frames.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
syntax = "proto3";

package dailyai_proto;

message TextFrame {
string text = 1;
}

message AudioFrame {
bytes audio = 1;
}

message TranscriptionFrame {
string text = 1;
string participant_id = 2;
string timestamp = 3;
}

message Frame {
oneof frame {
TextFrame text = 1;
AudioFrame audio = 2;
TranscriptionFrame transcription = 3;
}
}
134 changes: 134 additions & 0 deletions examples/foundational/websocket-server/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
<!DOCTYPE html>
<html lang="en">

<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="//cdn.jsdelivr.net/npm/[email protected]/dist/protobuf.min.js"></script>
<title>WebSocket Audio Stream</title>
</head>

<body>
<h1>WebSocket Audio Stream</h1>
<button id="startAudioBtn">Start Audio</button>
<button id="stopAudioBtn">Stop Audio</button>
<script>
const SAMPLE_RATE = 16000;
const BUFFER_SIZE = 8192;
const MIN_AUDIO_SIZE = 6400;

let audioContext;
let microphoneStream;
let scriptProcessor;
let source;
let frame;
let audioChunks = [];
let isPlaying = false;
let ws;

const proto = protobuf.load("frames.proto", (err, root) => {
if (err) throw err;
frame = root.lookupType("dailyai_proto.Frame");
});

function initWebSocket() {
ws = new WebSocket('ws://localhost:8765');

ws.addEventListener('open', () => console.log('WebSocket connection established.'));
ws.addEventListener('message', handleWebSocketMessage);
ws.addEventListener('close', (event) => console.log("WebSocket connection closed.", event.code, event.reason));
ws.addEventListener('error', (event) => console.error('WebSocket error:', event));
}

async function handleWebSocketMessage(event) {
const arrayBuffer = await event.data.arrayBuffer();
enqueueAudioFromProto(arrayBuffer);
}

function enqueueAudioFromProto(arrayBuffer) {
const parsedFrame = frame.decode(new Uint8Array(arrayBuffer));
if (!parsedFrame?.audio) return false;

const frameCount = parsedFrame.audio.data.length / 2;
const audioOutBuffer = audioContext.createBuffer(1, frameCount, SAMPLE_RATE);
const nowBuffering = audioOutBuffer.getChannelData(0);
const view = new Int16Array(parsedFrame.audio.data.buffer);

for (let i = 0; i < frameCount; i++) {
const word = view[i];
nowBuffering[i] = ((word + 32768) % 65536 - 32768) / 32768.0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems nowBuffering is not being used?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nm, you are modifying the contents.

}

audioChunks.push(audioOutBuffer);
if (!isPlaying) playNextChunk();
}

function playNextChunk() {
if (audioChunks.length === 0) {
isPlaying = false;
return;
}

isPlaying = true;
const audioOutBuffer = audioChunks.shift();
const source = audioContext.createBufferSource();
source.buffer = audioOutBuffer;
source.connect(audioContext.destination);
source.onended = playNextChunk;
source.start();
}

function startAudio() {
if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
alert('getUserMedia is not supported in your browser.');
return;
}

navigator.mediaDevices.getUserMedia({ audio: true })
.then((stream) => {
microphoneStream = stream;
audioContext = new (window.AudioContext || window.webkitAudioContext)();
scriptProcessor = audioContext.createScriptProcessor(BUFFER_SIZE, 1, 1);
source = audioContext.createMediaStreamSource(stream);
source.connect(scriptProcessor);
scriptProcessor.connect(audioContext.destination);

const audioBuffer = [];
const skipRatio = Math.floor(audioContext.sampleRate / (SAMPLE_RATE * 2));

scriptProcessor.onaudioprocess = (event) => {
const rawLeftChannelData = event.inputBuffer.getChannelData(0);
for (let i = 0; i < rawLeftChannelData.length; i += skipRatio) {
const normalized = ((rawLeftChannelData[i] * 32768.0) + 32768) % 65536 - 32768;
const swappedBytes = ((normalized & 0xff) << 8) | ((normalized >> 8) & 0xff);
audioBuffer.push(swappedBytes);
}

if (audioBuffer.length >= MIN_AUDIO_SIZE) {
const audioFrame = frame.create({ audio: { audio: audioBuffer.slice(0, MIN_AUDIO_SIZE) } });
const encodedFrame = new Uint8Array(frame.encode(audioFrame).finish());
ws.send(encodedFrame);
audioBuffer.splice(0, MIN_AUDIO_SIZE);
}
};

initWebSocket();
})
.catch((error) => console.error('Error accessing microphone:', error));
}

function stopAudio() {
if (ws) {
ws.close();
scriptProcessor.disconnect();
source.disconnect();
ws = undefined;
}
}

document.getElementById('startAudioBtn').addEventListener('click', startAudio);
document.getElementById('stopAudioBtn').addEventListener('click', stopAudio);
</script>
</body>

</html>
50 changes: 50 additions & 0 deletions examples/foundational/websocket-server/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import asyncio
import aiohttp
import logging
import os
from dailyai.pipeline.frame_processor import FrameProcessor
from dailyai.pipeline.frames import TextFrame, TranscriptionQueueFrame
from dailyai.pipeline.pipeline import Pipeline
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.services.websocket_transport_service import WebsocketTransport
from dailyai.services.whisper_ai_services import WhisperSTTService

logging.basicConfig(format="%(levelno)s %(asctime)s %(message)s")
logger = logging.getLogger("dailyai")
logger.setLevel(logging.DEBUG)


class WhisperTranscriber(FrameProcessor):
async def process_frame(self, frame):
if isinstance(frame, TranscriptionQueueFrame):
print(f"Transcribed: {frame.text}")
else:
yield frame


async def main():
async with aiohttp.ClientSession() as session:
transport = WebsocketTransport(
mic_enabled=True,
speaker_enabled=True,
)
tts = ElevenLabsTTSService(
aiohttp_session=session,
api_key=os.getenv("ELEVENLABS_API_KEY"),
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)

pipeline = Pipeline([
WhisperSTTService(),
WhisperTranscriber(),
tts,
])

@transport.on_connection
async def queue_frame():
await pipeline.queue_frames([TextFrame("Hello there!")])

await transport.run(pipeline)

if __name__ == "__main__":
asyncio.run(main())
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ dependencies = [
"torch",
"torchaudio",
"pyaudio",
"typing-extensions"
"typing-extensions",
"websockets"
]

[project.urls]
Expand Down
2 changes: 0 additions & 2 deletions src/dailyai/pipeline/frame_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ async def process_frame(
self, frame: Frame
) -> AsyncGenerator[Frame, None]:
"""Process a single frame and yield 0 or more frames."""
if isinstance(frame, ControlFrame):
yield frame
yield frame

@abstractmethod
Expand Down
25 changes: 25 additions & 0 deletions src/dailyai/pipeline/frames.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
syntax = "proto3";

package dailyai_proto;

message TextFrame {
string text = 1;
}

message AudioFrame {
bytes data = 1;
}

message TranscriptionFrame {
string text = 1;
string participantId = 2;
string timestamp = 3;
}

message Frame {
oneof frame {
TextFrame text = 1;
AudioFrame audio = 2;
TranscriptionFrame transcription = 3;
}
}
17 changes: 17 additions & 0 deletions src/dailyai/pipeline/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, List

from dailyai.services.openai_llm_context import OpenAILLMContext
import dailyai.pipeline.protobufs.frames_pb2 as frame_protos


class Frame:
Expand Down Expand Up @@ -107,6 +108,22 @@ class TranscriptionQueueFrame(TextFrame):
participantId: str
timestamp: str

def __str__(self):
return f"{self.__class__.__name__}, text: '{self.text}' participantId: {self.participantId}, timestamp: {self.timestamp}"


class TTSStartFrame(ControlFrame):
"""Used to indicate the beginning of a TTS response. Following AudioFrames
are part of the TTS response until an TTEndFrame. These frames can be used
for aggregating audio frames in a transport to optimize the size of frames
sent to the session, without needing to control this in the TTS service."""
pass


class TTSEndFrame(ControlFrame):
"""Indicates the end of a TTS response."""
pass


@dataclass()
class LLMMessagesQueueFrame(Frame):
Expand Down
9 changes: 6 additions & 3 deletions src/dailyai/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
queues. If this pipeline is run by a transport, its sink and source queues
will be overridden.
"""
self.processors: List[FrameProcessor] = processors
self._processors: List[FrameProcessor] = processors

self.source: asyncio.Queue[Frame] = source or asyncio.Queue()
self.sink: asyncio.Queue[Frame] = sink or asyncio.Queue()
Expand All @@ -40,6 +40,9 @@ def set_sink(self, sink: asyncio.Queue[Frame]):
has processed a frame, its output will be placed on this queue."""
self.sink = sink

def add_processor(self, processor: FrameProcessor):
self._processors.append(processor)

async def get_next_source_frame(self) -> AsyncGenerator[Frame, None]:
"""Convenience function to get the next frame from the source queue. This
lets us consistently have an AsyncGenerator yield frames, from either the
Expand Down Expand Up @@ -80,7 +83,7 @@ async def run_pipeline(self):
while True:
initial_frame = await self.source.get()
async for frame in self._run_pipeline_recursively(
initial_frame, self.processors
initial_frame, self._processors
):
await self.sink.put(frame)

Expand All @@ -91,7 +94,7 @@ async def run_pipeline(self):
except asyncio.CancelledError:
# this means there's been an interruption, do any cleanup necessary
# here.
for processor in self.processors:
for processor in self._processors:
await processor.interrupted()
pass

Expand Down
Loading
Loading