Skip to content

Commit

Permalink
BetterTransformer implementation example
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Oct 13, 2023
1 parent f9f1837 commit 489239f
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 215 deletions.
2 changes: 1 addition & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ DEBUG=True
# docker run command as a volume.
# e.g. WHISPER_MODEL="/app/models/custom"
# docker cmd: -v /path/to/custom/model:/app/models/custom
WHISPER_MODEL="large-v2"
WHISPER_MODEL="openai/whisper-large-v2"
# The compute_type parameter is used to control the precision of the model. You can choose between:
# "int8", "int8_float16", "int8_bfloat16", "int16", "float_16", "bfloat16", "float32".
# The default value is "float16".
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,24 @@ classifiers = [
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"accelerate",
"aiohttp>=3.8.4",
"aiofiles>=23.1.0",
"bitsandbytes",
"ctranslate2>=3.18.0",
"faster-whisper @ git+https://github.com/Wordcab/faster-whisper@master",
"ffmpeg-python>=0.2.0",
"librosa>=0.9.0",
"loguru>=0.6.0",
"numpy==1.23.1",
"onnxruntime>=1.15.0",
"optimum",
"pydantic>=1.10.9",
"python-dotenv>=1.0.0",
"tensorshare>=0.1.1",
"torch>=2.0.0",
"torchaudio>=2.0.1",
"transformers @ git+https://github.com/huggingface/transformers",
"wget>=3.2.0",
"yt-dlp>=2023.3.4",
]
Expand Down
29 changes: 27 additions & 2 deletions src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from enum import Enum
from typing import List, Literal, NamedTuple, Optional, Union

from faster_whisper.transcribe import Segment
from pydantic import BaseModel, HttpUrl, field_validator
from tensorshare import TensorShare

Expand Down Expand Up @@ -529,10 +528,36 @@ class MultiChannelTranscriptionOutput(BaseModel):
segments: List[MultiChannelSegment]


class Chunk(BaseModel):
"""Whisper transcription chunk."""

start: float
end: float
text: str
words: None


class TranscriptionOutput(BaseModel):
"""Transcription output model for the API."""

segments: List[Segment]
segments: List[Chunk]

@classmethod
def get_segments_from_outputs(cls, outputs: List[dict]) -> "TranscriptionOutput":
"""Create a TranscriptionOutput object from a list of outputs."""
chunks = []
for out in outputs:
start, end = out["timestamp"]
chunks.append(
Chunk(
start=start,
end=end,
text=out["text"],
words=None,
)
)

return cls(segments=chunks)


class TranscribeRequest(BaseModel):
Expand Down
Loading

0 comments on commit 489239f

Please sign in to comment.