diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index a3d3f0f1239f19..a039947342a6db 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import unittest from enum import Flag, auto -from typing import Dict, List, NamedTuple, Type +from typing import Dict, List, Type from torch.testing._internal.common_distributed import MultiProcessTestCase from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN @@ -76,25 +76,20 @@ class MultiProcess(Flag): } -class Test(NamedTuple): - test_class: Type[RpcAgentTestFixture] - mp_type: MultiProcess - - # This list contains test suites that are agent-agnostic and that only verify # compliance with the generic RPC interface specification. These tests should # *not* make use of implementation details of a specific agent (options, # attributes, ...). These test suites will be instantiated multiple times, once # for each agent (except the faulty agent, which is special). GENERIC_TESTS = [ - Test(RpcTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(DistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(DistOptimizerTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(JitRpcTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(JitDistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(RemoteModuleTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(DdpUnderDistAutogradTest, MultiProcess.SPAWN), - Test(DdpComparisonTest, MultiProcess.SPAWN), + RpcTest, + DistAutogradTest, + DistOptimizerTest, + JitRpcTest, + JitDistAutogradTest, + RemoteModuleTest, + DdpUnderDistAutogradTest, + DdpComparisonTest, ] @@ -102,7 +97,7 @@ class Test(NamedTuple): # These suites should be standalone, and separate from the ones in the generic # list (not subclasses of those!). PROCESS_GROUP_TESTS = [ - Test(ProcessGroupAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN) + ProcessGroupAgentRpcTest ] @@ -110,7 +105,7 @@ class Test(NamedTuple): # These suites should be standalone, and separate from the ones in the generic # list (not subclasses of those!). TENSORPIPE_TESTS = [ - Test(TensorPipeAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN) + TensorPipeAgentRpcTest ] @@ -120,16 +115,16 @@ class Test(NamedTuple): # suites in this list, which were designed to test such behaviors, and not the # ones in the generic list. FAULTY_AGENT_TESTS = [ - Test(FaultyAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(FaultyAgentDistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(JitFaultyAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN), + FaultyAgentRpcTest, + FaultyAgentDistAutogradTest, + JitFaultyAgentRpcTest, ] def generate_tests( prefix: str, mixin: Type[RpcAgentTestFixture], - tests: List[Test], + tests: List[Type[RpcAgentTestFixture]], mp_type_filter: MultiProcess, module_name: str, ) -> Dict[str, Type[RpcAgentTestFixture]]: @@ -149,12 +144,12 @@ def generate_tests( is necessary for pickling to work on them. """ ret: Dict[str, Type[RpcAgentTestFixture]] = {} - for test in tests: + for test_class in tests: for mp_type in MultiProcess: - if mp_type & mp_type_filter & test.mp_type: + if mp_type & mp_type_filter: mp_helper, suffix = MP_HELPERS_AND_SUFFIXES[mp_type] - name = f"{prefix}{test.test_class.__name__}{suffix}" - class_ = type(name, (test.test_class, mixin, mp_helper), dict()) + name = f"{prefix}{test_class.__name__}{suffix}" + class_ = type(name, (test_class, mixin, mp_helper), dict()) class_.__module__ = module_name ret[name] = class_ return ret