Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for gradient accumulation test #125

Merged
merged 3 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, model):
]
self.log_logit_mean = False
self.device = 0
self.precision = "amp_bfloat16"
self.precision = "float32"
self.wd = 0.033
self.lr = 3e-3
self.beta1 = 0.9
Expand Down Expand Up @@ -100,7 +100,7 @@ def create_train_fixtures(model = "open_lm_11m"):
args.train_num_samples = args.batch_size

# increase learning rate and remove warmup for maximize change to model weights
args.lr = 2
args.lr = 1e-3
args.warmup = 0

# create base models
Expand Down
59 changes: 44 additions & 15 deletions tests/test_grad_accum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import torch
import torch.multiprocessing as mp
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from open_lm.model import create_model

from open_lm.train import train_one_epoch
from open_lm.main import random_seed
from tests.shared import create_train_fixtures


def _grad_acc_helper(test_fsdp, accs=[2, 1], threshold=1e-7):
def _grad_acc_helper(test_fsdp, accs=[1, 2], threshold=1e-7):
if test_fsdp:
world_size = 1
mp.spawn(
Expand All @@ -36,22 +37,34 @@ def _grad_acc_helper_fsdp(rank, world_size, accs, threshold):


def _grad_acc_helper_single(test_fsdp, accs=[2, 1], threshold=1e-7):
args, model, data, optimizer, scheduler, loss = create_train_fixtures()
random_seed()
# List of tuples with (args, model, data, optimizer, scheduler, loss)
fixtures = []
for _ in accs:
random_seed()
(args, model, data, optimizer, scheduler, loss) = create_train_fixtures()

if test_fsdp:
args.fsdp = True
args.fsdp_amp = True
# HACK: Currently, AdamW optimizer leads to different results with gradient accumulation.
optimizer = torch.optim.SGD([p for p in model.parameters() if p.requires_grad], lr=args.lr)

# create models
random_seed()
model_accum_grad = copy.deepcopy(model).to(args.device)
model_no_accum_grad = copy.deepcopy(model_accum_grad).to(args.device)
if test_fsdp:
args.fsdp = True
args.fsdp_amp = True
fixtures.append((args, model, data, optimizer, scheduler, loss))

model1 = fixtures[0][1]
for fixture in fixtures[1:]:
model2 = fixture[1]
for p1, p2 in zip(model1.parameters(), model2.parameters()):
assert torch.allclose(p1, p2, atol=threshold), "Parameter mismatch at init"

# train on mock data with/without grad accumulation for one epoch
for model, accum_freq in zip([model_accum_grad, model_no_accum_grad], accs):
for fixture, accum_freq in zip(fixtures, accs):
args, model, data, optimizer, scheduler, loss = fixture
if test_fsdp:
model = FSDP(model)
args.accum_freq = accum_freq
random_seed()
train_one_epoch(
model=model,
data=data,
Expand All @@ -65,15 +78,31 @@ def _grad_acc_helper_single(test_fsdp, accs=[2, 1], threshold=1e-7):
args=args,
)

# check that models weights are similar (within some threshold)
for p1, p2 in zip(model_accum_grad.parameters(), model_no_accum_grad.parameters()):
assert torch.allclose(p1, p2, atol=threshold)
model1 = fixtures[0][1]
failed_grad = []
failed_weight = []
for fixture in fixtures[1:]:
model2 = fixture[1]
for (n1, p1), (n2, p2) in zip(model1.named_parameters(), model2.named_parameters()):
if not torch.allclose(p1.grad, p2.grad, atol=threshold):
failed_grad.append(n1)
print(f"Gradient mismatch at {n1}, {n2}")

if not torch.allclose(p1, p2, atol=threshold):
failed_weight.append(n1)
print(f"Weight mismatch at {n1}, {n2}")
assert not failed_grad, f"Failed gradient checks at: {failed_grad}"
assert not failed_weight, f"Failed weight checks at: {failed_weight}"


def test_no_accumulation_matches():
_grad_acc_helper(test_fsdp=False, accs=[1, 1])


def test_grad_acc():
_grad_acc_helper(test_fsdp=False)
_grad_acc_helper(test_fsdp=False, accs=[1, 2])


@pytest.mark.gpu
def test_grad_acc_fsdp():
_grad_acc_helper(test_fsdp=True)
_grad_acc_helper(test_fsdp=True, accs=[1, 2])