Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[RUNTIME] Fix ray placement group (#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Aug 16, 2022
1 parent 9802b3b commit e342ad2
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 151 deletions.
21 changes: 8 additions & 13 deletions alpa/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
|
PhysicalDeviceMesh (one device mesh)
|
MeshHostWorker (one host in a devie mesh)
MeshHostWorker (one host in a device mesh)
Besides, we have two additional classes: VirtualPhysicalMesh and
LogicalDeviceMesh. They are only used during compilation time. They are used to
Expand Down Expand Up @@ -957,7 +957,7 @@ def _launch_xla_servers(self):
self.service_server = xla_client._xla.get_distributed_runtime_service(
self.server_address, self.num_hosts)
logger.debug(f"Success to start XLA gRPC server on port: {port}...")
time.sleep(0.5)
time.sleep(0.4)

# Launch workers
self.workers = []
Expand All @@ -969,8 +969,6 @@ def _launch_xla_servers(self):
device_bundle_idx_list = get_bundle_idx(placement_group, self.node_ips)

for i in range(self.num_hosts):
bundle_index = device_bundle_idx_list[i]

# Set XLA environment variables
env_vars = {
"ALPA_IS_WORKER":
Expand Down Expand Up @@ -1010,14 +1008,16 @@ def _launch_xla_servers(self):
""), # For libnccl-net.so
})

bundle_index = device_bundle_idx_list[i]

# Launch the DaemonMoveWorker
cls = ray.remote(num_cpus=1e-3)(DaemonMoveWorker)
cls = ray.remote(num_cpus=0)(DaemonMoveWorker)
move_worker = cls.options(
placement_group=placement_group,
placement_group_bundle_index=bundle_index).remote()

