From 2501e2b12d4f86971c03536501ac068b2775a00b Mon Sep 17 00:00:00 2001 From: Luca Wehrstedt Date: Wed, 5 Aug 2020 15:01:13 -0700 Subject: [PATCH] [RPC tests] Run DdpUnderDistAutogradTest and DdpComparisonTest with fork too (#42528) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42528 It seems it was an oversight that they weren't run. This allows to simplify our auto-generation logic as now all test suites are run in both modes. ghstack-source-id: 109229969 Test Plan: CI Reviewed By: pritamdamania87 Differential Revision: D22922151 fbshipit-source-id: 0766a6970c927efb04eee4894b73d4bcaf60b97f --- .../_internal/distributed/rpc_utils.py | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/torch/testing/_internal/distributed/rpc_utils.py b/torch/testing/_internal/distributed/rpc_utils.py index a3d3f0f1239f1..a039947342a6d 100644 --- a/torch/testing/_internal/distributed/rpc_utils.py +++ b/torch/testing/_internal/distributed/rpc_utils.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 import unittest from enum import Flag, auto -from typing import Dict, List, NamedTuple, Type +from typing import Dict, List, Type from torch.testing._internal.common_distributed import MultiProcessTestCase from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN @@ -76,25 +76,20 @@ class MultiProcess(Flag): } -class Test(NamedTuple): - test_class: Type[RpcAgentTestFixture] - mp_type: MultiProcess - - # This list contains test suites that are agent-agnostic and that only verify # compliance with the generic RPC interface specification. These tests should # *not* make use of implementation details of a specific agent (options, # attributes, ...). These test suites will be instantiated multiple times, once # for each agent (except the faulty agent, which is special). GENERIC_TESTS = [ - Test(RpcTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(DistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(DistOptimizerTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(JitRpcTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(JitDistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(RemoteModuleTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(DdpUnderDistAutogradTest, MultiProcess.SPAWN), - Test(DdpComparisonTest, MultiProcess.SPAWN), + RpcTest, + DistAutogradTest, + DistOptimizerTest, + JitRpcTest, + JitDistAutogradTest, + RemoteModuleTest, + DdpUnderDistAutogradTest, + DdpComparisonTest, ] @@ -102,7 +97,7 @@ class Test(NamedTuple): # These suites should be standalone, and separate from the ones in the generic # list (not subclasses of those!). PROCESS_GROUP_TESTS = [ - Test(ProcessGroupAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN) + ProcessGroupAgentRpcTest ] @@ -110,7 +105,7 @@ class Test(NamedTuple): # These suites should be standalone, and separate from the ones in the generic # list (not subclasses of those!). TENSORPIPE_TESTS = [ - Test(TensorPipeAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN) + TensorPipeAgentRpcTest ] @@ -120,16 +115,16 @@ class Test(NamedTuple): # suites in this list, which were designed to test such behaviors, and not the # ones in the generic list. FAULTY_AGENT_TESTS = [ - Test(FaultyAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(FaultyAgentDistAutogradTest, MultiProcess.FORK | MultiProcess.SPAWN), - Test(JitFaultyAgentRpcTest, MultiProcess.FORK | MultiProcess.SPAWN), + FaultyAgentRpcTest, + FaultyAgentDistAutogradTest, + JitFaultyAgentRpcTest, ] def generate_tests( prefix: str, mixin: Type[RpcAgentTestFixture], - tests: List[Test], + tests: List[Type[RpcAgentTestFixture]], mp_type_filter: MultiProcess, module_name: str, ) -> Dict[str, Type[RpcAgentTestFixture]]: @@ -149,12 +144,12 @@ def generate_tests( is necessary for pickling to work on them. """ ret: Dict[str, Type[RpcAgentTestFixture]] = {} - for test in tests: + for test_class in tests: for mp_type in MultiProcess: - if mp_type & mp_type_filter & test.mp_type: + if mp_type & mp_type_filter: mp_helper, suffix = MP_HELPERS_AND_SUFFIXES[mp_type] - name = f"{prefix}{test.test_class.__name__}{suffix}" - class_ = type(name, (test.test_class, mixin, mp_helper), dict()) + name = f"{prefix}{test_class.__name__}{suffix}" + class_ = type(name, (test_class, mixin, mp_helper), dict()) class_.__module__ = module_name ret[name] = class_ return ret