Skip to content

Commit

Permalink
getting started on protobuf stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Moishe committed Mar 20, 2024
1 parent 397bd2d commit fd68dd7
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 57 deletions.
75 changes: 31 additions & 44 deletions examples/foundational/websocket-server/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="//cdn.jsdelivr.net/npm/[email protected]/dist/protobuf.min.js"></script>
<title>WebSocket Audio Stream</title>
</head>

Expand All @@ -14,65 +15,51 @@ <h1>WebSocket Audio Stream</h1>
<script>
let audioContext;
let microphoneStream;
let mediaRecorder;
let refresher;
let audioCtx;


function startAudio() {
async function startAudio() {
audioCtx = new (window.AudioContext || window.webkitAudioContext)();

if (!navigator.mediaDevices || !navigator.mediaDevices.getUserMedia) {
alert('getUserMedia is not supported in your browser.');
return;
}

navigator.mediaDevices.getUserMedia({ audio: true })
.then(function (stream) {
console.log("creating audio context")
audioContext = new AudioContext(options = { sampleRate: 16000 });
microphoneStream = audioContext.createMediaStreamSource(stream);
mediaRecorder = new MediaRecorder(stream);
const captureNode = audioContext.createScriptProcessor(8192, 1, 1);
captureNode.onaudioprocess = function (e) {
console.log("on audio process")
debugger;
var left = e.inputBuffer.getChannelData(0);
ws.send(left.buffer);
};
/*
mediaRecorder.ondataavailable = function (event) {
console.log("ondataavailable")
if (event.data.size > 0 && ws && ws.readyState === ws.OPEN) {
console.log("Sending audio blob:", event.data.size, 'bytes')
debugger
ws.send(event.data);
}
};
*/
mediaRecorder.start();

refresher = setInterval(function () {
console.log("sending audio")
mediaRecorder.requestData();
}, 1000);

})
.catch(function (error) {
console.error('Error accessing microphone:', error);
});
// Get the microphone stream.
microphoneStream = await navigator.mediaDevices.getUserMedia({ audio: true });

// Create an AudioContext.
const context = new AudioContext();

// Create a ScriptProcessorNode.
const scriptProcessor = context.createScriptProcessor(8192, 1, 1);

// Connect the microphone stream to the ScriptProcessorNode.
const source = context.createMediaStreamSource(microphoneStream);
source.connect(scriptProcessor);

// Connect the ScriptProcessorNode to the destination.
scriptProcessor.connect(context.destination);

scriptProcessor.onaudioprocess = (event) => {
const rawLeftChannelData = event.inputBuffer.getChannelData(0);
scaledToInt = []
for (var i = 0; i < rawLeftChannelData.length; i++) {
// Convert each item from -1.0-1.0 to a 16-bit signed integer
scaledToInt[i] = (rawLeftChannelData[i] * 32767 + 32767) % 65535;
}
ws.send(scaledToInt);
};


initWebSocket();
}

function stopAudio() {
if (mediaRecorder && mediaRecorder.state === 'recording') {
clearInterval(refresher);
mediaRecorder.stop();
microphoneStream.disconnect();
ws.close();
ws = undefined;
}
microphoneStream.disconnect();
ws.close();
ws = undefined;
}

function initWebSocket() {
Expand Down
26 changes: 18 additions & 8 deletions examples/foundational/websocket-server/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@
import aiohttp
import logging
import os
import wave

from dailyai.pipeline.frames import TextFrame
from dailyai.pipeline.frame_processor import FrameProcessor
from dailyai.pipeline.frames import AudioFrame, TextFrame
from dailyai.pipeline.pipeline import Pipeline
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.services.websocket_transport_service import WebsocketTransportService

from runner import configure
from dailyai.services.whisper_ai_services import WhisperSTTService

logging.basicConfig(format=f"%(levelno)s %(asctime)s %(message)s")
logger = logging.getLogger("websockets.server")
logger = logging.getLogger("dailyai")
logger.setLevel(logging.DEBUG)


async def main(room_url):
async def main():
async with aiohttp.ClientSession() as session:
transport = WebsocketTransportService(
mic_enabled=True, speaker_enabled=True)
Expand All @@ -26,7 +27,17 @@ async def main(room_url):
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
)

pipeline = Pipeline([tts])
class WriteToWav(FrameProcessor):
async def process_frame(self, frame):
if isinstance(frame, AudioFrame):
with wave.open("output.wav", "wb") as f:
f.setnchannels(1)
f.setsampwidth(2)
f.setframerate(16000)
f.writeframes(frame.data)
yield frame

pipeline = Pipeline([WriteToWav(), WhisperSTTService(), tts])

@transport.on_connection
async def queue_frame():
Expand All @@ -37,5 +48,4 @@ async def queue_frame():


if __name__ == "__main__":
(url, token) = configure()
asyncio.run(main(url))
asyncio.run(main())
Binary file added output.wav
Binary file not shown.
25 changes: 25 additions & 0 deletions src/dailyai/pipeline/frames.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
syntax = "proto3";

package dailyai_proto;

message TextFrame {
string text = 1;
}

message AudioFrame {
bytes audio = 1;
}

message TranscriptionFrame {
string text = 1;
string participant_id = 2;
string timestamp = 3;
}

message Frame {
oneof frame {
TextFrame text = 1;
AudioFrame audio = 2;
TranscriptionFrame transcription = 3;
}
}
75 changes: 73 additions & 2 deletions src/dailyai/pipeline/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@
from typing import Any, List

from dailyai.services.openai_llm_context import OpenAILLMContext
import dailyai.pipeline.protobufs.frames_pb2 as frame_protos


class Frame:
def __str__(self):
return f"{self.__class__.__name__}"

def to_proto(self):
raise NotImplementedError

@staticmethod
def from_proto(proto):
raise NotImplementedError

class ControlFrame(Frame):
# Control frames should contain no instance data, so
Expand Down Expand Up @@ -61,12 +68,30 @@ class LLMResponseEndFrame(ControlFrame):
@dataclass()
class AudioFrame(Frame):
"""A chunk of audio. Will be played by the transport if the transport's mic
has been enabled."""
has been enabled.
>>> str(AudioFrame(data=b'1234567890'))
'AudioFrame, size: 10 B'
>>> AudioFrame.from_proto(AudioFrame(data=b'1234567890').to_proto())
AudioFrame(data=b'1234567890')
"""
data: bytes

def __str__(self):
return f"{self.__class__.__name__}, size: {len(self.data)} B"

def to_proto(self) -> frame_protos.Frame:
frame = frame_protos.Frame()
frame.audio.audio = self.data
return frame

@staticmethod
def from_proto(proto: frame_protos.Frame):
if proto.WhichOneof("frame") != "audio":
raise ValueError("Proto does not contain an audio frame")
return AudioFrame(data=proto.audio.audio)


@dataclass()
class ImageFrame(Frame):
Expand All @@ -93,12 +118,25 @@ def __str__(self):
@dataclass()
class TextFrame(Frame):
"""A chunk of text. Emitted by LLM services, consumed by TTS services, can
be used to send text through pipelines."""
be used to send text through pipelines.
>>> str(TextFrame.from_proto(TextFrame(text='hello world').to_proto()))
'TextFrame: "hello world"'
"""
text: str

def __str__(self):
return f'{self.__class__.__name__}: "{self.text}"'

def to_proto(self) -> frame_protos.Frame:
proto_frame = frame_protos.Frame()
proto_frame.text.text = self.text
return proto_frame

@staticmethod
def from_proto(proto: frame_protos.TextFrame):
return TextFrame(text=proto.text.text)


@dataclass()
class TranscriptionQueueFrame(TextFrame):
Expand All @@ -107,6 +145,21 @@ class TranscriptionQueueFrame(TextFrame):
participantId: str
timestamp: str

def to_proto(self) -> frame_protos.Frame:
frame = frame_protos.Frame()
frame.transcription.text = self.text
frame.transcription.participant_id = self.participantId
frame.transcription.timestamp = self.timestamp
return frame

@staticmethod
def from_proto(proto: frame_protos.Frame):
return TranscriptionQueueFrame(
text=proto.transcription.text,
participantId=proto.transcription.participant_id,
timestamp=proto.transcription.timestamp
)


@dataclass()
class LLMMessagesQueueFrame(Frame):
Expand Down Expand Up @@ -179,3 +232,21 @@ class LLMFunctionCallFrame(Frame):
"""Emitted when the LLM has received an entire function call completion."""
function_name: str
arguments: str


if __name__ == "__main__":
audio_frame = AudioFrame(data=b'1234567890')
print(audio_frame)
print(audio_frame.to_proto().SerializeToString())
print(AudioFrame.from_proto(audio_frame.to_proto()))

text_frame = TextFrame(text="Hello there!")
print(text_frame)
print(text_frame.to_proto().SerializeToString())
print(TextFrame.from_proto(text_frame.to_proto()))

transcripton_frame = TranscriptionQueueFrame(
text="Hello there!", participantId="123", timestamp="2021-01-01")
print(transcripton_frame)
print(transcripton_frame.to_proto().SerializeToString())
print(TranscriptionQueueFrame.from_proto(transcripton_frame.to_proto()))
32 changes: 32 additions & 0 deletions src/dailyai/pipeline/protobufs/frames_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/dailyai/services/local_stt_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _new_wave(self):
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
"""Processes a frame of audio data, either buffering or transcribing it."""
if not isinstance(frame, AudioFrame):
yield frame
return

data = frame.data
Expand Down
9 changes: 6 additions & 3 deletions src/dailyai/services/websocket_transport_service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from concurrent.futures import Future
import queue
import wave
import websockets

from dailyai.services.base_transport_service import BaseTransportService
Expand Down Expand Up @@ -36,7 +37,6 @@ def write_frame_to_mic(self, frame: bytes):
def read_audio_frames(self, desired_frame_count):
audio_bytes = bytes()
while not self._audio_queue.empty() and len(audio_bytes) < desired_frame_count:
print("getting audio out of queue")
audio_bytes += self._audio_queue.get()

return audio_bytes
Expand All @@ -52,7 +52,11 @@ async def _websocket_handler(self, websocket: websockets.WebSocketServerProtocol

self._websocket = websocket
async for message in websocket:
print("got message")
with wave.open("output.wav", "wb") as f:
f.setnchannels(1)
f.setsampwidth(2)
f.setframerate(16000)
f.writeframes(message)
self._audio_queue.put(message)

async def _start_server(self) -> None:
Expand All @@ -68,7 +72,6 @@ async def _start_server(self) -> None:
print(f"Error in start server: {e}")

async def _close_server(self):
print("_close_server")
if self._websocket:
await self._websocket.close()

Expand Down
1 change: 1 addition & 0 deletions src/dailyai/services/whisper_ai_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,5 @@ async def run_stt(self, audio: BinaryIO) -> str:
res: str = ""
for segment in segments:
res += f"{segment.text} "
print("Transcription: ", segment.text)
return res

0 comments on commit fd68dd7

Please sign in to comment.