Skip to content

Commit

Permalink
catch overlaping port from find_free_port
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Feb 15, 2024
1 parent 558b341 commit 98046f8
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/3d_parallelism_unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,5 +58,6 @@ jobs:
--color=yes \
--durations=0 \
--ignore tests/kernels \
--ignore tests/fp8 \
--verbose \
tests/
11 changes: 9 additions & 2 deletions src/nanotron/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def get_rank(group: Optional[ProcessGroup] = None) -> int: # pylint: disable=fu
return result


def initialize_torch_distributed(port: Optional[int] = None):
def initialize_torch_distributed():
"""Initializes torch distributed with the environment variables"""
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand All @@ -259,7 +259,14 @@ def initialize_torch_distributed(port: Optional[int] = None):
backend = "gloo"

# Call the init process.
port = find_free_port() if port is None else port
# port = find_free_port() if port is None else port

port = os.getenv("MASTER_PORT")
if port is None:
port = find_free_port()
else:
port = int(port)

init_method = f"env://localhost:{port}"
dist.init_process_group(
init_method=init_method, backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout
Expand Down
5 changes: 2 additions & 3 deletions src/nanotron/parallel/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Literal, Optional, Tuple
from typing import Literal, Tuple

import numpy as np
import torch
Expand All @@ -15,7 +15,6 @@ def __init__(
tensor_parallel_size: int,
pipeline_parallel_size: int,
data_parallel_size: int,
port: Optional[int] = None,
backend: DistributedBackend = "nccl",
):
"""Initialize parallel context."""
Expand Down Expand Up @@ -49,7 +48,7 @@ def __init__(
assert backend == "nccl", "Only nccl backend is supported for now."

if not dist.is_initialized():
dist.initialize_torch_distributed(port)
dist.initialize_torch_distributed()

world_size = int(os.getenv("WORLD_SIZE", "1"))
ranks = list(range(world_size))
Expand Down
8 changes: 4 additions & 4 deletions src/nanotron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import inspect
import math
import os
from contextlib import ExitStack, contextmanager
from typing import Callable, ContextManager, List, Optional
import random
import socket
from contextlib import ExitStack, contextmanager
from typing import Callable, ContextManager, List, Optional

import torch
from packaging import version
Expand Down Expand Up @@ -159,5 +159,5 @@ def find_free_port(min_port: int = 2000, max_port: int = 65000) -> int:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
sock.bind(("localhost", port))
return port
except OSError as e:
raise e
except OSError:
raise Exception("Address already in use")

0 comments on commit 98046f8

Please sign in to comment.