Skip to content

Commit

Permalink
Merge pull request #28 from TalkBank/feat/cantonese-utterance-segment…
Browse files Browse the repository at this point in the history
…ation

Feat/cantonese utterance segmentation

co-authored-by: Sebantian Song <[email protected]>
  • Loading branch information
Jemoka and Sebantian Song authored Feb 23, 2025
2 parents 7d4e7cb + 74e7255 commit dd3fa24
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 15 deletions.
2 changes: 1 addition & 1 deletion batchalign/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .utterance import BertUtteranceModel
from .utterance import BertUtteranceModel, BertCantoneseUtteranceModel
from .whisper import WhisperASRModel, WhisperFAModel
from .speaker import NemoSpeakerModel
from .utils import ASRAudioFile
Expand Down
2 changes: 1 addition & 1 deletion batchalign/models/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"utterance": {
'eng': "talkbank/CHATUtterance-en",
"zho": "talkbank/CHATUtterance-zh_CN",
"yue": "talkbank/CHATUtterance-zh_CN",
"yue": "PolyU-AngelChanLab/Cantonese-Utterance-Segmentation",
},
"whisper": {
'eng': ("talkbank/CHATWhisper-en-large-v1", "openai/whisper-large-v2"),
Expand Down
2 changes: 2 additions & 0 deletions batchalign/models/utterance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .infer import BertUtteranceModel
from .cantonese_infer import BertCantoneseUtteranceModel


164 changes: 164 additions & 0 deletions batchalign/models/utterance/cantonese_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import re
import string
import random

# tokenization utilities
import nltk
from nltk import word_tokenize, sent_tokenize

# torch
import torch
from torch.utils.data import dataset
from torch.utils.data.dataloader import DataLoader
from torch.optim import AdamW

# import huggingface utils
from transformers import AutoTokenizer, BertForTokenClassification
from transformers import DataCollatorForTokenClassification

# tqdm
from tqdm import tqdm

# seed device and tokens
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# seed model
class BertCantoneseUtteranceModel(object):

def __init__(self, model):
# seed tokenizers and model
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = BertForTokenClassification.from_pretrained(model).to(DEVICE)
self.max_length = 512
self.overlap = 20

# eval mode
self.model.eval()
print(f"Model and tokenizer initialized on device: {DEVICE}")
print(f"Max length set to {self.max_length} with overlap of {self.overlap}")

def __call__(self, passage):
# Step 1: Clean up passage
passage = passage.lower()
passage = passage.replace('.','')
passage = passage.replace(',','')
passage = passage.replace('!','')
passage = passage.replace('!','')
passage = passage.replace('?','')
passage = passage.replace('。','')
passage = passage.replace(',','')
passage = passage.replace('?','')
passage = passage.replace('(','')
passage = passage.replace(')','')
passage = passage.replace(':','')
passage = passage.replace('*','')
passage = passage.replace('l','')


# Step 2: Define keywords and split the passage based on them
keywords = ['呀', '啦', '喎', '嘞', '㗎喇', '囉', '㗎', '啊', '嗯'] # Replace with your desired keywords

chunks = []
start = 0

while start < len(passage):
# Find the position of each keyword in the passage starting from the current `start`
keyword_positions = [(keyword, passage.find(keyword, start)) for keyword in keywords]
# Filter out keywords that are not found (find() returns -1 if not found)
keyword_positions = [kp for kp in keyword_positions if kp[1] != -1]

if keyword_positions:
# Find the keyword that appears first in the passage from current start
first_keyword, keyword_pos = min(keyword_positions, key=lambda x: x[1])
chunk = passage[start:keyword_pos + len(first_keyword)]
chunks.append(chunk)
start = keyword_pos + len(first_keyword)
else:
# No more keywords found, add the rest of the passage as the last chunk
chunks.append(passage[start:])
break

# Debugging: Print number of chunks and their content
print(f"Created {len(chunks)} chunks based on keywords.")
for i, chunk in enumerate(chunks):
print(f"Chunk {i + 1}: {chunk[:100]}...") # Print the first 100 characters of each chunk

# Step 3: Process each chunk and restore punctuation
final_passage = []
for chunk_index, chunk in enumerate(chunks):
print(f"Processing chunk {chunk_index + 1}/{len(chunks)}...")

# Step 3.1: Split chunk by characters (Chinese tokenization)
tokenized_chunk = list(chunk) # Simply split by characters for Chinese text

# Step 3.2: Pass chunk through the tokenizer and model
tokd = self.tokenizer.batch_encode_plus([tokenized_chunk],
return_tensors='pt',
truncation=True,
padding=True,
max_length=self.max_length,
is_split_into_words=True).to(DEVICE)

try:
# Pass it through the model
res = self.model(**tokd).logits
except Exception as e:
print(f"Error during model inference: {e}")
return []

# Argmax for classification
classified_targets = torch.argmax(res, dim=2).cpu()

# Initialize result tokens list for the current chunk
res_toks = []
prev_word_idx = None

# Iterate over tokenized words
wids = tokd.word_ids(0)
for indx, elem in enumerate(wids):
if elem is None or elem == prev_word_idx:
continue

prev_word_idx = elem
action = classified_targets[0][indx]

# Get the word corresponding to the token
w = tokenized_chunk[elem] # Use tokenized chunk here

# Fix one word hanging issue (if needed)
will_action = False
if indx < len(wids) - 2 and classified_targets[0][indx + 1] > 0:
will_action = True

if not will_action:
# Perform the edits based on model predictions
if action == 1: # First capital letter
w = w[0].upper() + w[1:]
elif action == 2: # Add period
w = w + '.'
elif action == 3: # Add question mark
w = w + '?'
elif action == 4: # Add exclamation mark
w = w + '!'
elif action == 5: # Add comma
w = w + ','

# Append modified word to result list
res_toks.append(w)

# Convert list of tokens back to string and append to final_passage
final_passage.append(self.tokenizer.convert_tokens_to_string(res_toks))

# Step 4: Join processed chunks together into the final passage
final_text = ' '.join(final_passage)

print("Text processing completed. Generating final output...")

# Optionally, tokenize the final text into sentences based on punctuation
try:
split_passage = sent_tokenize(final_text)
except LookupError:
nltk.download('punkt')
split_passage = sent_tokenize(final_text)

return split_passage
1 change: 1 addition & 0 deletions batchalign/models/whisper/infer_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import logging
L = logging.getLogger("batchalign")

# DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device('cpu')
# PYTORCH_ENABLE_MPS_FALLBACK=1
Expand Down
8 changes: 6 additions & 2 deletions batchalign/pipelines/asr/rev.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from batchalign.errors import *

from batchalign.models import BertUtteranceModel, resolve
from batchalign.models import BertUtteranceModel, BertCantoneseUtteranceModel, resolve

import time
import pathlib
Expand Down Expand Up @@ -49,7 +49,11 @@ def __init__(self, key:str=None, lang="eng", num_speakers=2):
self.__client = apiclient.RevAiAPIClient(key)
if resolve("utterance", lang) != None:
L.debug("Initializing utterance model...")
self.__engine = BertUtteranceModel(resolve("utterance", lang))
if lang != "yue":
self.__engine = BertUtteranceModel(resolve("utterance", lang))
else:
# we have special inference procedure for cantonese
self.__engine = BertCantoneseUtteranceModel(resolve("utterance", lang))
L.debug("Done.")
else:
self.__engine = None
Expand Down
7 changes: 5 additions & 2 deletions batchalign/pipelines/asr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def retokenize_with_engine(intermediate_output, engine):
tmp = []

for s in new_ut:
tmp.append((s, utterance.pop(0)[1]))
try:
tmp.append((s, utterance.pop(0)[1]))
except IndexError:
continue

final_outputs.append((speaker, tmp+[[delim, [None, None]]]))

Expand Down Expand Up @@ -159,7 +162,7 @@ def process_generation(output, lang="eng", utterance_engine=None):
final_words.append([part.strip(), [cur, cur+div]])
cur += div

lang_2 = pycountry.languages.get(alpha_3=lang).alpha_2
lang_2 = "yue" if lang == "yue" else pycountry.languages.get(alpha_3=lang).alpha_2
def catched_num2words(i):
if not i.isdigit():
return i
Expand Down
8 changes: 6 additions & 2 deletions batchalign/pipelines/asr/whisper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from batchalign.document import *
from batchalign.pipelines.base import *
from batchalign.pipelines.asr.utils import *
from batchalign.models import WhisperASRModel, BertUtteranceModel
from batchalign.models import WhisperASRModel, BertUtteranceModel, BertCantoneseUtteranceModel

import pycountry

Expand Down Expand Up @@ -44,7 +44,11 @@ def __init__(self, model=None, lang="eng"):

if resolve("utterance", self.__lang) != None:
L.debug("Initializing utterance model...")
self.__engine = BertUtteranceModel(resolve("utterance", self.__lang))
if lang != "yue":
self.__engine = BertUtteranceModel(resolve("utterance", lang))
else:
# we have special inference procedure for cantonese
self.__engine = BertCantoneseUtteranceModel(resolve("utterance", lang))
L.debug("Done.")
else:
self.__engine = None
Expand Down
6 changes: 3 additions & 3 deletions batchalign/version
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
0.7.14
Feburary 19nd, 2025
machine translation!
0.7.15
Feburary 23rd, 2025
Whisper ASR with Cantonese and tokenization!
9 changes: 5 additions & 4 deletions scratchpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
# engine = infer.BertUtteranceModel("talkbank/CHATUtterance-zh_CN")
# engine("我 现在 想 听 你说 一些 你 自己 经 历 过 的 故 事 好不好 然后 呢 我们 会 一起 讨 论 有 六 种 不同 的 情 景 然后 在 每 一个 情 景 中 都 需要 你 去 讲 一个 关 于 你 自己 的 一个 故 事 小 故 事")

# doc = Document.new(media_path="/Users/houjun/Downloads/trial.mp3", lang="zho")
# print(doc)
# pipe = BatchalignPipeline.new("asr", lang="zho", num_speakers=2, engine="rev")
# doc = Document.new(media_path="/Users/houjun/Documents/Projects/talkbank-alignment/cantonese/input/Untitled.mp3", lang="yue")
# # print(doc)
# pipe = BatchalignPipeline.new("asr", lang="yue", num_speakers=2, asr="whisper")
# res = pipe(doc)
# res

# # with open("schema.json", 'w') as df:
# # json.dump(Document.model_json_schema(), df, indent=4)

# res
# ########### The Batchalign Core Test Harness ###########
# from batchalign.formats.chat.parser import chat_parse_utterance
# from batchalign.formats.chat.generator import check_utterances_ordered
Expand Down

0 comments on commit dd3fa24

Please sign in to comment.