Skip to content

Commit

Permalink
Update switch between local and remote and reverse
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Oct 5, 2023
1 parent 147e295 commit 38029c3
Showing 1 changed file with 141 additions and 108 deletions.
249 changes: 141 additions & 108 deletions src/wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,51 @@ class TranscriptionTask(BaseModel):


@dataclass
class ServiceHandler:
"""Services handler model."""
class LocalServiceRegistry:
"""Registry for local services."""

diarization: Union[DiarizeService, None] = None
post_processing: PostProcessingService = PostProcessingService()
transcription: Union[TranscribeService, None] = None
vad: VadService = VadService()


@dataclass
class RemoteServiceConfig:
"""Remote service config."""

url_handler: Union[URLService, None] = None
use_remote: bool = False

def get_urls(self) -> List[str]:
"""Get the list of URLs."""
return self.url_handler.get_urls()

def get_queue_size(self) -> int:
"""Get the queue size."""
return self.url_handler.get_queue_size()

async def add_url(self, url: str) -> None:
"""Add a URL to the list of URLs."""
await self.url_handler.add_url(url)

async def next_url(self) -> str:
"""Get the next URL."""
return await self.url_handler.next_url()

async def remove_url(self, url: str) -> None:
"""Remove a URL from the list of URLs."""
await self.url_handler.remove_url(url)


@dataclass
class RemoteServiceRegistry:
"""Registry for remote services."""

diarization: RemoteServiceConfig = RemoteServiceConfig()
transcription: RemoteServiceConfig = RemoteServiceConfig()


class ASRService(ABC):
"""Base ASR Service module that handle all AI interactions and batch processing."""

