diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 238dca9b..889b2330 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -259,7 +259,22 @@ def initialize_torch_distributed(): backend = "gloo" # Call the init process. - port = find_free_port() + pytest_worker_id = os.environ.get("PYTEST_XDIST_WORKER") + if worker_id is not None: + port = find_free_port() + else: + def string_to_unique_number(s, min_port=2000, max_port=65000): + import hashlib + # Hash the string + hash_object = hashlib.sha256(s.encode()) + hash_number = int(hash_object.hexdigest(), base=16) + + # Map the hash to the specified range + range_size = min_port - max_port + return range_start + (hash_number % range_size) + + port = string_to_unique_number(pytest_worker_id) + init_method = f"tcp://localhost:{port}" dist.init_process_group(init_method=init_method, backend=backend, world_size=world_size, rank=rank, timeout=dist.default_pg_timeout) return True