Skip to content

Commit

Permalink
refactor: move transcription models for consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
rachwalk committed Jan 7, 2025
1 parent 2ee0ec4 commit d7cceea
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/rai_asr/rai_asr/asr_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from scipy.io import wavfile
from whisper.transcribe import transcribe

# WARN: This file is going to be removed in favour of rai_asr.models


class ASRModel:
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
Expand Down
13 changes: 11 additions & 2 deletions src/rai_asr/rai_asr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from rai_asr.models.base import BaseVoiceDetectionModel
from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel
from rai_asr.models.local_whisper import LocalWhisper
from rai_asr.models.open_ai_whisper import OpenAIWhisper
from rai_asr.models.open_wake_word import OpenWakeWord
from rai_asr.models.silero_vad import SileroVAD

__all__ = ["BaseVoiceDetectionModel", "SileroVAD", "OpenWakeWord"]
__all__ = [
"BaseVoiceDetectionModel",
"SileroVAD",
"OpenWakeWord",
"BaseTranscriptionModel",
"LocalWhisper",
"OpenAIWhisper",
]
15 changes: 15 additions & 0 deletions src/rai_asr/rai_asr/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abc import ABC, abstractmethod
from typing import Any, Tuple

import numpy as np
from numpy._typing import NDArray


Expand All @@ -26,3 +27,17 @@ def detected(
self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> Tuple[bool, dict[str, Any]]:
pass


class BaseTranscriptionModel(ABC):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
self.model_name = model_name
self.sample_rate = sample_rate
self.language = language

@abstractmethod
def transcribe(self, data: NDArray[np.int16]) -> str:
pass

def __call__(self, data: NDArray[np.int16]) -> str:
return self.transcribe(data)
32 changes: 32 additions & 0 deletions src/rai_asr/rai_asr/models/local_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import whisper
from numpy._typing import NDArray

from rai_asr.models.base import BaseTranscriptionModel


class LocalWhisper(BaseTranscriptionModel):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
super().__init__(model_name, sample_rate, language)
self.whisper = whisper.load_model(self.model_name)

def transcribe(self, data: NDArray[np.int16]) -> str:
result = whisper.transcribe(self.whisper, data.astype(np.float32) / 32768.0)
transcription = result["text"]
# NOTE: this is only for type enforcement, doesn't need to work on runtime
assert isinstance(transcription, str)
return transcription
47 changes: 47 additions & 0 deletions src/rai_asr/rai_asr/models/open_ai_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (C) 2024 Robotec.AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
from functools import partial

import numpy as np
from numpy.typing import NDArray
from openai import OpenAI
from scipy.io import wavfile

from rai_asr.models.base import BaseTranscriptionModel


class OpenAIWhisper(BaseTranscriptionModel):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
super().__init__(model_name, sample_rate, language)
api_key = os.getenv("OPENAI_API_KEY")
if api_key is None:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
self.api_key = api_key
self.openai_client = OpenAI()
self.model = partial(
self.openai_client.audio.transcriptions.create,
model=self.model_name,
)

def transcribe(self, data: NDArray[np.int16]) -> str:
with io.BytesIO() as temp_wav_buffer:
wavfile.write(temp_wav_buffer, self.sample_rate, data)
temp_wav_buffer.seek(0)
temp_wav_buffer.name = "temp.wav"
response = self.model(file=temp_wav_buffer, language=self.language)
transcription = response.text
return transcription

0 comments on commit d7cceea

Please sign in to comment.