Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 15, 2024
1 parent 98046f8 commit d96c7fa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 172 deletions.
2 changes: 2 additions & 0 deletions src/nanotron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,6 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int:
sock.bind(("localhost", port))
return port
except OSError:
# NOTE: we raise the same message as pytorch distributed raises
# so that rerun_if_address_is_in_use() can catch it!
raise Exception("Address already in use")
183 changes: 11 additions & 172 deletions tests/helpers/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import contextlib
import os
import random
import re
import time
from inspect import signature
from typing import Any, Callable, Dict, List, Optional, Tuple

import torch.cuda
from nanotron.parallel import ParallelContext
import torch.multiprocessing as mp
from nanotron.utils import find_free_port
from packaging import version


Expand Down Expand Up @@ -60,69 +59,6 @@ def mock_os_environ(remove_keys: List[str] = None, update_key_values: Dict[str,
env.update(reverse_change)


class init_process_and_run_func:
"""Initialize distributed process groups and run function."""

def __init__(self, func, args, kwargs, tp: int, dp: int, pp: int):
self.func = func
self.args = args
self.kwargs = kwargs
self.tp = tp
self.dp = dp
self.pp = pp
self.__name__ = self.__class__.__name__
self.__qualname__ = self.__class__.__qualname__

def __call__(self):
with mock_os_environ(update_key_values={"WORLD_SIZE": f"{self.tp * self.dp * self.pp}"}):
# NOTE: we use a different random seed, so that each unit tests don't generate the same port
random.seed(time.time())
parallel_context = ParallelContext(
data_parallel_size=self.dp, pipeline_parallel_size=self.pp, tensor_parallel_size=self.tp
)

assert "parallel_context" not in self.kwargs
self.kwargs["parallel_context"] = parallel_context

self.func(*self.args, **self.kwargs)


# def init_distributed(tp: int, dp: int, pp: int):
# def _init_distributed(func):
# """Wrapper to help initialize distributed nanotron.

# :param func: parallel function that runs on all the process, it requires one of its keyword argument to be "parallel_context"
# """
# nb_gpus = tp * dp * pp
# run_id = uuid.uuid4()

# config = torch.distributed.launcher.LaunchConfig(
# min_nodes=1,
# max_nodes=1,
# nproc_per_node=nb_gpus,
# rdzv_backend="c10d",
# rdzv_configs={"timeout": 60},
# # Setting port to `0` allows `torch` to randomly pick a port: https://pytorch.org/docs/stable/elastic/run.html#stacked-single-node-multi-worker
# # Works only for single node workload.
# rdzv_endpoint="localhost:0",
# run_id=str(run_id),
# max_restarts=0,
# # TODO @thomasw21: Tune as we increase the number of tests
# monitor_interval=1,
# tee=torch.distributed.elastic.multiprocessing.Std(3),
# )

# def wrapper(*args, **kwargs):
# return elastic_launch(
# config=config,
# entrypoint=init_process_and_run_func(func, tp=tp, dp=dp, pp=pp, args=args, kwargs=kwargs),
# )()

# return wrapper

# return _init_distributed


def is_dict_equal(first: Dict, second: Dict, sub_paths: Optional[List[str]] = None) -> Tuple[bool, Optional[str]]:
"""Returns True or False if the dictionaries match, and an additional message when it's False"""
if sub_paths is None:
Expand Down Expand Up @@ -282,9 +218,6 @@ def _run_until_success(*args, **kwargs):
while max_try is None or try_count < max_try:
try:
try_count += 1
# if try_count == max_try:
# raise ValueError("Maximum number of attempts is reached, no more retrying...")

ret = func(*args, **kwargs)
return ret
except exception_type as e:
Expand All @@ -310,118 +243,24 @@ def _run_until_success(*args, **kwargs):
return _wrapper


# class init_process_and_run_func_for_spawn:
# """Initialize distributed process groups and run function."""

# def __init__(self, func, args, kwargs, tp: int, dp: int, pp: int):
# self.func = func
# self.args = args
# self.kwargs = kwargs
# self.tp = tp
# self.dp = dp
# self.pp = pp
# self.__name__ = self.__class__.__name__
# self.__qualname__ = self.__class__.__qualname__

# def __call__(self):
# from nanotron.utils import find_free_port
# port = find_free_port()
# with mock_os_environ(update_key_values={
# "WORLD_SIZE": f"{self.tp * self.dp * self.pp}",
# "MASTER_ADDR": "localhost",
# "MASTER_PORT": str(port)
# }):
# # NOTE: we use a different random seed, so that each unit tests don't generate the same port
# # random.seed(time.time())
# parallel_context = ParallelContext(
# data_parallel_size=self.dp, pipeline_parallel_size=self.pp, tensor_parallel_size=self.tp
# )

# assert "parallel_context" not in self.kwargs
# self.kwargs["parallel_context"] = parallel_context

# self.func(*self.args, **self.kwargs)

# class ProcessSpawner:
# def __init__(self, func, tp, pp, dp, **kwargs):
# self.func = func
# self.tp = tp
# self.pp = pp
# self.dp = dp
# self.kwargs = kwargs
# self.world_size = tp * pp * dp
# self.port = find_free_port()

# @staticmethod
# def setup_dist_env(rank, world_size, port):
# os.environ["WORLD_SIZE"] = str(world_size)
# os.environ["RANK"] = str(rank)
# os.environ["LOCAL_RANK"] = str(rank)
# os.environ["MASTER_ADDR"] = "localhost"
# os.environ["MASTER_PORT"] = str(port)

# def func_wrapper(self, rank):
# # Setup distributed environment for this process
# ProcessSpawner.setup_dist_env(rank, self.world_size, self.port)
# # Call the actual function with adjusted parameters
# self.func(rank=rank, tp=self.tp, pp=self.pp, dp=self.dp, port=self.port, **self.kwargs)

# def spawn(self):
# wrapped_func = partial(self.func_wrapper)
# mp.spawn(wrapped_func, nprocs=self.world_size)


# def global_wrapper(rank, func, tp, pp, dp, port, *args, **kwargs):
# setup_dist_env(rank, tp * pp * dp, port)
# func(tp=tp, pp=pp, dp=dp, *args, **kwargs)


# def global_wrapper(rank, func, tp, pp, dp, port, *args, **kwargs):
# setup_dist_env(rank, tp * pp * dp, port)
# func(tp=tp, pp=pp, dp=dp, **kwargs)


# def spawn(func: Callable, tp: int, pp: int, dp: int, **kwargs):
# from nanotron.utils import find_free_port

# world_size = tp * pp * dp
# port = find_free_port()

# mp.spawn(global_wrapper, args=(func, tp, pp, dp, port, kwargs), nprocs=world_size)


def setup_dist_env(rank, world_size, port):
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)


def global_wrapper(rank, func, tp, pp, dp, port, kwargs):
def setup_dist_env(rank, world_size, port):
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
# NOTE: since we do unit tests in
# a single node => this is fine!
os.environ["LOCAL_RANK"] = str(rank)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)

world_size = tp * pp * dp
setup_dist_env(rank, world_size, port)
func(tp=tp, pp=pp, dp=dp, **kwargs)


def spawn(func: Callable, tp: int, pp: int, dp: int, **kwargs):
import torch.multiprocessing as mp
from nanotron.utils import find_free_port

world_size = tp * pp * dp
port = find_free_port()

# Note that kwargs needs to be passed as part of args in a way that can be unpacked
args = (func, tp, pp, dp, port, kwargs)
mp.spawn(global_wrapper, args=args, nprocs=world_size)


def init_distributed(tp: int, dp: int, pp: int):
def _init_distributed(func):
def wrapper(**kwargs):
import torch.multiprocessing as mp
from nanotron.utils import find_free_port

world_size = tp * pp * dp
port = find_free_port()

Expand Down

0 comments on commit d96c7fa

Please sign in to comment.