Skip to content

Commit

Permalink
skip ALL_REDUCE with async comm
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Jan 31, 2024
1 parent abe42c6 commit 0a754a1
Showing 1 changed file with 4 additions and 17 deletions.
21 changes: 4 additions & 17 deletions tests/test_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from contextlib import nullcontext as does_not_raise
from typing import Any

import pytest
Expand Down Expand Up @@ -147,25 +146,13 @@ def _test_column_linear(


@pytest.mark.parametrize("tp,dp,pp", [pytest.param(i, 1, 1) for i in range(1, min(4, available_gpus()) + 1)])
@pytest.mark.parametrize(
"tp_mode,async_communication,expectation",
[
pytest.param(TensorParallelLinearMode.ALL_REDUCE, False, does_not_raise()),
pytest.param(TensorParallelLinearMode.REDUCE_SCATTER, False, does_not_raise()),
pytest.param(TensorParallelLinearMode.REDUCE_SCATTER, True, does_not_raise()),
pytest.param(
TensorParallelLinearMode.ALL_REDUCE,
True,
pytest.raises(
ValueError,
match=r"Cf this: https://github.com/huggingface/nanotron/blob/bf82cded9eef1ba77864b48e65bffefad4076339/src/nanotron/core/parallel/tensor_parallel/nn.py#L132",
),
),
],
)
@pytest.mark.parametrize("tp_mode", list(TensorParallelLinearMode))
@pytest.mark.parametrize("async_communication", [False, True])
def test_row_linear(
tp: int, dp: int, pp: int, tp_mode: TensorParallelLinearMode, async_communication: bool, expectation: Any
):
if tp_mode is TensorParallelLinearMode.ALL_REDUCE and async_communication:
pytest.skip("ALL_REDUCE mode does not support async communication")
init_distributed(tp=tp, dp=dp, pp=pp)(_test_row_linear)(
tp_mode=tp_mode, async_communication=async_communication, expectation=expectation
)
Expand Down

0 comments on commit 0a754a1

Please sign in to comment.