Skip to content

Commit

Permalink
Merge pull request #20 from probcomp/gg/renormalize
Browse files Browse the repository at this point in the history
Renormalize LMNextToken.sample() probs to fix floating point errors
  • Loading branch information
alex-lew authored Jan 4, 2025
2 parents f172d8b + f40ee1b commit a191fca
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions hfppl/distributions/lmcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ async def log_prob(self, x):

async def sample(self):
probs = np.exp(self.ctx.next_token_logprobs)
probs /= np.sum(probs) # Renormalize to fix floating point errors
token_id = np.random.choice(len(probs), p=(probs))
self.ctx.tokens.append(token_id)
logprob = self.ctx.next_token_logprobs[token_id]
Expand Down

0 comments on commit a191fca

Please sign in to comment.