Skip to content

Commit

Permalink
Merge pull request #546 from kmyusk/master
Browse files Browse the repository at this point in the history
Fix periodic inference during training
  • Loading branch information
karpathy authored Jun 8, 2024
2 parents d396cd1 + 1ee0b43 commit 0774519
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,21 +772,18 @@ def get_lr(it):
if (args.sample_every > 0 \
and (step % args.sample_every == 0 or last_step)) \
and master_process:
# TODO I'm not sure why this sampling code (which worked fine)
# doesn't work anymore when placed here debug later
if False:
model.eval()
# before we end, let's also do one round of inference
# we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
start_ids = [enc.eot_token]
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
max_new_tokens = 32
temperature = 1.0
top_k = 40
y = raw_model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
print0('---------------')
print0(enc.decode(y[0].tolist()))
print0('---------------')
model.eval()
# before we end, let's also do one round of inference
# we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
start_ids = [enc.eot_token]
xg = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
max_new_tokens = 32
temperature = 1.0
top_k = 40
yg = raw_model.generate(xg, max_new_tokens, temperature=temperature, top_k=top_k)
print0('---------------')
print0(enc.decode(yg[0].tolist()))
print0('---------------')

# bit confusing: we want to make sure to eval and sample on 0th iteration
# but also after the very last iteration. so we loop for step <= num_iterations
Expand Down

0 comments on commit 0774519

Please sign in to comment.