diff --git a/common/seamless_asr.py b/common/seamless_asr.py index 6078645..78b7300 100644 --- a/common/seamless_asr.py +++ b/common/seamless_asr.py @@ -25,7 +25,6 @@ class SeamlessM4TPipeline(BaseModel): class SeamlessM4TInputs(BaseModel): audio: str | None = None # required for ASR, S2ST, and S2TT text: str | None = None # required for T2ST and T2TT - task: typing.Literal["T2ST", "T2TT", "ASR"] = "ASR" # we do not need S2ST and S2TT src_lang: str | None = None # required for T2ST and T2TT tgt_lang: str | None = None # ignored for ASR (only src_lang is used) # seamless uses ISO 639-3 codes for languages @@ -37,36 +36,46 @@ class SeamlessM4TInputs(BaseModel): speaker_id: int = 0 # only used for T2ST, value in [0, 200) -@app.task(name="seamless") +@app.task(name="seamless.t2st") @gooey_gpu.endpoint -def seamless_asr( +def seamless_text2speech_translation( pipeline: SeamlessM4TPipeline, inputs: SeamlessM4TInputs, -) -> AsrOutput | None: - pipe, processor, model = load_pipe(pipeline.model_id) +) -> None: + _, processor, model = load_pipe(pipeline.model_id) tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" - if inputs.task == "ASR": - assert inputs.audio is not None - - audio = requests.get(inputs.audio).content + assert inputs.text is not None + assert inputs.src_lang is not None + text_inputs = processor( + text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" + ) - previous_src_lang = pipe.tokenizer.src_lang - if inputs.src_lang: - pipe.tokenizer.src_lang = inputs.src_lang + audio_array_from_text = ( + model.generate(**text_inputs, tgt_lang=tgt_lang, speaker_id=inputs.speaker_id)[ + 0 + ] + .cpu() + .numpy() + .squeeze() + ) - prediction = pipe( - audio, - # see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1 - chunk_length_s=inputs.chunk_length_s, - stride_length_s=inputs.stride_length_s, - batch_size=inputs.batch_size, - generate_kwargs=dict(tgt_lang=tgt_lang), - ) + bytes_wav = bytes() + byte_io = io.BytesIO(bytes_wav) + write(byte_io, 16000, audio_array_from_text) + audio_bytes = byte_io.read() + gooey_gpu.upload_audio_from_bytes(audio_bytes, pipeline.upload_urls[0]) + return - pipe.tokenizer.src_lang = previous_src_lang - return prediction +@app.task(name="seamless.t2tt") +@gooey_gpu.endpoint +def seamless_text2text_translation( + pipeline: SeamlessM4TPipeline, + inputs: SeamlessM4TInputs, +) -> AsrOutput | None: + _, processor, model = load_pipe(pipeline.model_id) + tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" assert inputs.text is not None assert inputs.src_lang is not None @@ -74,31 +83,45 @@ def seamless_asr( text=inputs.text, src_lang=inputs.src_lang, return_tensors="pt" ) - if inputs.task == "T2ST": - audio_array_from_text = ( - model.generate( - **text_inputs, tgt_lang=tgt_lang, speaker_id=inputs.speaker_id - )[0] - .cpu() - .numpy() - .squeeze() - ) - - bytes_wav = bytes() - byte_io = io.BytesIO(bytes_wav) - write(byte_io, 16000, audio_array_from_text) - audio_bytes = byte_io.read() - gooey_gpu.upload_audio_from_bytes(audio_bytes, pipeline.upload_urls[0]) - return - if inputs.task == "T2TT": - output_tokens = model.generate( - **text_inputs, tgt_lang=tgt_lang, generate_speech=False - ) - translated_text_from_text = processor.decode( - output_tokens[0].tolist()[0], skip_special_tokens=True - ) - - return AsrOutput(text=translated_text_from_text) + output_tokens = model.generate( + **text_inputs, tgt_lang=tgt_lang, generate_speech=False + ) + translated_text_from_text = processor.decode( + output_tokens[0].tolist()[0], skip_special_tokens=True + ) + + return AsrOutput(text=translated_text_from_text) + + +@app.task(name="seamless") +@gooey_gpu.endpoint +def seamless_asr( + pipeline: SeamlessM4TPipeline, + inputs: SeamlessM4TInputs, +) -> AsrOutput | None: + pipe, _, _ = load_pipe(pipeline.model_id) + tgt_lang = inputs.tgt_lang or inputs.src_lang or "eng" + + assert inputs.audio is not None + + audio = requests.get(inputs.audio).content + + previous_src_lang = pipe.tokenizer.src_lang + if inputs.src_lang: + pipe.tokenizer.src_lang = inputs.src_lang + + prediction = pipe( + audio, + # see https://colab.research.google.com/drive/1rS1L4YSJqKUH_3YxIQHBI982zso23wor#scrollTo=Ca4YYdtATxzo&line=5&uniqifier=1 + chunk_length_s=inputs.chunk_length_s, + stride_length_s=inputs.stride_length_s, + batch_size=inputs.batch_size, + generate_kwargs=dict(tgt_lang=tgt_lang), + ) + + pipe.tokenizer.src_lang = previous_src_lang + + return prediction @lru_cache