Skip to content

Commit

Permalink
Add ThunderCompiler'ed optimizer tests (#1208)
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar authored Oct 3, 2024
1 parent 67d8df6 commit 15753c9
Showing 1 changed file with 74 additions and 4 deletions.
78 changes: 74 additions & 4 deletions thunder/tests/test_dynamo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
import pytest
import torch
import torch.fx
from thunder.tests.framework import instantiate, NOTHING, DynamoThunderExecutor, IS_WINDOWS

from thunder import dtypes
from thunder.dynamo import ThunderCompiler
from thunder import last_traces

import torch
import pytest
from thunder.tests.bf16 import device_supports_bf16
from thunder.tests.framework import (
instantiate,
NOTHING,
DynamoThunderExecutor,
IS_WINDOWS,
requiresCUDA,
)
from thunder.tests.make_tensor import make_tensor


# This will be applied to all tests in this file.
Expand Down Expand Up @@ -365,3 +373,65 @@ def func(x):
actual_grad = torch.autograd.grad(actual, x, g)
expected_grad = torch.autograd.grad(expected, x, g)
torch.testing.assert_close(actual_grad, expected_grad)


@instantiate(
dtypes=(dtypes.float32,),
executors=(DynamoThunderExecutor,),
decorators=(
pytest.mark.parametrize(
"optim",
(
torch.optim.SGD,
torch.optim.Adam,
torch.optim.AdamW,
),
ids=(
"sgd",
"adam",
"adamw",
),
),
pytest.mark.skipif(
IS_WINDOWS,
reason="torch.compile Windows support is still WIP - https://github.com/pytorch/pytorch/issues/122094",
),
),
)
@requiresCUDA
def test_thundercompiler_optim_step(executor, device, dtype, optim):
from thunder.tests.distributed.helper import ToyModel

if not device_supports_bf16(device):
pytest.skip(f"{device} does not support bf16")

tdtype = dtypes.to_torch_dtype(dtype)
model = ToyModel().to(device=device, dtype=tdtype)
optimizer = optim(model.parameters())
jitted_step = executor.make_callable(optimizer.step)

ref_model = ToyModel().to(device=device, dtype=tdtype)
ref_model.load_state_dict(model.state_dict())
ref_optimizer = optim(ref_model.parameters())
ref_optimizer.load_state_dict(optimizer.state_dict())

for i in range(2):
x = make_tensor((1, ToyModel.N_IN), dtype=tdtype, device=device)
x_ref = x.clone().detach()

y = model(x)
y.mean().backward()
jitted_step()
optimizer.zero_grad()

y_ref = ref_model(x_ref)
y_ref.mean().backward()
ref_optimizer.step()
ref_optimizer.zero_grad()

# There could be numerical error, see https://github.com/NVIDIA/Fuser/issues/2664
torch.testing.assert_close(
tuple(model.parameters()),
tuple(ref_model.parameters()),
msg=lambda s: f"{i+1}-iter {s}",
)

0 comments on commit 15753c9

Please sign in to comment.