diff --git a/lmcsc/common.py b/lmcsc/common.py index 32337c4..93debd3 100644 --- a/lmcsc/common.py +++ b/lmcsc/common.py @@ -5,6 +5,7 @@ OOV_CHAR = "□" MIN = -1e32 +HALF_MIN = -1e4 MAX = 1e32 EPS = 1e-7 diff --git a/lmcsc/generation.py b/lmcsc/generation.py index c677687..88ba22e 100644 --- a/lmcsc/generation.py +++ b/lmcsc/generation.py @@ -17,7 +17,7 @@ GenerateBeamDecoderOnlyOutput, ) -from lmcsc.common import MIN +from lmcsc.common import HALF_MIN, MIN from lmcsc.obversation_generator import BaseObversationGenerator def token_transformation_to_probs(self, observed_sequence: str) -> Tuple[List[int], List[float], dict]: @@ -289,7 +289,10 @@ def distortion_guided_beam_search( # template for the distortion model template_weight = self.probs_template * self.distortion_model_smoothing - template_weight[self.token_length > 1] = MIN + if template_weight.dtype == torch.float16: + template_weight[self.token_length > 1] = HALF_MIN + else: + template_weight[self.token_length > 1] = MIN # clear the cache self.cache = {} @@ -479,7 +482,10 @@ def distortion_guided_beam_search( else: this_text = "" if this_text in predicted_sequences_set: - next_token_scores[batch_i][candidate_i] = MIN + if next_token_scores.dtype == torch.float16: + next_token_scores[batch_i][candidate_i] = HALF_MIN + else: + next_token_scores[batch_i][candidate_i] = MIN else: predicted_sequences_set.add(this_text) if len(predicted_sequences_set) > (max(2, 1 + n_eos_tokens) * num_beams):