Skip to content

Commit

Permalink
FIX BUG: move list to tensor outside the torch.jit.script.
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Zhou committed Jan 6, 2025
1 parent 1e269be commit c1c61f5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ We provide a simple Python API for the corrector:

```python
from lmcsc import LMCorrector
import torch

corrector = LMCorrector(
model="Qwen/Qwen2.5-0.5B",
Expand Down
13 changes: 6 additions & 7 deletions lmcsc/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_distortion_probs(
return batch_indices, beam_indices, token_indices, distortion_probs, original_token_lengths, force_eos

@torch.jit.script
def distortion_probs_to_cuda(
def distortion_probs_to_cuda_jit(
template_tensor: torch.Tensor,
force_eos: torch.Tensor,
batch_size: int,
Expand All @@ -112,7 +112,8 @@ def distortion_probs_to_cuda(
_batch_indices: List[int],
_beam_indices: List[int],
_token_indices: List[int],
_distortion_probs: List[float]) -> torch.Tensor:
_distortion_probs: torch.Tensor,
) -> torch.Tensor:
"""
Transfers distortion probabilities to a CUDA tensor.
Expand All @@ -139,9 +140,7 @@ def distortion_probs_to_cuda(
distortion_probs = template_tensor.masked_fill(force_eos[:, None], MIN).view(batch_size, num_beams, vocab_size)

# Update distortion probabilities with the provided values
distortion_probs[_batch_indices, _beam_indices, _token_indices] = torch.tensor(
_distortion_probs, device=template_tensor.device
)
distortion_probs[_batch_indices, _beam_indices, _token_indices] = _distortion_probs

return distortion_probs.view(batch_beam_size, vocab_size)

Expand Down Expand Up @@ -390,7 +389,7 @@ def distortion_guided_beam_search(
# get the observed sequences and calculate the distortion probs
force_eos = torch.tensor(force_eos, device=input_ids.device, dtype=torch.bool)

distortion_probs = distortion_probs_to_cuda(
distortion_probs = distortion_probs_to_cuda_jit(
template_weight,
force_eos,
batch_size,
Expand All @@ -400,7 +399,7 @@ def distortion_guided_beam_search(
_batch_indices,
_beam_indices,
_token_indices,
_distortion_probs
torch.tensor(_distortion_probs, device=template_weight.device, dtype=template_weight.dtype)
)

# calculate the length reward
Expand Down
2 changes: 1 addition & 1 deletion lmcsc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def decorate_model_instance(self):
self.set_convert_ids_to_tokens()

self.tokenizer.padding_side = "left"
self.model.probs_template = torch.ones((self.model.vocab_size,)).to(
self.model.probs_template = torch.ones((self.model.vocab_size,), dtype=self.model.dtype).to(
self.model.device
)

Expand Down

0 comments on commit c1c61f5

Please sign in to comment.