Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[FEATURE] Serialize Parallel Plan (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jul 4, 2022
1 parent 57437d4 commit 4a73f3a
Show file tree
Hide file tree
Showing 15 changed files with 357 additions and 55 deletions.
1 change: 1 addition & 0 deletions alpa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from alpa.parallel_method import (ShardParallel, PipeshardParallel,
DataParallel, Zero2Parallel, Zero3Parallel,
CreateStateParallel)
from alpa.parallel_plan import plan_to_method
from alpa.pipeline_parallel.primitive_def import mark_pipeline_boundary
from alpa.pipeline_parallel.layer_construction import (manual_remat,
automatic_remat,
Expand Down
1 change: 1 addition & 0 deletions alpa/create_state_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self,
super().__init__(mesh_group=mesh_group,
pipeshard_config=pipeshard_config,
num_batch=1,
layer_option=None,
in_tree=in_tree,
out_tree=out_tree,
static_argnums=static_argnums)
Expand Down
29 changes: 26 additions & 3 deletions alpa/mesh_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
from jax._src.lib import xla_bridge as xb, xla_client as xc, xla_extension as xe
from jax.core import ShapedArray
from jax.interpreters import pxla
from jax.tree_util import tree_flatten, tree_unflatten, PyTreeDef
from jax.tree_util import tree_flatten, tree_unflatten, tree_leaves, PyTreeDef
import numpy as np
import ray

from alpa.device_mesh import (LocalPhysicalDeviceMesh,
DistributedPhysicalDeviceMesh, RemoteArrayRef,
next_array_uuids)
from alpa.global_env import global_config
from alpa.parallel_plan import PlacementSpec, StagePlan
from alpa.parallel_plan import (PlacementSpec, StagePlan, ClusterInfo,
ParallelPlan)
from alpa.shard_parallel.auto_sharding import (AutoShardingOption,
get_input_output_sharding_specs,
make_replicated_spec, HloStatus,
Expand Down Expand Up @@ -76,6 +77,10 @@ def get_output_placement_specs(self):
"""
raise NotImplementedError()

def get_parallel_plan(self):
"""Get the overall parallel plan."""
raise NotImplementedError()

def preshard_dynamic_args(self, *args):
"""Pre-shard the input arguments."""
raise NotImplementedError()
Expand Down Expand Up @@ -205,6 +210,7 @@ def __init__(self,
self.out_tree = out_tree
self.flop_count = flop_count
self.stage_plan = stage_plan
self.auto_sharding_option = stage_plan.auto_sharding_option
self.auto_sharding_objective = stage_plan.auto_sharding_objective

# Read sharding specs
Expand Down Expand Up @@ -324,6 +330,13 @@ def get_output_placement_specs(self):
self.output_sharding_specs,
self.out_tree)

def get_parallel_plan(self):
"""Get the overall parallel plan."""
cluster_info = ClusterInfo(self.physical_mesh.num_hosts,
self.physical_mesh.num_devices_per_host)
return ParallelPlan(cluster_info, None, self.auto_sharding_option, None,
tree_leaves(self.get_input_placement_specs()))

def preshard_dynamic_args(self, *args):
"""Pre-shard the input arguments."""
input_bufs = self.physical_mesh.shard_args_to_bufs(
Expand Down Expand Up @@ -517,6 +530,7 @@ def __init__(self,
self.out_tree = out_tree
self.flop_count = flop_count
self.stage_plan = stage_plan
self.auto_sharding_option = stage_plan.auto_sharding_option
self.auto_sharding_objective = stage_plan.auto_sharding_objective

# Read sharding specs
Expand Down Expand Up @@ -753,6 +767,14 @@ def get_output_placement_specs(self):
self.output_sharding_specs,
self.out_tree)

def get_parallel_plan(self):
"""Get the overall parallel plan."""
cluster_info = ClusterInfo(self.physical_mesh.num_hosts,
self.physical_mesh.num_devices_per_host)
return ParallelPlan(cluster_info, self.num_micro_batches,
self.auto_sharding_option, None,
tree_leaves(self.get_input_placement_specs()))

def get_total_allocation_size(self):
"""Get the total allocated memory size of this executable."""
if isinstance(self.physical_mesh, DistributedPhysicalDeviceMesh):
Expand Down Expand Up @@ -1185,7 +1207,8 @@ def get_index_select_mesh_executable(avals, sharding_specs, index, dim,
as_option = AutoShardingOption()
strategy_config = StagePlan(global_config.compile_random_seed,
device_mesh.shape, 1 << 60,
as_option.all_reduce_threshold, None, -1)
as_option.all_reduce_threshold,
AutoShardingOption(), None, -1)
out_tree = tree_flatten(avals)[1]
executable = NormalMeshDriverExecutable(device_mesh,
hlo_module,
Expand Down
63 changes: 61 additions & 2 deletions alpa/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.xla_bridge import get_backend as default_get_backend
from jax.core import Primitive
from jax.interpreters import partial_eval as pe
from jax.interpreters import partial_eval as pe, pxla
from jax.interpreters import xla, mlir
from jax.interpreters.xla import (xops, jaxpr_subcomp, extend_name_stack,
register_translation, wrap_name,
Expand Down Expand Up @@ -109,7 +109,6 @@ def _rng_normal_lowering(ctx, mu, sigma, *, shape):
mlir.register_lowering(rng_normal_p, _rng_normal_lowering)


# Monkey patch random generator to use the stateful random generator.
def fast_normal(key, shape=(), dtype=dtypes.float_, mu=0.0, sigma=1.0):
shape = core.as_named_shape(shape)
mu = jnp.asarray(mu, dtype)
Expand All @@ -126,6 +125,7 @@ def remove_fold_in(key, data):
return key


# Monkey patch random generator to use the stateful random generator.
jax._src.random.uniform = fast_uniform
jax.random.uniform = fast_uniform
jax._src.random.normal = fast_normal
Expand All @@ -136,6 +136,7 @@ def remove_fold_in(key, data):
jax.random.fold_in = remove_fold_in


# Monkey patch remat to use identity instead of while loop
def _zeros(c, xla_shape):
if xla_shape.is_array():
shape, dtype = xla_shape.dimensions(), xla_shape.numpy_dtype()
Expand Down Expand Up @@ -228,6 +229,64 @@ def _remat_translation_rule(ctx,
del dict_val[pe.remat_call_p]
register_translation(pe.remat_call_p, _remat_translation_rule)


# Support pickle ShardingSpec
def sharding_spec_getstate(self):
sharding = []
for x in self.sharding:
if isinstance(x, pxla.NoSharding):
sharding.append((0,))
elif isinstance(x, pxla.Chunked):
sharding.append((1, x.chunks))
elif isinstance(x, pxla.Unstacked):
sharding.append((2, x.size))
else:
raise ValueError(f"Invalid sharding: {x}")
mesh_mapping = []
for x in self.mesh_mapping:
if isinstance(x, pxla.ShardedAxis):
mesh_mapping.append((0, x.axis))
elif isinstance(x, pxla.Replicated):
mesh_mapping.append((1, x.replicas))
else:
raise ValueError(f"Invalid sharding: {x}")
return (sharding, mesh_mapping)


def sharding_spec_setstate(self, state_tuple):
sharding_encoding, mesh_mapping_encoding = state_tuple

sharding = []
for x in sharding_encoding:
if x[0] == 0:
sharding.append(pxla.NoSharding())
elif x[0] == 1:
sharding.append(pxla.Chunked(x[1]))
elif x[0] == 2:
sharding.append(pxla.Unstacked(x[1]))
else:
raise ValueError(f"Invalid sharding: {x}")

mesh_mapping = []
for x in mesh_mapping_encoding:
if x[0] == 0:
mesh_mapping.append(pxla.ShardedAxis(x[1]))
elif x[0] == 1:
mesh_mapping.append(pxla.Replicated(x[1]))
else:
raise ValueError(f"Invalid sharding: {x}")

# pylint: disable=unnecessary-dunder-call
self.__init__(
sharding=sharding,
mesh_mapping=mesh_mapping,
)


setattr(pxla.ShardingSpec, "__getstate__", sharding_spec_getstate)
setattr(pxla.ShardingSpec, "__setstate__", sharding_spec_setstate)

# Monkey patch tree map to disable some warnings
jax._src.tree_util.tree_multimap = jax._src.tree_util.tree_map
jax.tree_multimap = jax._src.tree_util.tree_map

Expand Down
5 changes: 3 additions & 2 deletions alpa/parallel_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ class PipeshardParallel(ParallelMethod):
Possible choices are {"manual", alpa.AutoLayerOption,
alpa.ManualLayerOption}
stage_option: Options of grouping layers into pipeline stages.
Possible choices are {"uniform", "auto", alpa.AutoStageOption,,
Possible choices are {"uniform", "auto", alpa.AutoStageOption,
alpa.ManualStageOption}
"""

Expand All @@ -178,7 +178,8 @@ def __init__(
stage_option: Optional[Union[StageOption, str]] = None):
self.devices = devices
self.num_micro_batches = num_micro_batches
self.as_option = default_auto_sharding_option or AutoShardingOption()
self.as_option = (default_auto_sharding_option or
AutoShardingOption(prefer_reduce_scatter=True))
self.pipeline_schedule = pipeline_schedule
if layer_option == "manual":
layer_option = ManualLayerOption()
Expand Down
39 changes: 31 additions & 8 deletions alpa/parallel_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
from jax.core import ShapedArray
from jax.interpreters import pxla
from jax.tree_util import PyTreeDef


@dataclass
Expand All @@ -26,23 +25,47 @@ class StagePlan:
logical_mesh_shape: Tuple[int]
all_gather_threshold: int
all_reduce_threshold: int
auto_sharding_option: "AutoShardingOption"
auto_sharding_solution_vector: np.ndarray
auto_sharding_objective: int


@dataclass
class PipelinePlan:
"""The parallel plan for a pipeline."""
forward_stage_layer_ids: Sequence[Sequence[int]]
submesh_physical_shapes: Sequence[Sequence[int]]
submesh_logical_shapes: Sequence[Sequence[int]]
submesh_autosharding_option_dicts: Sequence[dict]
pipeline_schedule: str
layer_option: "LayerOption"
manual_stage_option: "ManualStageOption"


@dataclass
class ClusterInfo:
num_hosts: int
num_devices_per_host: int


@dataclass
class ParallelPlan:
"""The global parallel plan."""
cluster_info: ClusterInfo
num_micro_batches: int
auto_sharding_option: "AutoShardingOption"
pipeline_plan: PipelinePlan
stage_plans: Sequence[StagePlan]
input_placement: PyTreeDef
version: str
input_placement_specs: Sequence[PlacementSpec]


def plan_to_method(plan: ParallelPlan) -> "ParallelMethod":
"""Convert a parallel plan to a parallel method."""
# pylint: disable=import-outside-toplevel
from alpa.parallel_method import ShardParallel, PipeshardParallel

if plan.pipeline_plan is None:
return ShardParallel(num_micro_batches=plan.num_micro_batches,
auto_sharding_option=plan.auto_sharding_option)
else:
return PipeshardParallel(
num_micro_batches=plan.num_micro_batches,
default_auto_sharding_option=plan.auto_sharding_option,
pipeline_schedule=plan.pipeline_plan.pipeline_schedule,
layer_option=plan.pipeline_plan.layer_option,
stage_option=plan.pipeline_plan.manual_stage_option)
9 changes: 6 additions & 3 deletions alpa/pipeline_parallel/compile_executable.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def compile_pipeshard_executable(
mesh_group=virtual_mesh.launched_physical_mesh_group,
pipeshard_config=pipeshard_config,
num_batch=num_microbatch,
layer_option=layer_option,
in_tree=in_tree,
out_tree=out_tree_thunk(),
static_argnums=static_argnums)
Expand Down Expand Up @@ -146,8 +147,7 @@ def compile_pipeshard_executable_internal(

# Construct pipeline stages by merging layers
(jax_pipeline_stages, stage_to_mesh, sliced_virtual_meshes,
logical_mesh_shapes,
autosharding_option_dicts) = cluster_layers_and_slice_mesh(
manual_stage_option) = cluster_layers_and_slice_mesh(
jax_pipeline_layers, virtual_mesh, donation_mapping, acc_grad_outvars,
num_microbatch, micro_batch_size, jax_apply_layers,
apply_grad_global_info, pipeline_schedule, default_as_option,
Expand Down Expand Up @@ -202,7 +202,8 @@ def compile_pipeshard_executable_internal(
xla_stages, total_flops = shard_each_stage(
jax_all_stages, sliced_virtual_meshes, schedule, n_stages, num_meshes,
grad_in_to_out, global_invars, acc_grad_outvars, donate_invars_dict,
num_microbatch, logical_mesh_shapes, autosharding_option_dicts,
num_microbatch, manual_stage_option.submesh_logical_shapes,
manual_stage_option.submesh_autosharding_option_dicts,
default_as_option, output_sharding_dict, name_base, gensym_func)
total_flops *= num_microbatch
debug_compilation_time("shard stages")
Expand All @@ -224,6 +225,8 @@ def compile_pipeshard_executable_internal(
schedule=schedule,
is_batch=batch_invars,
num_batch=num_microbatch,
default_auto_sharding_option=default_as_option,
manual_stage_option=manual_stage_option,
flop_count=total_flops).compile()

debug_compilation_time("runtime emitter")
Expand Down
6 changes: 4 additions & 2 deletions alpa/pipeline_parallel/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from alpa.shard_parallel.auto_sharding import (run_auto_sharding_pass,
run_spmd_partitioner_pass,
get_input_output_sharding_specs,
hlo_sharding_to_sharding_spec)
hlo_sharding_to_sharding_spec,
AutoShardingOption)
from alpa.global_env import global_config
from alpa.util import (OrderedSet, clone_jaxpr, get_compile_options,
jaxpr_to_hlo_module, setup_computation_alias,
Expand Down Expand Up @@ -214,7 +215,8 @@ def dummy_computation(cls, name, logical_mesh_shape, gensym_func):
backend_name = "gpu"
backend = xb.get_backend(backend_name)
stage_plan = StagePlan(global_config.compile_random_seed,
logical_mesh_shape, 1, 1, None, 0)
logical_mesh_shape, 1, 1, AutoShardingOption(),
None, 0)
compiled = compile_dummy_zero_constant(backend,
np.prod(logical_mesh_shape))
sharding_annotated_module = compiled.hlo_modules()[0]
Expand Down
Loading

0 comments on commit 4a73f3a

Please sign in to comment.