diff --git a/README.md b/README.md index fc3c237..5ae8a99 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ We provide a simple Python API for the corrector: ```python from lmcsc import LMCorrector +import torch corrector = LMCorrector( model="Qwen/Qwen2.5-0.5B", diff --git a/lmcsc/generation.py b/lmcsc/generation.py index 2248c73..e13a15f 100644 --- a/lmcsc/generation.py +++ b/lmcsc/generation.py @@ -102,7 +102,7 @@ def get_distortion_probs( return batch_indices, beam_indices, token_indices, distortion_probs, original_token_lengths, force_eos @torch.jit.script -def distortion_probs_to_cuda( +def distortion_probs_to_cuda_jit( template_tensor: torch.Tensor, force_eos: torch.Tensor, batch_size: int, @@ -112,7 +112,8 @@ def distortion_probs_to_cuda( _batch_indices: List[int], _beam_indices: List[int], _token_indices: List[int], - _distortion_probs: List[float]) -> torch.Tensor: + _distortion_probs: torch.Tensor, + ) -> torch.Tensor: """ Transfers distortion probabilities to a CUDA tensor. @@ -139,9 +140,7 @@ def distortion_probs_to_cuda( distortion_probs = template_tensor.masked_fill(force_eos[:, None], MIN).view(batch_size, num_beams, vocab_size) # Update distortion probabilities with the provided values - distortion_probs[_batch_indices, _beam_indices, _token_indices] = torch.tensor( - _distortion_probs, device=template_tensor.device - ) + distortion_probs[_batch_indices, _beam_indices, _token_indices] = _distortion_probs return distortion_probs.view(batch_beam_size, vocab_size) @@ -390,7 +389,7 @@ def distortion_guided_beam_search( # get the observed sequences and calculate the distortion probs force_eos = torch.tensor(force_eos, device=input_ids.device, dtype=torch.bool) - distortion_probs = distortion_probs_to_cuda( + distortion_probs = distortion_probs_to_cuda_jit( template_weight, force_eos, batch_size, @@ -400,7 +399,7 @@ def distortion_guided_beam_search( _batch_indices, _beam_indices, _token_indices, - _distortion_probs + torch.tensor(_distortion_probs, device=template_weight.device, dtype=template_weight.dtype) ) # calculate the length reward diff --git a/lmcsc/model.py b/lmcsc/model.py index 386455b..7ac9e28 100644 --- a/lmcsc/model.py +++ b/lmcsc/model.py @@ -99,7 +99,7 @@ def decorate_model_instance(self): self.set_convert_ids_to_tokens() self.tokenizer.padding_side = "left" - self.model.probs_template = torch.ones((self.model.vocab_size,)).to( + self.model.probs_template = torch.ones((self.model.vocab_size,), dtype=self.model.dtype).to( self.model.device )