Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
ananyahjha93 committed May 15, 2024
1 parent ddad073 commit 4d31ac8
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions tests/grad_norm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,26 +212,40 @@ def _naive_train_loop(
norm_vector_a = []
norm_vector_b = []

from olmo.torch_util import seed_all

for epoch in range(max_epochs):
for idx, batch in enumerate(data_loader):
step_count = epoch * len_dataloader + idx

optimizer_b.zero_grad()
optimizer_a.zero_grad()

from olmo.torch_util import seed_all
seed_all(step_count)

# send exact same batch to both models
logits_b = model_b(batch['input_ids'].to('cuda')).logits
logits_a = model_a(batch['input_ids'].to('cuda')).logits

loss_b = _lm_loss(logits_b, batch['input_ids'].to('cuda').clone())
loss_a = _lm_loss(logits_a, batch['input_ids'].to('cuda').clone())

loss_b.backward()

# norm_vector_b.append(optimizer_b.clip_grads_and_collect_metrics(step_count)["total_grad_norm"])

_apply_scheduler(cfg, step_count, scheduler_b, optimizer_b)
optimizer_b.step()

####################################################################

optimizer_a.zero_grad()
seed_all(step_count)

logits_a = model_a(batch['input_ids'].to('cuda')).logits
loss_a = _lm_loss(logits_a, batch['input_ids'].to('cuda').clone())

loss_a.backward()

# norm_vector_a.append(clip_grad_norm_(model_a.parameters(), max_norm))

_apply_scheduler(cfg, step_count, scheduler_a, optimizer_a)
optimizer_a.step()

##########################################################
# # deepcopy model_b and apply cliping at a param level
# model_b_params = []
Expand All @@ -243,18 +257,6 @@ def _naive_train_loop(
# total_norm = clip_grad_norm_(model_b_params, max_norm=1.)
##########################################################

# norm_vector_b.append(optimizer_b.clip_grads_and_collect_metrics(step_count)["total_grad_norm"])
# norm_vector_a.append(clip_grad_norm_(model_a.parameters(), max_norm))

# assert (total_norm - norm_vector_b[-1]).abs() < 1e-4

# apply olmo scheduler updates
_apply_scheduler(cfg, step_count, scheduler_b, optimizer_b)
_apply_scheduler(cfg, step_count, scheduler_a, optimizer_a)

optimizer_b.step()
optimizer_a.step()

if step_count % 100 == 0:
print('Step: {:4d}, Loss_b: {:.4f}, Loss_a: {:.4f}'.format(step_count, loss_b, loss_a))

Expand Down

0 comments on commit 4d31ac8

Please sign in to comment.