Skip to content

Commit

Permalink
Merge pull request #11 from mobiusml/init_whisper
Browse files Browse the repository at this point in the history
Basic ASR Model Integration with Faster-Whisper
  • Loading branch information
movchan74 authored Nov 17, 2023
2 parents 42c7485 + df73c3c commit f148012
Show file tree
Hide file tree
Showing 16 changed files with 4,663 additions and 2 deletions.
32 changes: 31 additions & 1 deletion aana/api/responses.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
from typing import Any, Optional
from fastapi.responses import JSONResponse
import orjson
from pydantic import BaseModel


def json_serializer_default(obj: Any) -> Any:
"""
Default function for json serializer to handle pydantic models.
If json serializer does not know how to serialize an object, it calls the default function.
If we see that the object is a pydantic model,
we call the dict method to get the dictionary representation of the model
that json serializer can deal with.
If the object is not a pydantic model, we raise a TypeError.
Args:
obj (Any): The object to serialize.
Returns:
Any: The serializable object.
Raises:
TypeError: If the object is not a pydantic model.
"""

if isinstance(obj, BaseModel):
return obj.dict()
raise TypeError


class AanaJSONResponse(JSONResponse):
Expand All @@ -23,4 +51,6 @@ def render(self, content: Any) -> bytes:
"""
Override the render method to use orjson.dumps instead of json.dumps.
"""
return orjson.dumps(content, option=self.option)
return orjson.dumps(
content, option=self.option, default=json_serializer_default
)
15 changes: 15 additions & 0 deletions aana/configs/deployments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from aana.deployments.hf_blip2_deployment import HFBlip2Config, HFBlip2Deployment
from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment
from aana.deployments.whisper_deployment import (
WhisperComputeType,
WhisperConfig,
WhisperDeployment,
WhisperModelSize,
)
from aana.models.core.dtype import Dtype
from aana.models.pydantic.sampling_params import SamplingParams