Expand Down Expand Up @@ -184,8 +220,8 @@ def __init__(
self,
whisper_model: str,
compute_type: str,
window_lengths: List[int],
shift_lengths: List[int],
window_lengths: List[float],
shift_lengths: List[float],
multiscale_weights: List[float],
extra_languages: Union[List[str], None],
extra_languages_model_paths: Union[List[str], None],
Expand All @@ -201,9 +237,9 @@ def __init__(
The path to the whisper model.
compute_type (str):
The compute type to use for inference.
window_lengths (List[int]):
window_lengths (List[float]):
The window lengths to use for diarization.
shift_lengths (List[int]):
shift_lengths (List[float]):
The shift lengths to use for diarization.
multiscale_weights (List[float]):
The multiscale weights to use for diarization.
Expand All @@ -222,7 +258,18 @@ def __init__(
"""
super().__init__()

self.services: ServiceHandler = ServiceHandler()
self.whisper_model: str = whisper_model
self.compute_type: str = compute_type
self.window_lengths: List[float] = window_lengths
self.shift_lengths: List[float] = shift_lengths
self.multiscale_weights: List[float] = multiscale_weights
self.extra_languages: Union[List[str], None] = extra_languages
self.extra_languages_model_paths: Union[List[str], None] = (
extra_languages_model_paths
)

self.local_services: LocalServiceRegistry = LocalServiceRegistry()
self.remote_services: RemoteServiceRegistry = RemoteServiceRegistry()
self.dual_channel_transcribe_options: dict = {
"beam_size": 5,
"patience": 1,
Expand All @@ -237,48 +284,67 @@ def __init__(
"You provided URLs for remote transcription server, no local model will"
" be used."
)
self.use_remote_transcription = True
self.transcription_url_handler = URLService(
remote_urls=transcribe_server_urls
self.remote_services.transcription = RemoteServiceConfig(
use_remote=True,
url_handler=URLService(remote_urls=transcribe_server_urls),
)
else:
logger.info(
"You did not provide URLs for remote transcription server, local model"
" will be used."
)
self.use_remote_transcription = False
self.services.transcription = TranscribeService(
model_path=whisper_model,
compute_type=compute_type,
device=self.device,
device_index=self.device_index,
extra_languages=extra_languages,
extra_languages_model_paths=extra_languages_model_paths,
)
self.create_transcription_local_service()

if diarize_server_urls is not None:
logger.info(
"You provided URLs for remote diarization server, no local model will"
" be used."
)
self.use_remote_diarization = True
self.diarization_url_handler = URLService(remote_urls=diarize_server_urls)
self.remote_services.diarization = RemoteServiceConfig(
use_remote=True,
url_handler=URLService(remote_urls=diarize_server_urls),
)
else:
logger.info(
"You did not provide URLs for remote diarization server, local model"
" will be used."
)
self.use_remote_diarization = False
self.services.diarization = DiarizeService(
device=self.device,
device_index=self.device_index,
window_lengths=window_lengths,
shift_lengths=shift_lengths,
multiscale_weights=multiscale_weights,
)
self.create_diarization_local_service()

self.debug_mode = debug_mode

def create_transcription_local_service(self) -> None:
"""Create a local transcription service."""
self.local_services.transcription = TranscribeService(
model_path=self.whisper_model,
compute_type=self.compute_type,
device=self.device,
device_index=self.device_index,
extra_languages=self.extra_languages,
extra_languages_model_paths=self.extra_languages_model_paths,
)

def create_diarization_local_service(self) -> None:
"""Create a local diarization service."""
self.local_services.diarization = DiarizeService(
device=self.device,
device_index=self.device_index,
window_lengths=self.window_lengths,
shift_lengths=self.shift_lengths,
multiscale_weights=self.multiscale_weights,
)

def create_local_service(
self, task: Literal["transcription", "diarization"]
) -> None:
"""Create a local service."""
if task == "transcription":
self.create_transcription_local_service()
elif task == "diarization":
self.create_diarization_local_service()
else:
raise NotImplementedError("No task specified.")

async def inference_warmup(self) -> None:
"""Warmup the GPU by loading the models."""
sample_path = Path(__file__).parent.parent / "assets/warmup_sample.wav"
Expand Down Expand Up @@ -394,16 +460,16 @@ async def process_input( # noqa: C901
)

gpu_index = None
if self.use_remote_transcription:
_url = await self.transcription_url_handler.next_url()
if self.remote_services.transcription.use_remote is True:
_url = await self.remote_services.transcription.next_url()
transcription_execution = RemoteExecution(url=_url)
else:
gpu_index = await self.gpu_handler.get_device()
transcription_execution = LocalExecution(index=gpu_index)

if diarization and multi_channel is False:
if self.use_remote_diarization:
_url = await self.diarization_url_handler.next_url()
if self.remote_services.diarization.use_remote is True:
_url = await self.remote_services.diarization.next_url()
diarization_execution = RemoteExecution(url=_url)
else:
if gpu_index is None:
Expand Down Expand Up @@ -498,7 +564,7 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None:
try:
if isinstance(task.transcription.execution, LocalExecution):
out = await time_and_tell_async(
lambda: self.services.transcription(
lambda: self.local_services.transcription(
task.audio,
model_index=task.transcription.execution.index,
suppress_blank=False,
Expand Down Expand Up @@ -565,12 +631,12 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None:
try:
if isinstance(task.diarization.execution, LocalExecution):
out = await time_and_tell_async(
lambda: self.services.diarization(
lambda: self.local_services.diarization(
waveform=task.audio,
audio_duration=task.duration,
oracle_num_speakers=task.diarization.num_speakers,
model_index=task.diarization.execution.index,
vad_service=self.services.vad,
vad_service=self.local_services.vad,
),
func_name="diarization",
debug_mode=debug_mode,
Expand Down Expand Up @@ -630,7 +696,7 @@ def process_post_processing(self, task: ASRTask) -> None:

if task.multi_channel:
utterances, process_time = time_and_tell(
self.services.post_processing.multi_channel_speaker_mapping(
self.local_services.post_processing.multi_channel_speaker_mapping(
task.transcription.result
),
func_name="multi_channel_speaker_mapping",
Expand All @@ -650,7 +716,7 @@ def process_post_processing(self, task: ASRTask) -> None:

if task.diarization.execution is not None:
utterances, process_time = time_and_tell(
self.services.post_processing.single_channel_speaker_mapping(
self.local_services.post_processing.single_channel_speaker_mapping(
transcript_segments=formatted_segments,
speaker_timestamps=task.diarization.result,
word_timestamps=task.word_timestamps,
Expand All @@ -663,7 +729,7 @@ def process_post_processing(self, task: ASRTask) -> None:
utterances = formatted_segments

final_utterances, process_time = time_and_tell(
self.services.post_processing.final_processing_before_returning(
self.local_services.post_processing.final_processing_before_returning(
utterances=utterances,
offset_start=task.offset_start,
timestamps_format=task.timestamps_format,
Expand Down Expand Up @@ -726,31 +792,20 @@ async def get_url(
self, task: Literal["transcription", "diarization"]
) -> Union[List[str], ProcessException]:
"""Get the list of remote URLs."""
logger.info(self.remote_services.transcription)
logger.info(self.remote_services.diarization)
try:
if task == "transcription":
# Case 1: We are not using remote transcription
if self.use_remote_transcription is False:
return ProcessException(
source=ExceptionSource.get_url,
message="You are not using remote transcription.",
)
# Case 2: We are using remote transcription
else:
return self.transcription_url_handler.get_urls()

elif task == "diarization":
# Case 1: We are not using remote diarization
if self.use_remote_diarization is False:
return ProcessException(
source=ExceptionSource.get_url,
message="You are not using remote diarization.",
)
# Case 2: We are using remote diarization
else:
return self.diarization_url_handler.get_urls()

selected_task = getattr(self.remote_services, task)
logger.info(selected_task)
# Case 1: We are not using remote task
if selected_task.use_remote is False:
return ProcessException(
source=ExceptionSource.get_url,
message=f"You are not using remote {task}.",
)
# Case 2: We are using remote task
else:
raise ValueError(f"{task} is not a valid task.")
return selected_task.get_urls()

except Exception as e:
return ProcessException(
Expand All @@ -761,30 +816,21 @@ async def get_url(
async def add_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]:
"""Add a remote URL to the list of URLs."""
try:
if data.task == "transcription":
# Case 1: We are not using remote transcription yet
if self.use_remote_transcription is False:
self.use_remote_transcription = True
self.transcription_url_handler = URLService(
remote_urls=[str(data.url)]
)
# Case 2: We are already using remote transcription
else:
await self.transcription_url_handler.add_url(data.url)

elif data.task == "diarization":
# Case 1: We are not using remote diarization yet
if self.use_remote_diarization is False:
self.use_remote_diarization = True
self.diarization_url_handler = URLService(
remote_urls=[str(data.url)]
)
# Case 2: We are already using remote diarization
else:
await self.diarization_url_handler.add_url(data.url)

selected_task = getattr(self.remote_services, data.task)
# Case 1: We are not using remote task yet
if selected_task.use_remote is False:
setattr(
self.remote_services,
data.task,
RemoteServiceConfig(
use_remote=True,
url_handler=URLService(remote_urls=[str(data.url)]),
),
)
setattr(self.local_services, data.task, None)
# Case 2: We are already using remote task
else:
raise ValueError(f"{data.task} is not a valid task.")
await selected_task.add_url(str(data.url))

except Exception as e:
return ProcessException(
Expand All @@ -797,30 +843,17 @@ async def add_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]:
async def remove_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]:
"""Remove a remote URL from the list of URLs."""
try:
if data.task == "transcription":
# Case 1: We are not using remote transcription
if self.use_remote_transcription is False:
raise ValueError("You are not using remote transcription.")
# Case 2: We are using remote transcription
else:
await self.transcription_url_handler.remove_url(str(data.url))
if self.transcription_url_handler.get_queue_size() == 0:
# TODO: Add a way to switch back to local transcription
pass

elif data.task == "diarization":
# Case 1: We are not using remote diarization
if self.use_remote_diarization is False:
raise ValueError("You are not using remote diarization.")
# Case 2: We are using remote diarization
else:
await self.diarization_url_handler.remove_url(str(data.url))
if self.diarization_url_handler.get_queue_size() == 0:
# TODO: Add a way to switch back to local diarization
pass

selected_task = getattr(self.remote_services, data.task)
# Case 1: We are not using remote task
if selected_task.use_remote is False:
raise ValueError(f"You are not using remote {data.task}.")
# Case 2: We are using remote task
else:
raise ValueError(f"{data.task} is not a valid task.")
await selected_task.remove_url(str(data.url))
if selected_task.get_queue_size() == 0:
# No more remote URLs, switch to local service
self.create_local_service(task=data.task)
setattr(self.remote_services, data.task, RemoteServiceConfig())

return data

Expand Down

0 comments on commit 38029c3

Please sign in to comment.