Skip to content

Commit

Permalink
[RPC tests] Run DdpUnderDistAutogradTest and DdpComparisonTest with f…
Browse files Browse the repository at this point in the history
…ork too (pytorch#42528)

Summary:
Pull Request resolved: pytorch#42528

It seems it was an oversight that they weren't run. This allows to simplify our auto-generation logic as now all test suites are run in both modes.
ghstack-source-id: 109229969

Test Plan: CI

Reviewed By: pritamdamania87

Differential Revision: D22922151

fbshipit-source-id: 0766a6970c927efb04eee4894b73d4bcaf60b97f
  • Loading branch information
lw authored and facebook-github-bot committed Aug 5, 2020
1 parent 4da602b commit 2501e2b
Showing 1 changed file with 19 additions and 24 deletions.
43 changes: 19 additions & 24 deletions torch/testing/_internal/distributed/rpc_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
import unittest
from enum import Flag, auto
from typing import Dict, List, NamedTuple, Type
from typing import Dict, List, Type

from torch.testing._internal.common_distributed import MultiProcessTestCase
from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN
Expand Down Expand Up @@ -76,41 +76,36 @@ class MultiProcess(Flag):
}


class Test(NamedTuple):
test_class: Type[RpcAgentTestFixture]
mp_type: MultiProcess


# This list contains test suites that are agent-agnostic and that only verify
# compliance with the generic RPC interface specification. These tests should
# *not* make use of implementation details of a specific agent (options,
# attributes, ...). These test suites will be instantiated multiple times, once
# for each agent (except the faulty agent, which is special).
GENERIC_TESTS = [
Test(RpcTest, MultiProcess.FORK | MultiProcess.SPAWN),
Test(DistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN),
Test(DistOptimizerTest, MultiProcess.FORK | MultiProcess.SPAWN),
Test(JitRpcTest, MultiProcess.FORK | MultiProcess.SPAWN),
Test(JitDistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN),
Test(RemoteModuleTest, MultiProcess.FORK | MultiProcess.SPAWN),
Test(DdpUnderDistAutogradTest, MultiProcess.SPAWN),
Test(DdpComparisonTest, MultiProcess.SPAWN),
RpcTest,
DistAutogradTest,
DistOptimizerTest,
JitRpcTest,
JitDistAutogradTest,
RemoteModuleTest,
DdpUnderDistAutogradTest,
DdpComparisonTest,
]


# This list contains test suites that will only be run on the ProcessGroupAgent.
# These suites should be standalone, and separate from the ones in the generic
# list (not subclasses of those!).
PROCESS_GROUP_TESTS = [
Test(ProcessGroupAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN)
ProcessGroupAgentRpcTest
]


# This list contains test suites that will only be run on the TensorPipeAgent.
# These suites should be standalone, and separate from the ones in the generic
# list (not subclasses of those!).
TENSORPIPE_TESTS = [
Test(TensorPipeAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN)
TensorPipeAgentRpcTest
]


Expand All @@ -120,16 +115,16 @@ class Test(NamedTuple):
# suites in this list, which were designed to test such behaviors, and not the
# ones in the generic list.
FAULTY_AGENT_TESTS = [
Test(FaultyAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN),
Test(FaultyAgentDistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN),
Test(JitFaultyAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN),
FaultyAgentRpcTest,
FaultyAgentDistAutogradTest,
JitFaultyAgentRpcTest,
]


def generate_tests(
prefix: str,
mixin: Type[RpcAgentTestFixture],
tests: List[Test],
tests: List[Type[RpcAgentTestFixture]],
mp_type_filter: MultiProcess,
module_name: str,
) -> Dict[str, Type[RpcAgentTestFixture]]:
Expand All @@ -149,12 +144,12 @@ def generate_tests(
is necessary for pickling to work on them.
"""
ret: Dict[str, Type[RpcAgentTestFixture]] = {}
for test in tests:
for test_class in tests:
for mp_type in MultiProcess:
if mp_type & mp_type_filter & test.mp_type:
if mp_type & mp_type_filter:
mp_helper, suffix = MP_HELPERS_AND_SUFFIXES[mp_type]
name = f"{prefix}{test.test_class.__name__}{suffix}"
class_ = type(name, (test.test_class, mixin, mp_helper), dict())
name = f"{prefix}{test_class.__name__}{suffix}"
class_ = type(name, (test_class, mixin, mp_helper), dict())
class_.__module__ = module_name
ret[name] = class_
return ret

0 comments on commit 2501e2b

Please sign in to comment.