# Launch the MeshHostWorker
cls = ray.remote(num_cpus=1e-3,
cls = ray.remote(num_cpus=0,
num_gpus=self.num_devices_per_host)(MeshHostWorker)
worker = cls.options(placement_group=placement_group,
placement_group_bundle_index=bundle_index,
Expand Down Expand Up @@ -2008,7 +2008,6 @@ class DeviceCluster:

def __init__(self):
# pylint: disable=import-outside-toplevel
self.placement_group = None
ray_global_node = ray_worker._global_node
try:
self.head_info = ray_global_node.address_info
Expand All @@ -2032,12 +2031,9 @@ def __init__(self):
assert number.is_integer()
self.host_num_devices.append(int(number))

def create_placment_group(self):
"""
Create a placement group for the current device cluster.
"""
# Create placement group
self.placement_group = create_placement_group(self.num_hosts,
self.host_num_devices[0])
self.host_num_devices)

def delete_placement_group(self):
"""remove the placement group for the current device cluster."""
Expand Down Expand Up @@ -2132,7 +2128,6 @@ def init_global_cluster(cluster: str):
ray.init(address="auto", ignore_reinit_error=True)
update_jax_platform("cpu")
global_cluster = DeviceCluster()
global_cluster.create_placment_group()
global_virtual_physical_mesh = (
global_cluster.get_virtual_physical_mesh())

Expand Down
4 changes: 2 additions & 2 deletions alpa/pipeline_parallel/stage_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ray.exceptions import RayActorError
import tqdm

from alpa.device_mesh import DeviceCluster, VirtualPhysicalMesh
from alpa.device_mesh import get_global_cluster, VirtualPhysicalMesh
from alpa.global_env import global_config
from alpa.pipeline_parallel.computation import (
JaxPipelineComputation, merge_marked_jaxprs_with_named_call)
Expand Down Expand Up @@ -460,7 +460,7 @@ def get_compute_cost(
num_hosts, num_devices = submesh
tic = time()
if global_config.profile_with_whole_ray_cluster:
whole_cluster_virtual_mesh = DeviceCluster(
whole_cluster_virtual_mesh = get_global_cluster(
).get_virtual_physical_mesh()
sliced_virtual_meshes = (
whole_cluster_virtual_mesh.slice_profiling_submeshes(
Expand Down
33 changes: 19 additions & 14 deletions alpa/pipeline_parallel/stage_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
run_backend_compilation,
hlo_sharding_to_sharding_spec)
from alpa.util import (clone_jaxpr, get_shard_shape, jaxpr_to_hlo_module,
OrderedSet, retrieve_placement_group)
OrderedSet, retrieve_placement_group,
get_num_available_gpus)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -225,7 +226,7 @@ class CompileWorkerPool(BaseWorkerPoolWrapper):

def __init__(self, num_cpus, debug_mode=False):
super().__init__()
worker_cls = ray.remote(num_cpus=0)(CompileWorker)
worker_cls = ray.remote(num_cpus=1)(CompileWorker)
self.actors = [worker_cls.remote() for _ in range(num_cpus)]
self.pool = ActorPool(self.actors)
self.local_worker = CompileWorker() if debug_mode else None
Expand Down Expand Up @@ -344,11 +345,9 @@ def restart(self, forced):
class ProfileWorkerPool(BaseWorkerPoolWrapper):
"""A pool of ProfileWorker for distributed profiling."""

def __init__(self, virtual_meshes):
def __init__(self, virtual_meshes, placement_group):
super().__init__()
worker_cls = ray.remote(num_cpus=1e-3)(ProfileWorker)
# retrieve the placement group
placement_group = retrieve_placement_group()
worker_cls = ray.remote(ProfileWorker)
self.actors = [
worker_cls.options(placement_group=placement_group).remote(mesh)
for mesh in virtual_meshes
Expand Down Expand Up @@ -413,19 +412,23 @@ class HloCostModelProfileWorkerPool(BaseWorkerPoolWrapper):
cost model to estimate the cost.
"""

def __init__(self, num_cpus, num_gpus, prof_result, mesh_num_devices,
def __init__(self, num_cpus, placement_group, prof_result, mesh_num_devices,
num_micro_batches):
super().__init__()
num_gpus = get_num_available_gpus(placement_group)
gpu_per_cpu = 1
while gpu_per_cpu * num_cpus > num_gpus:
gpu_per_cpu /= 2
env_vars = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"}
worker_cls = ray.remote(num_cpus=1,
worker_cls = ray.remote(num_cpus=0,
num_gpus=gpu_per_cpu)(HloCostModelProfileWorker)
self.actors = [
worker_cls.options(runtime_env={
"env_vars": env_vars
}).remote(prof_result, mesh_num_devices, num_micro_batches)
worker_cls.options(
runtime_env={
"env_vars": env_vars
},
placement_group=placement_group,
).remote(prof_result, mesh_num_devices, num_micro_batches)
for _ in range(num_cpus)
]
self.pool = ActorPool(self.actors)
Expand Down Expand Up @@ -473,20 +476,22 @@ def profile_all(stages, compiled_outputs: Sequence[CompileOutput], meshes,
# pylint: disable=unused-argument
compute_cost, max_n_succ_stages, is_profiled = mesh_cached_result

placement_group = retrieve_placement_group()

if auto_stage_option.use_hlo_cost_model:
num_cpus = int(
min(max(ray.available_resources()["CPU"] // 2, 1), len(stages)))
num_gpus = int(ray.available_resources()["GPU"])
mesh_num_devices = meshes[0].num_devices
prof_database = ProfilingResultDatabase()
prof_database.load(auto_stage_option.profiling_database_filename)
prof_result = prof_database.query("default", meshes[0].shape)
profile_workers = HloCostModelProfileWorkerPool(num_cpus, num_gpus,
profile_workers = HloCostModelProfileWorkerPool(num_cpus,
placement_group,
prof_result,
mesh_num_devices,
num_micro_batches)
else:
profile_workers = ProfileWorkerPool(meshes)
profile_workers = ProfileWorkerPool(meshes, placement_group)

succ_compile_ct = 0
for stage_id, (compiled_output,
Expand Down
Loading

0 comments on commit e342ad2

Please sign in to comment.