Skip to content

Commit

Permalink
Fix BUG: fix RuntimeError: value cannot be converted to type at: :Hal…
Browse files Browse the repository at this point in the history
…f without overflow happend on some devices
  • Loading branch information
Jacob-Zhou committed Nov 15, 2024
1 parent f671a18 commit 7bb106f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
1 change: 1 addition & 0 deletions lmcsc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
OOV_CHAR = "□"

MIN = -1e32
HALF_MIN = -1e4
MAX = 1e32
EPS = 1e-7

Expand Down
12 changes: 9 additions & 3 deletions lmcsc/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
GenerateBeamDecoderOnlyOutput,
)

from lmcsc.common import MIN
from lmcsc.common import HALF_MIN, MIN
from lmcsc.obversation_generator import BaseObversationGenerator

def token_transformation_to_probs(self, observed_sequence: str) -> Tuple[List[int], List[float], dict]:
Expand Down Expand Up @@ -289,7 +289,10 @@ def distortion_guided_beam_search(

# template for the distortion model
template_weight = self.probs_template * self.distortion_model_smoothing
template_weight[self.token_length > 1] = MIN
if template_weight.dtype == torch.float16:
template_weight[self.token_length > 1] = HALF_MIN
else:
template_weight[self.token_length > 1] = MIN

# clear the cache
self.cache = {}
Expand Down Expand Up @@ -479,7 +482,10 @@ def distortion_guided_beam_search(
else:
this_text = ""
if this_text in predicted_sequences_set:
next_token_scores[batch_i][candidate_i] = MIN
if next_token_scores.dtype == torch.float16:
next_token_scores[batch_i][candidate_i] = HALF_MIN
else:
next_token_scores[batch_i][candidate_i] = MIN
else:
predicted_sequences_set.add(this_text)
if len(predicted_sequences_set) > (max(2, 1 + n_eos_tokens) * num_beams):
Expand Down

0 comments on commit 7bb106f

Please sign in to comment.