Expand Down Expand Up @@ -45,4 +51,13 @@
num_processing_threads=2,
).dict(),
),
"whisper_deployment_medium": WhisperDeployment.options(
num_replicas=1,
max_concurrent_queries=1000,
ray_actor_options={"num_gpus": 0.5},
user_config=WhisperConfig(
model_size=WhisperModelSize.MEDIUM,
compute_type=WhisperComputeType.FLOAT16,
).dict(),
),
}
12 changes: 12 additions & 0 deletions aana/configs/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,16 @@
outputs=["timestamps", "duration"],
)
],
"whisper": [
Endpoint(
name="whisper_transcribe",
path="/video/transcribe",
summary="Transcribe a video using Whisper Medium",
outputs=[
"video_transcriptions_whisper_medium",
"video_transcriptions_segments_whisper_medium",
"video_transcriptions_info_whisper_medium",
],
)
],
}
58 changes: 58 additions & 0 deletions aana/configs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@
It is used to generate the pipeline and the API endpoints.
"""

from aana.models.pydantic.asr_output import (
AsrSegmentsList,
AsrTranscriptionInfoList,
AsrTranscriptionList,
)
from aana.models.pydantic.captions import CaptionsList, VideoCaptionsList
from aana.models.pydantic.image_input import ImageInputList
from aana.models.pydantic.prompt import Prompt
from aana.models.pydantic.sampling_params import SamplingParams
from aana.models.pydantic.video_input import VideoInputList
from aana.models.pydantic.video_params import VideoParams
from aana.models.pydantic.whisper_params import WhisperParams

# container data model
# we don't enforce this data model for now but it's a good reference for writing paths and flatten_by
Expand Down Expand Up @@ -293,4 +299,56 @@
}
],
},
{
"name": "whisper_params",
"type": "input",
"inputs": [],
"outputs": [
{
"name": "whisper_params",
"key": "whisper_params",
"path": "video_batch.whisper_params",
"data_model": WhisperParams,
}
],
},
{
"name": "whisper_medium_transcribe_videos",
"type": "ray_deployment",
"deployment_name": "whisper_deployment_medium",
"method": "transcribe_batch",
"inputs": [
{
"name": "video_objects",
"key": "media_batch",
"path": "video_batch.videos.[*].video",
},
{
"name": "whisper_params",
"key": "params",
"path": "video_batch.whisper_params",
"data_model": WhisperParams,
},
],
"outputs": [
{
"name": "video_transcriptions_segments_whisper_medium",
"key": "segments",
"path": "video_batch.videos.[*].segments",
"data_model": AsrSegmentsList,
},
{
"name": "video_transcriptions_info_whisper_medium",
"key": "transcription_info",
"path": "video_batch.videos.[*].transcription_info",
"data_model": AsrTranscriptionInfoList,
},
{
"name": "video_transcriptions_whisper_medium",
"key": "transcription",
"path": "video_batch.videos.[*].transcription",
"data_model": AsrTranscriptionList,
},
],
},
]
217 changes: 217 additions & 0 deletions aana/deployments/whisper_deployment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
from enum import Enum
from typing import Any, Dict, List, TypedDict, cast
from faster_whisper import WhisperModel
from pydantic import BaseModel, Field
from ray import serve
import torch

from aana.deployments.base_deployment import BaseDeployment
from aana.exceptions.general import InferenceException
from aana.models.core.video import Video
from aana.models.pydantic.asr_output import (
AsrSegment,
AsrTranscription,
AsrTranscriptionInfo,
)
from aana.models.pydantic.whisper_params import WhisperParams


class WhisperComputeType(str, Enum):
"""
The data type used by whisper models.
See [cTranslate2 docs on quantization](https://opennmt.net/CTranslate2/quantization.html#quantize-on-model-conversion)
for more information.
Available types:
- INT8
- INT8_FLOAT32
- INT8_FLOAT16
- INT8_BFLOAT16
- INT16
- FLOAT16
- BFLOAT16
- FLOAT32
"""

INT8 = "int8"
INT8_FLOAT32 = "int8_float32"
INT8_FLOAT16 = "int8_float16"
INT8_BFLOAT16 = "int8_bfloat16"
INT16 = "int16"
FLOAT16 = "float16"
BFLOAT16 = "bfloat16"
FLOAT32 = "float32"


class WhisperModelSize(str, Enum):
"""
The whisper model.
Available models:
- TINY
- TINY_EN
- BASE
- BASE_EN
- SMALL
- SMALL_EN
- MEDIUM
- MEDIUM_EN
- LARGE_V1
- LARGE_V2
- LARGE
"""

TINY = "tiny"
TINY_EN = "tiny.en"
BASE = "base"
BASE_EN = "base.en"
SMALL = "small"
SMALL_EN = "small.en"
MEDIUM = "medium"
MEDIUM_EN = "medium.en"
LARGE_V1 = "large-v1"
LARGE_V2 = "large-v2"
LARGE = "large"


class WhisperConfig(BaseModel):
"""
The configuration for the whisper deployment from faster-whisper.
"""

model_size: WhisperModelSize = Field(
default=WhisperModelSize.BASE, description="The whisper model size."
)
compute_type: WhisperComputeType = Field(
default=WhisperComputeType.FLOAT16, description="The compute type."
)


class WhisperOutput(TypedDict):
"""
The output of the whisper model.
Attributes:
segments (List[AsrSegment]): The ASR segments.
transcription_info (AsrTranscriptionInfo): The ASR transcription info.
transcription (AsrTranscription): The ASR transcription.
"""

segments: List[AsrSegment]
transcription_info: AsrTranscriptionInfo
transcription: AsrTranscription


class WhisperBatchOutput(TypedDict):
"""
The output of the whisper model for a batch of inputs.
Attributes:
segments (List[List[AsrSegment]]): The ASR segments for each media.
transcription_info (List[AsrTranscriptionInfo]): The ASR transcription info for each media.
transcription (List[AsrTranscription]): The ASR transcription for each media.
"""

segments: List[List[AsrSegment]]
transcription_info: List[AsrTranscriptionInfo]
transcription: List[AsrTranscription]


@serve.deployment
class WhisperDeployment(BaseDeployment):
"""
Deployment to serve Whisper models from faster-whisper.
"""

async def apply_config(self, config: Dict[str, Any]):
"""
Apply the configuration.
The method is called when the deployment is created or updated.
It loads the model and processor from HuggingFace.
The configuration should conform to the HFBlip2Config schema.
"""

config_obj = WhisperConfig(**config)
self.model_size = config_obj.model_size
self.model_name = "whisper_" + self.model_size
self.compute_type = config_obj.compute_type
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = WhisperModel(
self.model_size, device=self.device, compute_type=self.compute_type
)

# TODO: add audio support
async def transcribe(
self, media: Video, params: WhisperParams = WhisperParams()
) -> WhisperOutput:
"""
Transcribe the media with the whisper model.
Args:
media (Video): The media to transcribe.
params (WhisperParams): The parameters for the whisper model.
Returns:
WhisperOutput: The transcription output as a dictionary:
segments (List[AsrSegment]): The ASR segments.
transcription_info (AsrTranscriptionInfo): The ASR transcription info.
transcription (AsrTranscription): The ASR transcription.
Raises:
InferenceException: If the inference fails.
"""

media_path: str = str(media.path)
try:
segments, info = self.model.transcribe(media_path, **params.dict())
except Exception as e:
raise InferenceException(self.model_name) from e

asr_segments = [AsrSegment.from_whisper(seg) for seg in segments]
asr_transcription_info = AsrTranscriptionInfo.from_whisper(info)
transcription = "".join([seg.text for seg in asr_segments])
asr_transcription = AsrTranscription(text=transcription)

return WhisperOutput(
segments=asr_segments,
transcription_info=asr_transcription_info,
transcription=asr_transcription,
)

async def transcribe_batch(
self, media_batch: List[Video], params: WhisperParams = WhisperParams()
) -> WhisperBatchOutput:
"""
Transcribe the batch of media with the whisper model.
Args:
media (List[Video]): The batch of media to transcribe.
params (WhisperParams): The parameters for the whisper model.
Returns:
WhisperBatchOutput: The transcription output as a dictionary:
segments (List[List[AsrSegment]]): The ASR segments for each media.
transcription_info (List[AsrTranscriptionInfo]): The ASR transcription info for each media.
transcription (List[AsrTranscription]): The ASR transcription for each media.
Raises:
InferenceException: If the inference fails.
"""

segments: List[List[AsrSegment]] = []
infos: List[AsrTranscriptionInfo] = []
transcriptions: List[AsrTranscription] = []
for media in media_batch:
output = await self.transcribe(media, params)
segments.append(cast(List[AsrSegment], output["segments"]))
infos.append(cast(AsrTranscriptionInfo, output["transcription_info"]))
transcriptions.append(cast(AsrTranscription, output["transcription"]))

return WhisperBatchOutput(
segments=segments, transcription_info=infos, transcription=transcriptions
)
Loading

0 comments on commit f148012

Please sign in to comment.