Skip to content

Commit

Permalink
remove empty context manager in tp tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 10, 2024
1 parent 8a98cfc commit 29672db
Showing 1 changed file with 39 additions and 48 deletions.
87 changes: 39 additions & 48 deletions tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import os
from contextlib import nullcontext as does_not_raise
from typing import Any

import pytest
import torch
Expand Down Expand Up @@ -153,16 +151,10 @@ def test_row_linear(tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")

# NOTE: we expect all the current configurations don't raise any exceptions
expectation = does_not_raise()
init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(
tp_mode=tp_mode, async_communication=async_communication, expectation=expectation
)
init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(tp_mode=tp_mode, async_communication=async_communication)


def _test_row_linear(
parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool, expectation: Any
):
def _test_row_linear(parallel_context: ParallelContext, tp_mode: TensorParallelLinearMode, async_communication: bool):
if async_communication:
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
out_features = 3
Expand Down Expand Up @@ -223,48 +215,47 @@ def _test_row_linear(

# Test that we get the same output after forward pass
# TODO @kunhao: We may want to have our custom error type
with expectation:
sharded_output = row_linear(random_sharded_input)
reference_output = reference_linear(random_input)

if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
sharded_reference_output = reference_output
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
assert batch_size % parallel_context.tp_pg.size() == 0
sharded_batch_size = batch_size // parallel_context.tp_pg.size()
sharded_reference_output = reference_output[
dist.get_rank(parallel_context.tp_pg)
* sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1)
* sharded_batch_size
]
else:
raise ValueError(f"Unsupported mode: {tp_mode}")
sharded_output = row_linear(random_sharded_input)
reference_output = reference_linear(random_input)

# TODO @thomasw21: Tune tolerance
torch.testing.assert_close(
sharded_output,
sharded_reference_output,
)
if tp_mode is TensorParallelLinearMode.ALL_REDUCE:
sharded_reference_output = reference_output
elif tp_mode is TensorParallelLinearMode.REDUCE_SCATTER:
assert batch_size % parallel_context.tp_pg.size() == 0
sharded_batch_size = batch_size // parallel_context.tp_pg.size()
sharded_reference_output = reference_output[
dist.get_rank(parallel_context.tp_pg)
* sharded_batch_size : (dist.get_rank(parallel_context.tp_pg) + 1)
* sharded_batch_size
]
else:
raise ValueError(f"Unsupported mode: {tp_mode}")

# Test that we get the same gradient after backward pass
sharded_output.sum().backward()
reference_output.sum().backward()
# TODO @thomasw21: Tune tolerance
torch.testing.assert_close(
sharded_output,
sharded_reference_output,
)

# Test that we get the same gradient after backward pass
sharded_output.sum().backward()
reference_output.sum().backward()
torch.testing.assert_close(
row_linear.weight.grad,
reference_linear.weight.grad[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
],
)
if dist.get_rank(parallel_context.tp_pg) == 0:
torch.testing.assert_close(
row_linear.weight.grad,
reference_linear.weight.grad[
:,
dist.get_rank(parallel_context.tp_pg)
* in_features_per_rank : (dist.get_rank(parallel_context.tp_pg) + 1)
* in_features_per_rank,
],
row_linear.bias.grad,
reference_linear.bias.grad,
)
if dist.get_rank(parallel_context.tp_pg) == 0:
torch.testing.assert_close(
row_linear.bias.grad,
reference_linear.bias.grad,
)
else:
assert row_linear.bias is None
else:
assert row_linear.bias is None


@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
Expand Down

0 comments on commit 29672db

Please sign in to comment.