Skip to content

Commit 151bae2

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
Training Benchmark Fix World Size 4 (#1797)
Summary: Pull Request resolved: #1797 Training benchmark was broken with multiprocessing issues on servicelab. This diff is tested to ensure world size 4 is able to run on servicelab. World size 8 is not able to finish in the 40 minutes time. More investigation is required and there is also tons of cost in starting processes and tearing down for each sharding paradigm. However, this will allow some level of testing for future diffs to prevent training regression Differential Revision: D54880542 fbshipit-source-id: e8b001471c316c3d1436f4a42a67ab0ae2b51502
1 parent 62d0742 commit 151bae2

File tree

1 file changed

+9
-23
lines changed

1 file changed

+9
-23
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -614,29 +614,16 @@ def multi_process_benchmark(
614614
# pyre-ignore
615615
**kwargs,
616616
) -> BenchmarkResult:
617+
617618
def setUp() -> None:
618-
os.environ["MASTER_ADDR"] = str("localhost")
619-
os.environ["MASTER_PORT"] = str(get_free_port())
620-
os.environ["GLOO_DEVICE_TRANSPORT"] = "TCP"
621-
os.environ["NCCL_SOCKET_IFNAME"] = "lo"
622-
623-
torch.use_deterministic_algorithms(True)
624-
if torch.cuda.is_available():
625-
torch.backends.cudnn.allow_tf32 = False
626-
torch.backends.cuda.matmul.allow_tf32 = False
627-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
628-
629-
def tearDown() -> None:
630-
torch.use_deterministic_algorithms(False)
631-
del os.environ["GLOO_DEVICE_TRANSPORT"]
632-
del os.environ["NCCL_SOCKET_IFNAME"]
633-
if torch.cuda.is_available():
634-
os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
619+
if "MASTER_ADDR" not in os.environ:
620+
os.environ["MASTER_ADDR"] = str("localhost")
621+
os.environ["MASTER_PORT"] = str(get_free_port())
635622

636-
setUp()
637623
assert "world_size" in kwargs
638624
world_size = kwargs["world_size"]
639625

626+
setUp()
640627
benchmark_res_per_rank = []
641628
ctx = mp.get_context("forkserver")
642629
qq = ctx.SimpleQueue()
@@ -659,6 +646,10 @@ def tearDown() -> None:
659646
benchmark_res_per_rank.append(res)
660647
assert len(res.max_mem_allocated) == 1
661648

649+
for p in processes:
650+
p.join()
651+
assert 0 == p.exitcode
652+
662653
total_benchmark_res = BenchmarkResult(
663654
benchmark_res_per_rank[0].short_name,
664655
benchmark_res_per_rank[0].elapsed_time,
@@ -670,11 +661,6 @@ def tearDown() -> None:
670661
# Each rank's BenchmarkResult contains 1 memory measurement
671662
total_benchmark_res.max_mem_allocated[res.rank] = res.max_mem_allocated[0]
672663

673-
for p in processes:
674-
p.join()
675-
assert 0 == p.exitcode
676-
677-
tearDown()
678664
return total_benchmark_res
679665

680666

0 commit comments

Comments
 (0)