Skip to content

Commit

Permalink
Seach error finish up, parallel viterbi across batch
Browse files Browse the repository at this point in the history
  • Loading branch information
Haoran Zhang committed Feb 24, 2025
1 parent 6bd7b05 commit c5fad0e
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 167 deletions.
13 changes: 10 additions & 3 deletions users/zhang/experiments/WER_PPL/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def __init__(
self,
names: List[str],
results: List[Tuple[tk.Variable, tk.Path]],

lm_scales: List[Optional[float]],
search_errors: List[Optional[tk.Path]],
# Reserved for plot setting
):
self.out_summary = self.output_path("summary.csv")
Expand All @@ -29,6 +31,7 @@ def __init__(
self.names = names
self.results = results
self.lm_scales = lm_scales
self.search_errors = search_errors

def tasks(self) -> Iterator[Task]:
yield Task("create_table", mini_task=True)#, rqmt={"cpu": 1, "time": 1, "mem": 4})
Expand Down Expand Up @@ -110,19 +113,23 @@ def create_table(self):
for _, wer_path in self.results:
with open(wer_path.get_path(), "r") as f:
wers.append(json.load(f))
res = dict(zip(self.names, zip(ppls,wers,self.lm_scales)))
res = dict(zip(self.names, zip(ppls,wers,self.lm_scales,self.search_errors)))
# Define filenames
csv_filename = self.out_summary.get_path()

# Prepare the data as a list of lists
table_data = [["Model Name", "Perplexity", "Best Epoch", "lm_scale", "Dev-Clean WER", "Dev-Other WER", "Test-Clean WER",
table_data = [["Model Name", "Perplexity", "Best Epoch", "lm_scale", "search_error", "Dev-Clean WER", "Dev-Other WER", "Test-Clean WER",
"Test-Other WER"]]
for key, values in res.items():
ppl = values[0]
scores = values[1]["best_scores"]
best_epoch = values[1]["best_epoch"]
lm_scale = values[2]
table_data.append([key, ppl, best_epoch, f"{lm_scale:.2f}", scores.get("dev-clean","-"), scores.get("dev-other","-"), scores.get("test-clean","-"),
with open(values[3].get_path()) as f:
import re
search_error = f.readline()
match = re.search(r"([-+]?\d*\.\d+|\d+)%", search_error)
table_data.append([key, ppl, best_epoch, f"{lm_scale:.2f}", match.group(0), scores.get("dev-clean","-"), scores.get("dev-other","-"), scores.get("test-clean","-"),
scores.get("test-other","-")])

# Save to a CSV file manually
Expand Down
263 changes: 140 additions & 123 deletions users/zhang/experiments/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def py():
exp_names_postfix = ""
prune_num = 0
for n_order in [ # 2, 3,
4, 5
4, #5
# 6 Slow recog
]:
exp_names_postfix += str(n_order) + "_"
Expand Down Expand Up @@ -355,7 +355,7 @@ def train_exp(
# empirical_prior=emp_prior if with_prior and empirical_prior else None,
# dev_sets=["dev-other"],
# )
score = recog_exp(
score, _ = recog_exp(
prefix + f"/tune/lm/{str(dc_lm).replace('.', '').replace('-', 'm')}",
task_copy,
model_with_checkpoint,
Expand All @@ -382,39 +382,40 @@ def train_exp(
lm_weight_tune = lm_weight_tune["best_tune"]
lm_scale = default_lm + lm_weight_tune
original_params["lm_weight_tune"] = best_lm_tune # This will be implicitly used by following exps, i.e through decoding_config
for dc_prior in prior_tune_ls:
params["prior_weight"] = default_prior + dc_prior
task_copy = copy.deepcopy(task)
# score = recog_training_exp(
# prefix + f"/tune/prior/{str(dc_prior).replace('.', '').replace('-', 'm')}",
# task_copy,
# model_with_checkpoint,
# recog_def=decoder_def,
# decoding_config=params,
# recog_post_proc_funcs=recog_post_proc_funcs,
# exclude_epochs=exclude_epochs,
# search_mem_rqmt=search_mem_rqmt,
# prior_from_max=prior_from_max,
# empirical_prior=emp_prior if with_prior and empirical_prior else None,
# dev_sets=["dev-other"],
# search_rqmt=search_rqmt,
# )
score, _ = recog_exp(
prefix + f"/tune/prior/{str(dc_prior).replace('.', '').replace('-', 'm')}",
task_copy,
model_with_checkpoint,
epoch=recog_epoch,
recog_def=decoder_def,
decoding_config=params,
recog_post_proc_funcs=recog_post_proc_funcs,
exclude_epochs=exclude_epochs,
search_mem_rqmt=search_mem_rqmt,
prior_from_max=prior_from_max,
empirical_prior=emp_prior if with_prior and empirical_prior else None,
dev_sets=["dev-other"],
search_rqmt=search_rqmt,
)
prior_scores.append(score)
if with_prior:
for dc_prior in prior_tune_ls:
params["prior_weight"] = default_prior + dc_prior
task_copy = copy.deepcopy(task)
# score = recog_training_exp(
# prefix + f"/tune/prior/{str(dc_prior).replace('.', '').replace('-', 'm')}",
# task_copy,
# model_with_checkpoint,
# recog_def=decoder_def,
# decoding_config=params,
# recog_post_proc_funcs=recog_post_proc_funcs,
# exclude_epochs=exclude_epochs,
# search_mem_rqmt=search_mem_rqmt,
# prior_from_max=prior_from_max,
# empirical_prior=emp_prior if with_prior and empirical_prior else None,
# dev_sets=["dev-other"],
# search_rqmt=search_rqmt,
# )
score, _ = recog_exp(
prefix + f"/tune/prior/{str(dc_prior).replace('.', '').replace('-', 'm')}",
task_copy,
model_with_checkpoint,
epoch=recog_epoch,
recog_def=decoder_def,
decoding_config=params,
recog_post_proc_funcs=recog_post_proc_funcs,
exclude_epochs=exclude_epochs,
search_mem_rqmt=search_mem_rqmt,
prior_from_max=prior_from_max,
empirical_prior=emp_prior if with_prior and empirical_prior else None,
dev_sets=["dev-other"],
search_rqmt=search_rqmt,
)
prior_scores.append(score)
if len(prior_scores):
best_prior_tune = GetBestTuneValue(prior_scores, prior_tune_ls).out_best_tune
tk.register_output(prefix + "/tune/prior_best", best_prior_tune)
Expand Down Expand Up @@ -849,53 +850,50 @@ def model_recog_lm(
configs["lm"] = lm

configs.update(hyp_params)

# import pdb # ---------
#pdb.set_trace()
decoder = ctc_decoder(**configs)
enc_spatial_dim_torch = enc_spatial_dim.dyn_size_ext.raw_tensor.cpu()
if use_logsoftmax:
decoder_results = decoder(label_log_prob, enc_spatial_dim_torch)
else:
decoder_results = decoder(logits.raw_tensor.cpu(), enc_spatial_dim_torch)
#--------------------------------------test---------------------------------------------------------
import pdb
pdb.set_trace()
# from flashlight.lib.text.decoder import LexiconDecoderLM
def CTClm_score(seq):
"""
When lm is word level, sequence need to be converted to word...?
"""
state = decoder.lm.start(True)
score = 0
for token in seq:
state, cur_score = decoder.lm.score(state, token)
score += cur_score
score += decoder.lm.finish(state)[1]
return score

ctc_scores = [[l2.score for l2 in l1] for l1 in decoder_results]
ctc_scores = torch.tensor(ctc_scores)
sim_scores = []
ctc_losses = []
viterbi_scores = []
ctc_loss = torch.nn.CTCLoss(model.blank_idx, "none")
for i in range(label_log_prob.shape[0]):
seq = decoder_results[i][0].tokens # These are not padded
log_prob = label_log_prob[i]
viterbi_score = ctc_viterbi_one_seq(log_prob, seq, int(enc_spatial_dim_torch[i].item()), # int(enc_spatial_dim_torch.max())
blank_idx=model.blank_idx)[1]
viterbi_scores.append(viterbi_score)
word_seq = [decoder.word_dict.get_index(word) for word in decoder_results[i][0].words]
lm_score = CTClm_score(word_seq)
sim_scores.append(viterbi_score + hyp_params["lm_weight"]*lm_score)
ctc_losses.append(ctc_loss(log_prob, seq, [log_prob.shape[0]],
[seq.shape[0]]))
pdb.set_trace()
print(f"Average difference of ctc_decoder score and viterbi score: {np.mean(np.array(ctc_scores[:,0])-sim_scores)}")


# assert scores.raw_tensor[0,:] - ctc_viterbi_one_seq(label_log_prob[0], decoder_results[0][0].tokens, int(enc_spatial_dim_torch.max()),
# blank_idx=model.blank_idx) < tolerance, "CTCdecoder does use viterbi decoding!"
# -----------------------------------------------------------------------------------------------
# ctc_scores = [[l2.score for l2 in l1] for l1 in decoder_results]
# ctc_scores = torch.tensor(ctc_scores)
# ctc_scores_forced = []
# sim_scores = []
# lm_scores = []
# ctc_losses = []
# viterbi_scores = []
# sentences = []
# ctc_loss = torch.nn.CTCLoss(model.blank_idx, "none")
# for i in range(label_log_prob.shape[0]):
# seq = decoder_results[i][0].tokens # These are not padded
# log_prob = label_log_prob[i] # These are padded
# alignment, viterbi_score = ctc_viterbi_one_seq(log_prob, seq, int(enc_spatial_dim_torch[i].item()), # int(enc_spatial_dim_torch.max())
# blank_idx=model.blank_idx)
# alignment = torch.cat((alignment, torch.tensor([0 for _ in range(log_prob.shape[0] - alignment.shape[0])])))
# mask = torch.arange(model.target_dim.size + 1, device=log_prob.device).unsqueeze(0).expand(log_prob.shape) == alignment.unsqueeze(1)
# decoder_result = decoder(log_prob.masked_fill(~mask, float('-inf')).unsqueeze(0), enc_spatial_dim_torch[i].unsqueeze(0))
# ctc_scores_forced.append(decoder_result[0][0].score)
# viterbi_scores.append(viterbi_score)
# sentence = " ".join(list(decoder_results[i][0].words))
# word_seq = [decoder.word_dict.get_index(word) for word in decoder_results[i][0].words]
# lm_score = CTClm_score(word_seq, decoder)
# sentences.append(sentence)
# lm_scores.append(lm_score)
# sim_scores.append(viterbi_score + hyp_params["lm_weight"]*lm_score)
# ctc_losses.append(ctc_loss(log_prob, seq, [log_prob.shape[0]],
# [seq.shape[0]]))
# pdb.set_trace()
# pdb.set_trace()
# print(f"Average difference of ctc_decoder score and viterbi score: {np.mean(np.array(ctc_scores[:,0])-sim_scores)}")
#
#
# # assert scores.raw_tensor[0,:] - ctc_viterbi_one_seq(label_log_prob[0], decoder_results[0][0].tokens, int(enc_spatial_dim_torch.max()),
# # blank_idx=model.blank_idx) < tolerance, "CTCdecoder does use viterbi decoding!"
# # -----------------------------------------------------------------------------------------------
if use_lexicon:
print("Use words directly!")
if CHECK_DECODER_CONSISTENCY:
Expand Down Expand Up @@ -1013,6 +1011,50 @@ def ctc_viterbi_one_seq(ctc_log_probs, seq, t_max, blank_idx):
res = torch.tensor(res).flip(0)
return res, score

# TODO: dynmically get the pad_index
def trim_padded_sequence(sequence: torch.Tensor, pad_index: int = 0) -> torch.Tensor:
"""
Removes trailing pad_index elements from a padded 1D sequence tensor.
Args:
sequence (torch.Tensor): 1D tensor containing the sequence with padding.
pad_index (int): The padding index used in the sequence.
Returns:
torch.Tensor: The trimmed sequence without trailing padding.
"""
# Find indices where the sequence is not the pad_index.
non_pad_indices = (sequence != pad_index).nonzero(as_tuple=True)[0]

if non_pad_indices.numel() == 0:
# If the entire sequence is padding, return an empty tensor.
return sequence.new_empty(0)

# The last non-padding index; add 1 for slicing (because slicing is exclusive).
last_index = non_pad_indices[-1].item() + 1

return sequence[:last_index]

def viterbi_batch(inputs, blank_idx):
'''
inputs: (targets.raw_tensor[i, :], label_log_prob[i])
'''
target, log_prob, spatial_dim = inputs
seq = trim_padded_sequence(target)
return ctc_viterbi_one_seq(log_prob, seq, int(spatial_dim.item()), # int(enc_spatial_dim_torch.max())
blank_idx=blank_idx)

def CTClm_score(seq, decoder):
"""
When lm is word level, sequence need to be converted to word...?
"""
state = decoder.lm.start(False)
score = 0
for token in seq:
state, cur_score = decoder.lm.score(state, token)
score += cur_score
score += decoder.lm.finish(state)[1]
return score

def scoring(
*,
Expand Down Expand Up @@ -1097,70 +1139,45 @@ def scoring(
enc_spatial_dim_torch = enc_spatial_dim.dyn_size_ext.raw_tensor.cpu()
scores = []
# `````````````````test```````````````````````
def CTClm_score(seq):
"""
When lm is word level, sequence need to be converted to word...?
"""
state = decoder.lm.start(True)
score = 0
for token in seq:
state, cur_score = decoder.lm.score(state, token)
score += cur_score
score += decoder.lm.finish(state)[1]
return score

# TODO: dynmically get the pad_index
def trim_padded_sequence(sequence: torch.Tensor, pad_index: int = 0) -> torch.Tensor:
"""
Removes trailing pad_index elements from a padded 1D sequence tensor.
Args:
sequence (torch.Tensor): 1D tensor containing the sequence with padding.
pad_index (int): The padding index used in the sequence.
Returns:
torch.Tensor: The trimmed sequence without trailing padding.
"""
# Find indices where the sequence is not the pad_index.
non_pad_indices = (sequence != pad_index).nonzero(as_tuple=True)[0]

if non_pad_indices.numel() == 0:
# If the entire sequence is padding, return an empty tensor.
return sequence.new_empty(0)

# The last non-padding index; add 1 for slicing (because slicing is exclusive).
last_index = non_pad_indices[-1].item() + 1

return sequence[:last_index]
#ctc_scores = [[l2.score for l2 in l1] for l1 in decoder_results]
# ctc_scores = torch.tensor(ctc_scores)
# TODO: add ctc_loss for the case CTCdecoder use sum
# ctc_loss = torch.nn.CTCLoss(model.blank_idx, "none")
viterbi_scores = []
lm_scores = []

import pdb
pdb.set_trace()
#pdb.set_trace()
'''Parrallezing viterbi across batch'''
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
from functools import partial
target_list = list(targets.raw_tensor.cpu())
log_prob_list = list(label_log_prob.cpu())
spatial_dim_list = list(enc_spatial_dim_torch.cpu())
viterbi_batch_partial = partial(viterbi_batch, blank_idx=model.blank_idx)
cpu_cores = multiprocessing.cpu_count()
print(f"using {min(cpu_cores,32)} workers")
with ProcessPoolExecutor(max_workers=min(cpu_cores,32)) as executor:
alignments, viterbi_scores = zip(*list(executor.map(viterbi_batch_partial, zip(target_list, log_prob_list, spatial_dim_list))))

lm_scores = []
for i in range(label_log_prob.shape[0]):
seq = trim_padded_sequence(targets.raw_tensor[i,:])
target_words = " ".join([decoder.tokens_dict.get_entry(idx) for idx in seq])
target_words = target_words.replace("@@ ","")
#target_words = target_words.replace(" <s>", "")
log_prob = label_log_prob[i]
viterbi_score = ctc_viterbi_one_seq(log_prob, seq, int(enc_spatial_dim_torch[i].item()), # int(enc_spatial_dim_torch.max())
blank_idx=model.blank_idx)[1]
viterbi_scores.append(viterbi_score)
word_seq = [decoder.word_dict.get_index(word) for word in target_words.split()]
lm_score = CTClm_score(word_seq)
lm_score = CTClm_score(word_seq, decoder)
lm_scores.append(lm_score)
scores.append(viterbi_score + hyp_params["lm_weight"]*lm_score)
pdb.set_trace()
beam_dim = Dim(1, name="beam_dim")
scores = Tensor("scores", dims=[batch_dim, beam_dim], dtype="float32",
scores.append(viterbi_scores[i] + hyp_params["lm_weight"]*lm_score)
score_dim = Dim(1, name="score_dim")
scores = Tensor("scores", dims=[batch_dim, score_dim], dtype="float32",
raw_tensor=torch.tensor(scores).reshape([label_log_prob.shape[0], 1]))

torch.cuda.empty_cache()
#````````````````````````````````````````````

return scores
return scores, score_dim


# RecogDef API
Expand Down
Loading

0 comments on commit c5fad0e

Please sign in to comment.