Skip to content

Commit

Permalink
Improve return interface of word-based sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
gabegrand committed Aug 12, 2024
1 parent 86de0c1 commit 85476dd
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions hfppl/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ async def sample_word_2(
break

# Sample punctuation, if desired
punctuation = ""
mid_punctuation, end_punctuation = "", ""

mask = set()
if allow_mid_punctuation:
Expand All @@ -132,7 +132,10 @@ async def sample_word_2(
mask = mask | context.lm.masks.END_PUNCTUATION

if mask and await self.sample(context.mask_dist(mask)):
punctuation_token = await self.sample(context.next_token())
punctuation = context.lm.vocab[punctuation_token.token_id]
token = await self.sample(context.next_token())
if token.token_id in context.lm.masks.MID_PUNCTUATION:
mid_punctuation = context.lm.vocab[token.token_id]
if token.token_id in context.lm.masks.END_PUNCTUATION:
end_punctuation = context.lm.vocab[token.token_id]

return word, punctuation
return word, mid_punctuation, end_punctuation

0 comments on commit 85476dd

Please sign in to comment.