diff --git a/tests/grad_norm_test.py b/tests/grad_norm_test.py index 230ce1caf..2dff5fce3 100644 --- a/tests/grad_norm_test.py +++ b/tests/grad_norm_test.py @@ -212,26 +212,40 @@ def _naive_train_loop( norm_vector_a = [] norm_vector_b = [] + from olmo.torch_util import seed_all + for epoch in range(max_epochs): for idx, batch in enumerate(data_loader): step_count = epoch * len_dataloader + idx optimizer_b.zero_grad() - optimizer_a.zero_grad() - - from olmo.torch_util import seed_all seed_all(step_count) - # send exact same batch to both models logits_b = model_b(batch['input_ids'].to('cuda')).logits - logits_a = model_a(batch['input_ids'].to('cuda')).logits - loss_b = _lm_loss(logits_b, batch['input_ids'].to('cuda').clone()) - loss_a = _lm_loss(logits_a, batch['input_ids'].to('cuda').clone()) loss_b.backward() + + # norm_vector_b.append(optimizer_b.clip_grads_and_collect_metrics(step_count)["total_grad_norm"]) + + _apply_scheduler(cfg, step_count, scheduler_b, optimizer_b) + optimizer_b.step() + + #################################################################### + + optimizer_a.zero_grad() + seed_all(step_count) + + logits_a = model_a(batch['input_ids'].to('cuda')).logits + loss_a = _lm_loss(logits_a, batch['input_ids'].to('cuda').clone()) + loss_a.backward() + # norm_vector_a.append(clip_grad_norm_(model_a.parameters(), max_norm)) + + _apply_scheduler(cfg, step_count, scheduler_a, optimizer_a) + optimizer_a.step() + ########################################################## # # deepcopy model_b and apply cliping at a param level # model_b_params = [] @@ -243,18 +257,6 @@ def _naive_train_loop( # total_norm = clip_grad_norm_(model_b_params, max_norm=1.) ########################################################## - # norm_vector_b.append(optimizer_b.clip_grads_and_collect_metrics(step_count)["total_grad_norm"]) - # norm_vector_a.append(clip_grad_norm_(model_a.parameters(), max_norm)) - - # assert (total_norm - norm_vector_b[-1]).abs() < 1e-4 - - # apply olmo scheduler updates - _apply_scheduler(cfg, step_count, scheduler_b, optimizer_b) - _apply_scheduler(cfg, step_count, scheduler_a, optimizer_a) - - optimizer_b.step() - optimizer_a.step() - if step_count % 100 == 0: print('Step: {:4d}, Loss_b: {:.4f}, Loss_a: {:.4f}'.format(step_count, loss_b, loss_a))