Skip to content

Commit

Permalink
refactor into separate gpu tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
SanderGi committed Jun 10, 2024
1 parent 3c9c056 commit 32964e5
Showing 1 changed file with 70 additions and 47 deletions.
117 changes: 70 additions & 47 deletions common/seamless_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class SeamlessM4TPipeline(BaseModel):
class SeamlessM4TInputs(BaseModel):
audio: str | None = None # required for ASR, S2ST, and S2TT
text: str | None = None # required for T2ST and T2TT
task: typing.Literal["T2ST", "T2TT", "ASR"] = "ASR" # we do not need S2ST and S2TT
src_lang: str | None = None # required for T2ST and T2TT
tgt_lang: str | None = None # ignored for ASR (only src_lang is used)
# seamless uses ISO 639-3 codes for languages
Expand All @@ -37,68 +36,92 @@ class SeamlessM4TInputs(BaseModel):
speaker_id: int = 0 # only used for T2ST, value in [0, 200)


@app.task(name="seamless")
@app.task(name="seamless.t2st")
@gooey_gpu.endpoint
def seamless_asr(
def seamless_text2speech_translation(
pipeline: SeamlessM4TPipeline,
inputs: SeamlessM4TInputs,
) -> AsrOutput | None:
pipe, processor, model = load_pipe(pipeline.model_id)
) -> None:
_, processor, model = load_pipe(pipeline.model_id)
tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng"

if inputs.task == "ASR":
assert inputs.audio is not None

audio = requests.get(inputs.audio).content
assert inputs.text is not None
assert inputs.src_lang is not None
text_inputs = processor(
text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt"
)

previous_src_lang = pipe.tokenizer.src_lang
if inputs.src_lang:
pipe.tokenizer.src_lang = inputs.src_lang
audio_array_from_text = (
model.generate(**text_inputs, tgt_lang=tgt_lang, speaker_id=inputs.speaker_id)[
0
]
.cpu()
.numpy()
.squeeze()
)

prediction = pipe(
audio,
# see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1
chunk_length_s=inputs.chunk_length_s,
stride_length_s=inputs.stride_length_s,
batch_size=inputs.batch_size,
generate_kwargs=dict(tgt_lang=tgt_lang),
)
bytes_wav = bytes()
byte_io = io.BytesIO(bytes_wav)
write(byte_io, 16000, audio_array_from_text)
audio_bytes = byte_io.read()
gooey_gpu.upload_audio_from_bytes(audio_bytes, pipeline.upload_urls[0])
return

pipe.tokenizer.src_lang = previous_src_lang

return prediction
@app.task(name="seamless.t2tt")
@gooey_gpu.endpoint
def seamless_text2text_translation(
pipeline: SeamlessM4TPipeline,
inputs: SeamlessM4TInputs,
) -> AsrOutput | None:
_, processor, model = load_pipe(pipeline.model_id)
tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng"

assert inputs.text is not None
assert inputs.src_lang is not None
text_inputs = processor(
text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt"
)

if inputs.task == "T2ST":
audio_array_from_text = (
model.generate(
**text_inputs, tgt_lang=tgt_lang, speaker_id=inputs.speaker_id
)[0]
.cpu()
.numpy()
.squeeze()
)

bytes_wav = bytes()
byte_io = io.BytesIO(bytes_wav)
write(byte_io, 16000, audio_array_from_text)
audio_bytes = byte_io.read()
gooey_gpu.upload_audio_from_bytes(audio_bytes, pipeline.upload_urls[0])
return
if inputs.task == "T2TT":
output_tokens = model.generate(
**text_inputs, tgt_lang=tgt_lang, generate_speech=False
)
translated_text_from_text = processor.decode(
output_tokens[0].tolist()[0], skip_special_tokens=True
)

return AsrOutput(text=translated_text_from_text)
output_tokens = model.generate(
**text_inputs, tgt_lang=tgt_lang, generate_speech=False
)
translated_text_from_text = processor.decode(
output_tokens[0].tolist()[0], skip_special_tokens=True
)

return AsrOutput(text=translated_text_from_text)


@app.task(name="seamless")
@gooey_gpu.endpoint
def seamless_asr(
pipeline: SeamlessM4TPipeline,
inputs: SeamlessM4TInputs,
) -> AsrOutput | None:
pipe, _, _ = load_pipe(pipeline.model_id)
tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng"

assert inputs.audio is not None

audio = requests.get(inputs.audio).content

previous_src_lang = pipe.tokenizer.src_lang
if inputs.src_lang:
pipe.tokenizer.src_lang = inputs.src_lang

prediction = pipe(
audio,
# see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1
chunk_length_s=inputs.chunk_length_s,
stride_length_s=inputs.stride_length_s,
batch_size=inputs.batch_size,
generate_kwargs=dict(tgt_lang=tgt_lang),
)

pipe.tokenizer.src_lang = previous_src_lang

return prediction


@lru_cache
Expand Down

0 comments on commit 32964e5

Please sign in to comment.