Skip to content

Commit

Permalink
update: test_muon_rank
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Dec 3, 2024
1 parent 756d7ea commit dede2ed
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pytest
import torch
Expand All @@ -13,6 +15,7 @@
WSAM,
DynamicLossScaler,
Lookahead,
Muon,
PCGrad,
load_optimizer,
)
Expand Down Expand Up @@ -772,3 +775,23 @@ def test_muon_zero_power_via_newton_schulz_5():

with pytest.raises(ValueError):
_ = zero_power_via_newton_schulz_5(x[0], num_steps=6)


@pytest.mark.parametrize('rank', ['1', '0'])
def test_muon_rank(rank):
os.environ['RANK'] = rank

model = nn.Sequential(
nn.Conv1d(1, 1, 1),
nn.Conv1d(1, 1, 1),
nn.Conv1d(1, 1, 1),
)

optimizer = Muon(model.parameters())
optimizer.zero_grad()

model[0].weight.grad = torch.randn(1, 1, 1)
model[1].weight.grad = torch.randn(1, 1, 1)
model[2].weight.grad = torch.randn(1, 1, 1)

optimizer.step()

0 comments on commit dede2ed

Please sign in to comment.