From f449537648829c02c003aa5239aced774759bfb4 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 19 May 2023 01:54:12 +0000 Subject: [PATCH 01/36] Add solver --- setup.py | 1 + slapo/sharding/solver.py | 527 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 528 insertions(+) create mode 100644 slapo/sharding/solver.py diff --git a/setup.py b/setup.py index e5a401ea..768c5bb9 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,7 @@ def setup(): long_description_content_type="text/markdown", setup_requires=[], install_requires=[ + "z3-solver", "packaging", "psutil", ], diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py new file mode 100644 index 00000000..8ce85d04 --- /dev/null +++ b/slapo/sharding/solver.py @@ -0,0 +1,527 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import operator +import torch +from torch import nn +from torch import fx +import torch.nn.functional as F +from torch.fx.passes.shape_prop import ShapeProp +import z3 + + +class ShardSpec: + def __init__(self, spec): + self.map = {"RR": 0, "RS": 1, "SR": 2} + if isinstance(spec, str): + self.spec = spec + else: + self.spec = list(self.map.keys())[list(self.map.values()).index(spec)] + + @property + def id(self): + return self.map[self.spec] + + def __str__(self): + return self.spec + + +class FxOp: + def __init__(self, node): + self.node = node + self.name = node.name + self.args = [] + self.users = [] + self.out_shape = node.meta["tensor_meta"].shape + self.out_size = self.out_shape[-2] * self.out_shape[-1] + + def add_arg(self, arg): + self.args.append(arg) + + def add_user(self, user): + self.users.append(user) + + def set_concrete_values(self, inp): + self.input_v = inp + + def generate_input_z3(self): + raise NotImplementedError + + def generate_output(self): + raise NotImplementedError + + def generate_output_z3(self): + raise NotImplementedError + + def calculate_comm_cost(self): + raise NotImplementedError + + def calculate_comm_cost_z3(self): + raise NotImplementedError + + +class PlaceholderOp(FxOp): + def generate_input_z3(self): + # input should not be sharded + return [], [] + + def generate_output(self): + return ShardSpec("RR").id + + def generate_output_z3(self): + return ShardSpec("RR").id + + def calculate_comm_cost(self): + return 0 + + def calculate_comm_cost_z3(self): + return 0 + + +class ElementwiseOp(FxOp): + def generate_input_z3(self): + return [], [] + + def generate_output(self): + return self.args[0].generate_output() + + def generate_output_z3(self): + return self.args[0].generate_output_z3() + + def calculate_comm_cost_z3(self): + return 0 + + +class BinaryOp(FxOp): + def generate_input_z3(self): + self.lhs = z3.BitVec(f"{self.name}_0", 2) + self.rhs = z3.BitVec(f"{self.name}_1", 2) + compute_constraints = [self.lhs == self.rhs] + format_constraints = [z3.ULE(self.lhs, 3), z3.ULE(self.rhs, 3)] + constraints = compute_constraints + format_constraints + return [self.lhs, self.rhs], constraints + + def set_concrete_values(self, lhs, rhs): + self.lhs_v = lhs + self.rhs_v = rhs + + def generate_output(self): + return self.lhs_v + + def generate_output_z3(self): + return self.lhs + + def calculate_comm_cost_z3(self): + # output remains the same spec as the inputs + return 0 + + +class LayerNormOp(FxOp): + pass + + +class SoftmaxOp(FxOp): + pass + + +class ViewOp(FxOp): + def generate_input_z3(self): + self.input = z3.BitVec(f"{self.name}_0", 2) + format_constraints = [z3.ULE(self.input, 3)] + return [self.input], format_constraints + + def generate_output(self): + return self.input_v + + def generate_output_z3(self): + return self.input + + def calculate_comm_cost_z3(self): + # output remains the same spec as the inputs + return 0 + + +class PermuteOp(FxOp): + def __init__(self, node, z3_graph): + # FIXME: Suppose permute is always (0, 2, 1, 3) + super().__init__(node) + self.z3_graph = z3_graph + self.output_map = {"RR": "RR", "RS": "RS", "SR": "RR"} + self.prev_op = self.z3_graph[self.node.args[0].name] + + def generate_input_z3(self): + return [], [] + # self.input = z3.BitVec(f"{self.name}_0", 2) + # compute_constraints = [self.input == self.prev_op.generate_output_z3()] + # format_constraints = [z3.ULE(self.input, 3)] + # constraints = compute_constraints + format_constraints + # return [self.input], constraints + + def generate_output(self): + return ShardSpec( + self.output_map[ShardSpec(self.prev_op.generate_output()).spec] + ).id + + def generate_output_z3(self): + result = 3 # invalid + for inp, out in self.output_map.items(): + result = z3.If( + self.prev_op.generate_output_z3() == ShardSpec(inp).id, + ShardSpec(out).id, + result, + ) + return result + + def calculate_comm_cost_z3(self): + # output remains the same spec as the inputs + return 0 + + +class TransposeOp(FxOp): + def __init__(self, node, z3_graph): + # FIXME: Suppose always transpose the last two dims + super().__init__(node) + self.z3_graph = z3_graph + self.output_map = {"RR": "RR", "RS": "SR", "SR": "RS"} + self.prev_op = self.z3_graph[self.node.args[0].name] + + def generate_input_z3(self): + return [], [] + # self.input = z3.BitVec(f"{self.name}_0", 2) + # compute_constraints = [] + # format_constraints = [z3.ULE(self.input, 3)] + # constraints = compute_constraints + format_constraints + # return [self.input], constraints + + def generate_output(self): + return ShardSpec( + self.output_map[ShardSpec(self.prev_op.generate_output()).spec] + ).id + + def generate_output_z3(self): + result = 3 # invalid + for inp, out in self.output_map.items(): + result = z3.If( + self.prev_op.generate_output_z3() == ShardSpec(inp).id, + ShardSpec(out).id, + result, + ) + return result + + def calculate_comm_cost_z3(self): + # output remains the same spec as the inputs + return 0 + + +class DropoutOp(FxOp): + pass + + +class MatmulOp(FxOp): + def __init__(self, node, mod=None, is_linear=False): + super().__init__(node) + self.lhs_shape = node.args[0].meta["tensor_meta"].shape + self.rhs_shape = ( + node.args[1].meta["tensor_meta"].shape + if not is_linear + else mod.weight.shape + ) + self.out_shape = ( + node.meta["tensor_meta"].shape + if not isinstance(node.meta["tensor_meta"], list) + else node.meta["tensor_meta"][0].shape + ) + self.lhs_size = self.lhs_shape[-2] * self.lhs_shape[-1] + print(self.name, self.lhs_shape, self.rhs_shape, self.out_shape) + if is_linear: + # weight is transposed + assert self.lhs_shape[-1] == self.rhs_shape[-1] + self.rhs_size = self.rhs_shape[-1] * self.rhs_shape[-2] + self.out_size = self.lhs_shape[-2] * self.rhs_shape[-2] + else: + assert self.lhs_shape[-1] == self.rhs_shape[-2] + self.rhs_size = self.rhs_shape[-2] * self.rhs_shape[-1] + self.out_size = self.lhs_shape[-2] * self.rhs_shape[-1] + self.output_map = {"RR": "RS", "RS": "RR", "SR": "SR"} + self.comm_cost_map = { # map from input spec to comm cost + "RR": 0, + "RS": self.out_size, # all_reduce + "SR": 0, + } + + def generate_input_z3(self): + self.lhs = z3.BitVec(f"{self.name}_0", 2) # input + self.rhs = z3.BitVec(f"{self.name}_1", 2) # weight + + compute_constraints = [ + z3.Or( + [ + z3.And( + self.lhs == ShardSpec("RR").id, self.rhs == ShardSpec("RS").id + ), + z3.And( + self.lhs == ShardSpec("RS").id, self.rhs == ShardSpec("SR").id + ), + z3.And( + self.lhs == ShardSpec("SR").id, self.rhs == ShardSpec("RR").id + ), + ] + ) + ] + format_constraints = [z3.ULE(self.lhs, 3), z3.ULE(self.rhs, 3)] + constraints = compute_constraints + format_constraints + # force to shard + # constraints += [self.lhs != ShardSpec("RR").id, self.rhs != ShardSpec("RR").id] + return [self.lhs, self.rhs], constraints + + def set_concrete_values(self, lhs, rhs): + self.lhs_v = lhs + self.rhs_v = rhs + + def generate_output(self): + return ShardSpec(self.output_map[ShardSpec(self.lhs_v).spec]).id + + def generate_output_z3(self): + result = 3 # invalid + for inp, out in self.output_map.items(): + result = z3.If(self.lhs == ShardSpec(inp).id, ShardSpec(out).id, result) + return result + + def calculate_comm_cost(self): + return self.comm_cost_map[ShardSpec(self.lhs_v).spec] + + def calculate_comm_cost_z3(self): + result = 1e12 # invalid + for inp, cost in self.comm_cost_map.items(): + result = z3.If(self.lhs == ShardSpec(inp).id, cost, result) + return result + + +class Solver: + def __init__(self, gm, p) -> None: + self.gm = gm + self.gm.graph.eliminate_dead_code() + self.named_modules = dict(self.gm.named_modules()) + self.z3_graph = {} # {node_name: FxOp} + self.goal = [] + self.cost = None + self.num_devices = p + self.reshard_cost_map = { + "RR": {"RR": 0, "RS": 0, "SR": 0}, + "RS": {"RR": 1 / p, "RS": 0, "SR": 1 / p - 1 / (p * p)}, + "SR": {"RR": 1 / p, "RS": 1 / p - 1 / (p * p), "SR": 0}, + } + + def inference_shape(self, inputs): + sp = ShapeProp(self.gm) + sp.propagate(*inputs) + for node in self.gm.graph.nodes: + if "tensor_meta" in node.meta: + if isinstance(node.meta["tensor_meta"], list): + lst = node.meta["tensor_meta"] + else: + lst = [node.meta["tensor_meta"]] + for data in lst: + print(node.name, data) + + def calculate_reshard_cost(self, prev, curr, shape): + return int( + self.reshard_cost_map[ShardSpec(prev).spec][ShardSpec(curr).spec] * shape + ) + + def calculate_reshard_cost_z3(self, prev, curr, shape): + result = 1e12 # invalid + for in_spec, target_map in self.reshard_cost_map.items(): + tmp = 1e12 # invalid + for out_spec, val in target_map.items(): + tmp = z3.If(curr == ShardSpec(out_spec).id, int(val * shape), tmp) + result = z3.If(prev == ShardSpec(in_spec).id, tmp, result) + return result + + def construct_z3_graph(self): + print(self.gm.graph) + for node in self.gm.graph.nodes: + if "tensor_meta" not in node.meta: + continue + if node.op == "placeholder": # input + new_op = PlaceholderOp(node) + elif node.op == "call_module": + mod = self.named_modules[node.target] + if isinstance(mod, nn.Linear): + new_op = MatmulOp( + node, + mod=mod, + is_linear=True, + ) + elif isinstance(mod, (nn.LayerNorm, nn.Dropout)): + new_op = ElementwiseOp(node) + else: + raise RuntimeError(f"Unsupported module: {node.target}") + elif node.op == "call_function": + if node.target == torch.matmul: + new_op = MatmulOp(node) + elif node.target in [ + F.relu, + F.gelu, + F.softmax, + torch._C._nn.gelu, + operator.truediv, + ]: + new_op = ElementwiseOp(node) + elif node.target in [operator.add]: + new_op = BinaryOp(node) + else: + raise RuntimeError(f"Unsupported function: {node.target}") + elif node.op == "call_method": + if node.target == "view": + new_op = ViewOp(node) + elif node.target == "permute": + new_op = PermuteOp(node, self.z3_graph) + elif node.target == "transpose": + new_op = TransposeOp(node, self.z3_graph) + elif node.target == "contiguous": + continue + else: + raise RuntimeError(f"Unsupported method: {node.target}") + else: # output + continue + # construct edges + if not (node.op == "call_method" and node.target == "view"): + for arg in node.args: + if not isinstance(arg, fx.Node): + continue + new_op.add_arg(self.z3_graph[arg.name]) + self.z3_graph[arg.name].add_user(new_op) + self.z3_graph[node.name] = new_op + print(self.z3_graph) + + def construct_z3_problem(self): + bitvecs = {} + input_constraints = [] + comm_costs = [] + for op in self.z3_graph.values(): + # no need to include output, since output can be obtained from inputs + inputs, constraints = op.generate_input_z3() + for inp in inputs: + bitvecs[str(inp)] = inp + # input constraints + input_constraints.extend(constraints) + # communication cost + comm_costs.append(op.calculate_comm_cost_z3()) + + reshard_costs = [] + for op in self.z3_graph.values(): + for i, arg in enumerate(op.args): + name = f"{op.name}_{i}" + if name not in bitvecs: + continue + curr = bitvecs[name] + prev = arg.generate_output_z3() + reshard_costs.append( + self.calculate_reshard_cost_z3(prev, curr, arg.out_size) + ) + # final output should not be sharded + if len(op.users) == 0: + next_inp = ShardSpec("RR").id + reshard_costs.append( + self.calculate_reshard_cost_z3( + op.generate_output_z3(), next_inp, op.out_size + ) + ) + + self.cost = sum(comm_costs) + sum(reshard_costs) + self.goal += input_constraints + + def solve(self, inputs, max_iter=100): + self.inference_shape(inputs) + self.construct_z3_graph() + self.construct_z3_problem() + sol = z3.Solver() + sol.add(self.goal) + max_cost = 1e12 + for it in range(max_iter): + print(f"=================== Iter {it} ===================") + sol.push() + assert self.cost is not None + sol.add(self.cost < max_cost) + # print(sol) + sat = sol.check() + if str(sat) == "unsat": + print("Cannot find better solutions") + break + mod = sol.model() + print(mod) + results = {d.name(): mod[d] for d in mod.decls()} + max_cost = 0 + for name, op in self.z3_graph.items(): + flag = True + if isinstance(op, (MatmulOp, BinaryOp)): + lhs = results[f"{name}_0"] + rhs = results[f"{name}_1"] + op.set_concrete_values(lhs, rhs) + output = op.generate_output() + print( + f"{name}: {ShardSpec(lhs)} x {ShardSpec(rhs)} = {ShardSpec(output)}" + ) + if isinstance(op, MatmulOp): + print( + f" {name}: {op.lhs_shape} x {op.rhs_shape} = {op.out_shape}" + ) + comm_cost = op.calculate_comm_cost() + max_cost += comm_cost + print(f" Comm cost: {comm_cost}") + elif f"{name}_0" in results: + inp = results[f"{name}_0"] + op.set_concrete_values(inp) + output = op.generate_output() + print(f"{name}: {ShardSpec(inp)} -> {ShardSpec(output)}") + else: + continue # flag = False + if flag: + # resharding cost + for i, arg in enumerate(op.args): + curr = results[f"{name}_{i}"] + prev = arg.generate_output() + reshard_cost = self.calculate_reshard_cost( + prev, curr, arg.out_size + ) + max_cost += reshard_cost + print( + f" Resharding cost ({arg.name}) {ShardSpec(prev)} -> {ShardSpec(curr)}: {reshard_cost}" + ) + if len(op.users) == 0: + next_inp = ShardSpec("RR").id + reshard_cost = self.calculate_reshard_cost( + output, next_inp, op.out_size + ) + max_cost += reshard_cost + print( + f" Last resharding cost {ShardSpec(output)} -> {ShardSpec(next_inp)}: {reshard_cost}" + ) + print("Total cost:", max_cost) + sol.pop() + # generate sharding sequence + self.best_spec = results + print() + print("Best solution:") + for name, op in self.z3_graph.items(): + if not isinstance(op, MatmulOp): + continue + weight = self.best_spec[f"{name}_1"] + if weight == ShardSpec("RS").id: + dim = 0 # transposed + elif weight == ShardSpec("SR").id: + dim = 1 + else: + continue + if op.node.op == "call_module": + print(f'sch["{op.node.target}"].shard("weight", dim={dim})') + if dim == 0: + print(f'sch["{op.node.target}"].shard("bias", dim={dim})') + if ( + self.best_spec[f"{name}_0"] == ShardSpec("RS").id + and self.best_spec[f"{name}_1"] == ShardSpec("SR").id + ): + print( + f'sch["{op.node.target}"].sync(mode="fwd_post", sync_op_or_fn="all_reduce")' + ) \ No newline at end of file From 340121c5f518c9447ce0766dd43027a87d2deadc Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 21 May 2023 06:11:03 +0000 Subject: [PATCH 02/36] Add test_autoshard --- tests/test_autoshard.py | 52 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 tests/test_autoshard.py diff --git a/tests/test_autoshard.py b/tests/test_autoshard.py new file mode 100644 index 00000000..e18718f8 --- /dev/null +++ b/tests/test_autoshard.py @@ -0,0 +1,52 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# pylint: disable=unused-argument +""" +Test different resharding schemes on MLP. +Verified by different combinations of resharding schemes. +""" + +import os +import copy +import argparse + +import torch +from torch import nn +from torch import fx +import torch.nn.functional as F + +import slapo +from slapo.logger import get_logger + +logger = get_logger(__name__) + +# Config for verification +bs = 8 +seq_len = 1024 +hidden_size = 1024 + + +class MLP(nn.Module): + def __init__(self, dim): + super().__init__() + self.fc1 = nn.Linear(dim, 4 * dim) + self.fc2 = nn.Linear(4 * dim, dim) + + def forward(self, x): + x = self.fc1(x) + x = F.gelu(x) + x = self.fc2(x) + return x + + +with slapo.init_empty_weights(): + mlp = MLP(hidden_size) + +sch = slapo.create_schedule(mlp) +sch.trace() +assert isinstance(sch.mod, fx.GraphModule) + +from slapo.sharding import Solver + +sol = Solver(sch.mod, p=8) +sol.solve([torch.randn(bs, seq_len, hidden_size)]) From 2dff6e2e00e3026cb346402b2b715a894a856488 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 21 May 2023 06:12:12 +0000 Subject: [PATCH 03/36] Add dump_node --- setup.py | 1 + slapo/sharding/__init__.py | 1 + slapo/sharding/solver.py | 24 ++++++++++++++++++++++-- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 768c5bb9..b99ff40c 100644 --- a/setup.py +++ b/setup.py @@ -129,6 +129,7 @@ def setup(): setup_requires=[], install_requires=[ "z3-solver", + "tabulate", "packaging", "psutil", ], diff --git a/slapo/sharding/__init__.py b/slapo/sharding/__init__.py index 051919f5..4c310cf5 100644 --- a/slapo/sharding/__init__.py +++ b/slapo/sharding/__init__.py @@ -10,3 +10,4 @@ scatter_forward_output, reduce_forward_output, ) +from .solver import Solver diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 8ce85d04..14348718 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -8,6 +8,11 @@ import torch.nn.functional as F from torch.fx.passes.shape_prop import ShapeProp import z3 +from tabulate import tabulate + +from ..logger import get_logger + +logger = get_logger(__name__) class ShardSpec: @@ -299,6 +304,7 @@ def calculate_comm_cost_z3(self): class Solver: def __init__(self, gm, p) -> None: + assert isinstance(gm, fx.GraphModule), "gm must be a GraphModule" self.gm = gm self.gm.graph.eliminate_dead_code() self.named_modules = dict(self.gm.named_modules()) @@ -314,7 +320,13 @@ def __init__(self, gm, p) -> None: def inference_shape(self, inputs): sp = ShapeProp(self.gm) + # Tackle the case of meta device + device = self.gm.named_parameters().__next__()[1].device + inputs = [inp.to(device) for inp in inputs] sp.propagate(*inputs) + + def dump_node(self): + res = [] for node in self.gm.graph.nodes: if "tensor_meta" in node.meta: if isinstance(node.meta["tensor_meta"], list): @@ -322,7 +334,13 @@ def inference_shape(self, inputs): else: lst = [node.meta["tensor_meta"]] for data in lst: - print(node.name, data) + res.append( + [node.name, node.op, node.target, list(data.shape), data.dtype] + ) + logger.info( + "\n" + tabulate(res, headers=["name", "op", "target", "shape", "dtype"]), + ranks=0, + ) def calculate_reshard_cost(self, prev, curr, shape): return int( @@ -434,6 +452,8 @@ def construct_z3_problem(self): def solve(self, inputs, max_iter=100): self.inference_shape(inputs) + self.dump_node() + sys.exit() self.construct_z3_graph() self.construct_z3_problem() sol = z3.Solver() @@ -524,4 +544,4 @@ def solve(self, inputs, max_iter=100): ): print( f'sch["{op.node.target}"].sync(mode="fwd_post", sync_op_or_fn="all_reduce")' - ) \ No newline at end of file + ) From 54884ac1722412d017df37ff969329c02df6111b Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 21 May 2023 06:34:04 +0000 Subject: [PATCH 04/36] Add fx_op_map & param_dump --- slapo/sharding/solver.py | 50 ++++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 14348718..cb2ee90f 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -237,7 +237,6 @@ def __init__(self, node, mod=None, is_linear=False): else node.meta["tensor_meta"][0].shape ) self.lhs_size = self.lhs_shape[-2] * self.lhs_shape[-1] - print(self.name, self.lhs_shape, self.rhs_shape, self.out_shape) if is_linear: # weight is transposed assert self.lhs_shape[-1] == self.rhs_shape[-1] @@ -302,11 +301,26 @@ def calculate_comm_cost_z3(self): return result +fx_op_map = { + nn.Linear: MatmulOp, + nn.LayerNorm: LayerNormOp, + nn.Dropout: DropoutOp, + torch.matmul: MatmulOp, + F.relu: ElementwiseOp, + F.gelu: ElementwiseOp, + F.softmax: ElementwiseOp, + torch._C._nn.gelu: ElementwiseOp, + operator.truediv: ElementwiseOp, + operator.add: BinaryOp, +} + + class Solver: def __init__(self, gm, p) -> None: assert isinstance(gm, fx.GraphModule), "gm must be a GraphModule" self.gm = gm self.gm.graph.eliminate_dead_code() + logger.debug(self.gm.graph, ranks=0) self.named_modules = dict(self.gm.named_modules()) self.z3_graph = {} # {node_name: FxOp} self.goal = [] @@ -334,9 +348,20 @@ def dump_node(self): else: lst = [node.meta["tensor_meta"]] for data in lst: + if node.op == "call_module": + target = type(self.named_modules[node.target]) + else: + target = node.target res.append( - [node.name, node.op, node.target, list(data.shape), data.dtype] + [node.name, node.op, target, list(data.shape), data.dtype] ) + if node.op == "call_module": + for name, param in self.named_modules[ + node.target + ].named_parameters(): + res.append( + ["|-" + name, "", "", list(param.shape), param.dtype] + ) logger.info( "\n" + tabulate(res, headers=["name", "op", "target", "shape", "dtype"]), ranks=0, @@ -357,9 +382,10 @@ def calculate_reshard_cost_z3(self, prev, curr, shape): return result def construct_z3_graph(self): - print(self.gm.graph) for node in self.gm.graph.nodes: - if "tensor_meta" not in node.meta: + if ( + "tensor_meta" not in node.meta + ): # not an activation tensor, no need to care continue if node.op == "placeholder": # input new_op = PlaceholderOp(node) @@ -376,18 +402,9 @@ def construct_z3_graph(self): else: raise RuntimeError(f"Unsupported module: {node.target}") elif node.op == "call_function": - if node.target == torch.matmul: - new_op = MatmulOp(node) - elif node.target in [ - F.relu, - F.gelu, - F.softmax, - torch._C._nn.gelu, - operator.truediv, - ]: - new_op = ElementwiseOp(node) - elif node.target in [operator.add]: - new_op = BinaryOp(node) + if node.target in fx_op_map: + new_cls = fx_op_map[node.target] + new_op = new_cls(node) else: raise RuntimeError(f"Unsupported function: {node.target}") elif node.op == "call_method": @@ -453,7 +470,6 @@ def construct_z3_problem(self): def solve(self, inputs, max_iter=100): self.inference_shape(inputs) self.dump_node() - sys.exit() self.construct_z3_graph() self.construct_z3_problem() sol = z3.Solver() From d5cf94f22b02162ad04c9296eb5778473d811d37 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 21 May 2023 20:38:16 +0000 Subject: [PATCH 05/36] Add dump_z3_graph --- slapo/sharding/solver.py | 28 +++++++++++++++++++++++++--- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index cb2ee90f..d0cc4452 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -339,7 +339,7 @@ def inference_shape(self, inputs): inputs = [inp.to(device) for inp in inputs] sp.propagate(*inputs) - def dump_node(self): + def dump_fx_node(self): res = [] for node in self.gm.graph.nodes: if "tensor_meta" in node.meta: @@ -428,7 +428,27 @@ def construct_z3_graph(self): new_op.add_arg(self.z3_graph[arg.name]) self.z3_graph[arg.name].add_user(new_op) self.z3_graph[node.name] = new_op - print(self.z3_graph) + + def dump_z3_graph(self, dot_file="z3_graph.dot"): + """ + Dump the z3 graph in dot format + """ + res = "digraph z3_graph {\n" + # add nodes + for op in self.z3_graph.values(): + attr = f'label="{op.name}"' + if isinstance(op, PlaceholderOp): + attr += ",shape=box" + elif isinstance(op, MatmulOp): + attr += ",style=filled,fillcolor=yellow" + res += f" {op.name} [{attr}];\n" + # add edges + for op in self.z3_graph.values(): + for arg in op.args: + res += f" {arg.name} -> {op.name};\n" + res += "}" + with open(dot_file, "w") as f: + f.write(res) def construct_z3_problem(self): bitvecs = {} @@ -469,8 +489,10 @@ def construct_z3_problem(self): def solve(self, inputs, max_iter=100): self.inference_shape(inputs) - self.dump_node() + self.dump_fx_node() self.construct_z3_graph() + self.dump_z3_graph() + sys.exit() self.construct_z3_problem() sol = z3.Solver() sol.add(self.goal) From 9d1c1764358e0c5471b7f7e283cf943419d2e63c Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 22 May 2023 03:30:41 +0000 Subject: [PATCH 06/36] Refactor & print --- slapo/sharding/solver.py | 271 ++++++++++++++++++++++----------------- 1 file changed, 153 insertions(+), 118 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index d0cc4452..28dc30ed 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -39,6 +39,8 @@ def __init__(self, node): self.users = [] self.out_shape = node.meta["tensor_meta"].shape self.out_size = self.out_shape[-2] * self.out_shape[-1] + self.z3_inputs = [] + self.input_v = [] def add_arg(self, arg): self.args.append(arg) @@ -46,8 +48,9 @@ def add_arg(self, arg): def add_user(self, user): self.users.append(user) - def set_concrete_values(self, inp): - self.input_v = inp + def set_concrete_values(self, inputs): + assert isinstance(inputs, list) + self.input_v = inputs def generate_input_z3(self): raise NotImplementedError @@ -93,28 +96,33 @@ def generate_output(self): def generate_output_z3(self): return self.args[0].generate_output_z3() + def calculate_comm_cost(self): + return 0 + def calculate_comm_cost_z3(self): return 0 class BinaryOp(FxOp): def generate_input_z3(self): - self.lhs = z3.BitVec(f"{self.name}_0", 2) - self.rhs = z3.BitVec(f"{self.name}_1", 2) - compute_constraints = [self.lhs == self.rhs] - format_constraints = [z3.ULE(self.lhs, 3), z3.ULE(self.rhs, 3)] + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + self.z3_inputs.append(z3.BitVec(f"{self.name}_1", 2)) + compute_constraints = [self.z3_inputs[0] == self.z3_inputs[1]] + format_constraints = [ + z3.ULE(self.z3_inputs[0], 3), + z3.ULE(self.z3_inputs[1], 3), + ] constraints = compute_constraints + format_constraints - return [self.lhs, self.rhs], constraints - - def set_concrete_values(self, lhs, rhs): - self.lhs_v = lhs - self.rhs_v = rhs + return self.z3_inputs, constraints def generate_output(self): - return self.lhs_v + return self.input_v[0] def generate_output_z3(self): - return self.lhs + return self.z3_inputs[0] + + def calculate_comm_cost(self): + return 0 def calculate_comm_cost_z3(self): # output remains the same spec as the inputs @@ -131,15 +139,18 @@ class SoftmaxOp(FxOp): class ViewOp(FxOp): def generate_input_z3(self): - self.input = z3.BitVec(f"{self.name}_0", 2) - format_constraints = [z3.ULE(self.input, 3)] - return [self.input], format_constraints + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + format_constraints = [z3.ULE(self.z3_inputs[0], 3)] + return self.z3_inputs, format_constraints def generate_output(self): - return self.input_v + return self.input_v[0] def generate_output_z3(self): - return self.input + return self.z3_inputs[0] + + def calculate_comm_cost(self): + return 0 def calculate_comm_cost_z3(self): # output remains the same spec as the inputs @@ -156,11 +167,6 @@ def __init__(self, node, z3_graph): def generate_input_z3(self): return [], [] - # self.input = z3.BitVec(f"{self.name}_0", 2) - # compute_constraints = [self.input == self.prev_op.generate_output_z3()] - # format_constraints = [z3.ULE(self.input, 3)] - # constraints = compute_constraints + format_constraints - # return [self.input], constraints def generate_output(self): return ShardSpec( @@ -177,6 +183,9 @@ def generate_output_z3(self): ) return result + def calculate_comm_cost(self): + return 0 + def calculate_comm_cost_z3(self): # output remains the same spec as the inputs return 0 @@ -213,6 +222,9 @@ def generate_output_z3(self): ) return result + def calculate_comm_cost(self): + return 0 + def calculate_comm_cost_z3(self): # output remains the same spec as the inputs return 0 @@ -254,50 +266,54 @@ def __init__(self, node, mod=None, is_linear=False): } def generate_input_z3(self): - self.lhs = z3.BitVec(f"{self.name}_0", 2) # input - self.rhs = z3.BitVec(f"{self.name}_1", 2) # weight + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) # input + self.z3_inputs.append(z3.BitVec(f"{self.name}_1", 2)) # weight compute_constraints = [ z3.Or( [ z3.And( - self.lhs == ShardSpec("RR").id, self.rhs == ShardSpec("RS").id + self.z3_inputs[0] == ShardSpec("RR").id, + self.z3_inputs[1] == ShardSpec("RS").id, ), z3.And( - self.lhs == ShardSpec("RS").id, self.rhs == ShardSpec("SR").id + self.z3_inputs[0] == ShardSpec("RS").id, + self.z3_inputs[1] == ShardSpec("SR").id, ), z3.And( - self.lhs == ShardSpec("SR").id, self.rhs == ShardSpec("RR").id + self.z3_inputs[0] == ShardSpec("SR").id, + self.z3_inputs[1] == ShardSpec("RR").id, ), ] ) ] - format_constraints = [z3.ULE(self.lhs, 3), z3.ULE(self.rhs, 3)] + format_constraints = [ + z3.ULE(self.z3_inputs[0], 3), + z3.ULE(self.z3_inputs[1], 3), + ] constraints = compute_constraints + format_constraints # force to shard - # constraints += [self.lhs != ShardSpec("RR").id, self.rhs != ShardSpec("RR").id] - return [self.lhs, self.rhs], constraints - - def set_concrete_values(self, lhs, rhs): - self.lhs_v = lhs - self.rhs_v = rhs + # constraints += [self.z3_inputs[0] != ShardSpec("RR").id, self.z3_inputs[1] != ShardSpec("RR").id] + return self.z3_inputs, constraints def generate_output(self): - return ShardSpec(self.output_map[ShardSpec(self.lhs_v).spec]).id + return ShardSpec(self.output_map[ShardSpec(self.input_v[0]).spec]).id def generate_output_z3(self): result = 3 # invalid for inp, out in self.output_map.items(): - result = z3.If(self.lhs == ShardSpec(inp).id, ShardSpec(out).id, result) + result = z3.If( + self.z3_inputs[0] == ShardSpec(inp).id, ShardSpec(out).id, result + ) return result def calculate_comm_cost(self): - return self.comm_cost_map[ShardSpec(self.lhs_v).spec] + return self.comm_cost_map[ShardSpec(self.input_v[0]).spec] def calculate_comm_cost_z3(self): result = 1e12 # invalid for inp, cost in self.comm_cost_map.items(): - result = z3.If(self.lhs == ShardSpec(inp).id, cost, result) + result = z3.If(self.z3_inputs[0] == ShardSpec(inp).id, cost, result) return result @@ -454,6 +470,7 @@ def construct_z3_problem(self): bitvecs = {} input_constraints = [] comm_costs = [] + reshard_costs = [] for op in self.z3_graph.values(): # no need to include output, since output can be obtained from inputs inputs, constraints = op.generate_input_z3() @@ -463,9 +480,7 @@ def construct_z3_problem(self): input_constraints.extend(constraints) # communication cost comm_costs.append(op.calculate_comm_cost_z3()) - - reshard_costs = [] - for op in self.z3_graph.values(): + # reshard cost for i, arg in enumerate(op.args): name = f"{op.name}_{i}" if name not in bitvecs: @@ -487,85 +502,71 @@ def construct_z3_problem(self): self.cost = sum(comm_costs) + sum(reshard_costs) self.goal += input_constraints - def solve(self, inputs, max_iter=100): - self.inference_shape(inputs) - self.dump_fx_node() - self.construct_z3_graph() - self.dump_z3_graph() - sys.exit() - self.construct_z3_problem() - sol = z3.Solver() - sol.add(self.goal) - max_cost = 1e12 - for it in range(max_iter): - print(f"=================== Iter {it} ===================") - sol.push() - assert self.cost is not None - sol.add(self.cost < max_cost) - # print(sol) - sat = sol.check() - if str(sat) == "unsat": - print("Cannot find better solutions") - break - mod = sol.model() - print(mod) - results = {d.name(): mod[d] for d in mod.decls()} - max_cost = 0 - for name, op in self.z3_graph.items(): - flag = True - if isinstance(op, (MatmulOp, BinaryOp)): - lhs = results[f"{name}_0"] - rhs = results[f"{name}_1"] - op.set_concrete_values(lhs, rhs) - output = op.generate_output() - print( - f"{name}: {ShardSpec(lhs)} x {ShardSpec(rhs)} = {ShardSpec(output)}" - ) - if isinstance(op, MatmulOp): - print( - f" {name}: {op.lhs_shape} x {op.rhs_shape} = {op.out_shape}" - ) - comm_cost = op.calculate_comm_cost() - max_cost += comm_cost - print(f" Comm cost: {comm_cost}") - elif f"{name}_0" in results: - inp = results[f"{name}_0"] - op.set_concrete_values(inp) - output = op.generate_output() - print(f"{name}: {ShardSpec(inp)} -> {ShardSpec(output)}") - else: - continue # flag = False - if flag: - # resharding cost - for i, arg in enumerate(op.args): - curr = results[f"{name}_{i}"] - prev = arg.generate_output() - reshard_cost = self.calculate_reshard_cost( - prev, curr, arg.out_size - ) - max_cost += reshard_cost - print( - f" Resharding cost ({arg.name}) {ShardSpec(prev)} -> {ShardSpec(curr)}: {reshard_cost}" - ) - if len(op.users) == 0: - next_inp = ShardSpec("RR").id - reshard_cost = self.calculate_reshard_cost( - output, next_inp, op.out_size - ) - max_cost += reshard_cost - print( - f" Last resharding cost {ShardSpec(output)} -> {ShardSpec(next_inp)}: {reshard_cost}" - ) - print("Total cost:", max_cost) - sol.pop() - # generate sharding sequence - self.best_spec = results + def calculate_new_cost(self, results): + max_cost = 0 + table = [] + for name, op in self.z3_graph.items(): + # communication cost + inputs = [] + if f"{name}_0" in results: + inputs.append(results[f"{name}_0"]) + if f"{name}_1" in results: + inputs.append(results[f"{name}_1"]) + op.set_concrete_values(inputs) + output = op.generate_output() + comm_cost = op.calculate_comm_cost() + max_cost += comm_cost + if len(inputs) == 1: + table.append( + [op.name, ShardSpec(inputs[0]), ShardSpec(output), comm_cost] + ) + elif len(inputs) == 2: + table.append( + [ + op.name, + f"{ShardSpec(inputs[0])}x{ShardSpec(inputs[1])}", + ShardSpec(output), + comm_cost, + ] + ) + elif len(inputs) > 2: + raise RuntimeError("Not supported") + # resharding cost + for i, arg in enumerate(op.args): + arg_name = f"{op.name}_{i}" + if arg_name not in results: + continue + curr = results[arg_name] + prev = arg.generate_output() + reshard_cost = self.calculate_reshard_cost(prev, curr, arg.out_size) + max_cost += reshard_cost + table.append( + [f"|-{arg.name}", ShardSpec(prev), ShardSpec(curr), reshard_cost] + ) + # final output should not be sharded + if len(op.users) == 0: + next_inp = ShardSpec("RR").id + reshard_cost = self.calculate_reshard_cost( + output, next_inp, op.out_size + ) + max_cost += reshard_cost + table.append( + ["output", ShardSpec(output), ShardSpec(next_inp), reshard_cost] + ) + table.append(["Total", "", "", max_cost]) + logger.info( + "\n" + tabulate(table, headers=["Name", "InSpec", "OutSpec", "Cost"]), + ranks=0, + ) + return max_cost + + def generate_schedule_sequence(self, results): print() print("Best solution:") for name, op in self.z3_graph.items(): if not isinstance(op, MatmulOp): continue - weight = self.best_spec[f"{name}_1"] + weight = results[f"{name}_1"] if weight == ShardSpec("RS").id: dim = 0 # transposed elif weight == ShardSpec("SR").id: @@ -577,9 +578,43 @@ def solve(self, inputs, max_iter=100): if dim == 0: print(f'sch["{op.node.target}"].shard("bias", dim={dim})') if ( - self.best_spec[f"{name}_0"] == ShardSpec("RS").id - and self.best_spec[f"{name}_1"] == ShardSpec("SR").id + results[f"{name}_0"] == ShardSpec("RS").id + and results[f"{name}_1"] == ShardSpec("SR").id ): print( f'sch["{op.node.target}"].sync(mode="fwd_post", sync_op_or_fn="all_reduce")' ) + + def solve(self, inputs, max_iter=100): + # 1. Shape propagation + self.inference_shape(inputs) + self.dump_fx_node() + # 2. Construct a simplied z3 graph from the fx graph + self.construct_z3_graph() + self.dump_z3_graph() + # 3. Construct the z3 constraints + self.construct_z3_problem() + # 4. Construct the z3 solver + sol = z3.Solver() + sol.add(self.goal) + max_cost = 1e12 + for it in range(max_iter): + logger.info(f"=================== Iter {it} ===================", ranks=0) + sol.push() + # 5. Update cost constraint + sol.add(self.cost < max_cost) + # 6. Solve the problem + sat = sol.check() + if str(sat) == "unsat": + logger.info("Cannot find better solutions", ranks=0) + break + mod = sol.model() + logger.info(f"new_cost: {mod.evaluate(self.cost)}", ranks=0) + logger.info(mod, ranks=0) + # Get the results + results = {d.name(): mod[d] for d in mod.decls()} + # 7. Calculate new cost from the results + max_cost = self.calculate_new_cost(results) + sol.pop() + # 8. Generate sharding sequence + self.generate_schedule_sequence(results) From fcf1a2e8c56497b6539209fb3387d7df807de44d Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 22 May 2023 04:33:55 +0000 Subject: [PATCH 07/36] Add reshard schedule --- slapo/sharding/solver.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 28dc30ed..d6b05db0 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -584,6 +584,26 @@ def generate_schedule_sequence(self, results): print( f'sch["{op.node.target}"].sync(mode="fwd_post", sync_op_or_fn="all_reduce")' ) + # reshard + for name, op in self.z3_graph.items(): + for i, arg in enumerate(op.args): + arg_name = f"{op.name}_{i}" + if arg_name not in results: + continue + curr = results[arg_name] + prev = arg.generate_output() + if int(str(curr)) != prev: + print( + f'sch["{op.name}"].sync(mode="fwd_pre", sync_op_or_fn="{ShardSpec(prev)}->{ShardSpec(curr)}")' + ) + # final output should not be sharded + if len(op.users) == 0: + next_inp = ShardSpec("RR").id + output = op.generate_output() + if output != next_inp: + print( + f'sch["{op.name}"].sync(mode="fwd_post", sync_op_or_fn="{ShardSpec(output)}->{ShardSpec(next_inp)}")' + ) def solve(self, inputs, max_iter=100): # 1. Shape propagation From bf542cfa2851881fa7fa252654b90b238dfb5b6e Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 22 May 2023 04:54:17 +0000 Subject: [PATCH 08/36] Fix generate_output --- slapo/sharding/solver.py | 58 +++++++++++----------------------------- 1 file changed, 16 insertions(+), 42 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index d6b05db0..7e632bad 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -55,8 +55,12 @@ def set_concrete_values(self, inputs): def generate_input_z3(self): raise NotImplementedError - def generate_output(self): - raise NotImplementedError + def generate_output(self, mod): + output = self.generate_output_z3() + if isinstance(output, int): + return output + else: + return mod.evaluate(output).as_long() def generate_output_z3(self): raise NotImplementedError @@ -73,9 +77,6 @@ def generate_input_z3(self): # input should not be sharded return [], [] - def generate_output(self): - return ShardSpec("RR").id - def generate_output_z3(self): return ShardSpec("RR").id @@ -90,9 +91,6 @@ class ElementwiseOp(FxOp): def generate_input_z3(self): return [], [] - def generate_output(self): - return self.args[0].generate_output() - def generate_output_z3(self): return self.args[0].generate_output_z3() @@ -115,9 +113,6 @@ def generate_input_z3(self): constraints = compute_constraints + format_constraints return self.z3_inputs, constraints - def generate_output(self): - return self.input_v[0] - def generate_output_z3(self): return self.z3_inputs[0] @@ -143,9 +138,6 @@ def generate_input_z3(self): format_constraints = [z3.ULE(self.z3_inputs[0], 3)] return self.z3_inputs, format_constraints - def generate_output(self): - return self.input_v[0] - def generate_output_z3(self): return self.z3_inputs[0] @@ -168,11 +160,6 @@ def __init__(self, node, z3_graph): def generate_input_z3(self): return [], [] - def generate_output(self): - return ShardSpec( - self.output_map[ShardSpec(self.prev_op.generate_output()).spec] - ).id - def generate_output_z3(self): result = 3 # invalid for inp, out in self.output_map.items(): @@ -201,16 +188,6 @@ def __init__(self, node, z3_graph): def generate_input_z3(self): return [], [] - # self.input = z3.BitVec(f"{self.name}_0", 2) - # compute_constraints = [] - # format_constraints = [z3.ULE(self.input, 3)] - # constraints = compute_constraints + format_constraints - # return [self.input], constraints - - def generate_output(self): - return ShardSpec( - self.output_map[ShardSpec(self.prev_op.generate_output()).spec] - ).id def generate_output_z3(self): result = 3 # invalid @@ -296,9 +273,6 @@ def generate_input_z3(self): # constraints += [self.z3_inputs[0] != ShardSpec("RR").id, self.z3_inputs[1] != ShardSpec("RR").id] return self.z3_inputs, constraints - def generate_output(self): - return ShardSpec(self.output_map[ShardSpec(self.input_v[0]).spec]).id - def generate_output_z3(self): result = 3 # invalid for inp, out in self.output_map.items(): @@ -502,7 +476,7 @@ def construct_z3_problem(self): self.cost = sum(comm_costs) + sum(reshard_costs) self.goal += input_constraints - def calculate_new_cost(self, results): + def calculate_new_cost(self, mod, results): max_cost = 0 table = [] for name, op in self.z3_graph.items(): @@ -513,7 +487,7 @@ def calculate_new_cost(self, results): if f"{name}_1" in results: inputs.append(results[f"{name}_1"]) op.set_concrete_values(inputs) - output = op.generate_output() + output = op.generate_output(mod) comm_cost = op.calculate_comm_cost() max_cost += comm_cost if len(inputs) == 1: @@ -537,7 +511,7 @@ def calculate_new_cost(self, results): if arg_name not in results: continue curr = results[arg_name] - prev = arg.generate_output() + prev = arg.generate_output(mod) reshard_cost = self.calculate_reshard_cost(prev, curr, arg.out_size) max_cost += reshard_cost table.append( @@ -560,7 +534,7 @@ def calculate_new_cost(self, results): ) return max_cost - def generate_schedule_sequence(self, results): + def generate_schedule_sequence(self, mod, results): print() print("Best solution:") for name, op in self.z3_graph.items(): @@ -590,16 +564,16 @@ def generate_schedule_sequence(self, results): arg_name = f"{op.name}_{i}" if arg_name not in results: continue - curr = results[arg_name] - prev = arg.generate_output() - if int(str(curr)) != prev: + curr = results[arg_name].as_long() + prev = arg.generate_output(mod) + if curr != prev: print( f'sch["{op.name}"].sync(mode="fwd_pre", sync_op_or_fn="{ShardSpec(prev)}->{ShardSpec(curr)}")' ) # final output should not be sharded if len(op.users) == 0: next_inp = ShardSpec("RR").id - output = op.generate_output() + output = op.generate_output(mod) if output != next_inp: print( f'sch["{op.name}"].sync(mode="fwd_post", sync_op_or_fn="{ShardSpec(output)}->{ShardSpec(next_inp)}")' @@ -634,7 +608,7 @@ def solve(self, inputs, max_iter=100): # Get the results results = {d.name(): mod[d] for d in mod.decls()} # 7. Calculate new cost from the results - max_cost = self.calculate_new_cost(results) + max_cost = self.calculate_new_cost(mod, results) sol.pop() # 8. Generate sharding sequence - self.generate_schedule_sequence(results) + self.generate_schedule_sequence(mod, results) From 35f150525c349c56d830edfd22a081a8bdd89a0a Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 22 May 2023 05:04:31 +0000 Subject: [PATCH 09/36] Remove useless functions --- slapo/sharding/solver.py | 61 +++++++++++++++------------------------- 1 file changed, 22 insertions(+), 39 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 7e632bad..42be25d4 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -40,7 +40,6 @@ def __init__(self, node): self.out_shape = node.meta["tensor_meta"].shape self.out_size = self.out_shape[-2] * self.out_shape[-1] self.z3_inputs = [] - self.input_v = [] def add_arg(self, arg): self.args.append(arg) @@ -48,10 +47,6 @@ def add_arg(self, arg): def add_user(self, user): self.users.append(user) - def set_concrete_values(self, inputs): - assert isinstance(inputs, list) - self.input_v = inputs - def generate_input_z3(self): raise NotImplementedError @@ -65,8 +60,12 @@ def generate_output(self, mod): def generate_output_z3(self): raise NotImplementedError - def calculate_comm_cost(self): - raise NotImplementedError + def calculate_comm_cost(self, mod): + cost = self.calculate_comm_cost_z3() + if isinstance(cost, int): + return cost + else: + return mod.evaluate(cost).as_long() def calculate_comm_cost_z3(self): raise NotImplementedError @@ -80,9 +79,6 @@ def generate_input_z3(self): def generate_output_z3(self): return ShardSpec("RR").id - def calculate_comm_cost(self): - return 0 - def calculate_comm_cost_z3(self): return 0 @@ -94,9 +90,6 @@ def generate_input_z3(self): def generate_output_z3(self): return self.args[0].generate_output_z3() - def calculate_comm_cost(self): - return 0 - def calculate_comm_cost_z3(self): return 0 @@ -116,9 +109,6 @@ def generate_input_z3(self): def generate_output_z3(self): return self.z3_inputs[0] - def calculate_comm_cost(self): - return 0 - def calculate_comm_cost_z3(self): # output remains the same spec as the inputs return 0 @@ -141,9 +131,6 @@ def generate_input_z3(self): def generate_output_z3(self): return self.z3_inputs[0] - def calculate_comm_cost(self): - return 0 - def calculate_comm_cost_z3(self): # output remains the same spec as the inputs return 0 @@ -170,9 +157,6 @@ def generate_output_z3(self): ) return result - def calculate_comm_cost(self): - return 0 - def calculate_comm_cost_z3(self): # output remains the same spec as the inputs return 0 @@ -199,9 +183,6 @@ def generate_output_z3(self): ) return result - def calculate_comm_cost(self): - return 0 - def calculate_comm_cost_z3(self): # output remains the same spec as the inputs return 0 @@ -281,9 +262,6 @@ def generate_output_z3(self): ) return result - def calculate_comm_cost(self): - return self.comm_cost_map[ShardSpec(self.input_v[0]).spec] - def calculate_comm_cost_z3(self): result = 1e12 # invalid for inp, cost in self.comm_cost_map.items(): @@ -353,14 +331,14 @@ def dump_fx_node(self): ["|-" + name, "", "", list(param.shape), param.dtype] ) logger.info( - "\n" + tabulate(res, headers=["name", "op", "target", "shape", "dtype"]), + "\n" + + tabulate(res, headers=["name", "op", "target", "shape", "dtype"]) + + "\n", ranks=0, ) - def calculate_reshard_cost(self, prev, curr, shape): - return int( - self.reshard_cost_map[ShardSpec(prev).spec][ShardSpec(curr).spec] * shape - ) + def calculate_reshard_cost(self, mod, prev, curr, shape): + return mod.evaluate(self.calculate_reshard_cost_z3(prev, curr, shape)) def calculate_reshard_cost_z3(self, prev, curr, shape): result = 1e12 # invalid @@ -486,9 +464,8 @@ def calculate_new_cost(self, mod, results): inputs.append(results[f"{name}_0"]) if f"{name}_1" in results: inputs.append(results[f"{name}_1"]) - op.set_concrete_values(inputs) output = op.generate_output(mod) - comm_cost = op.calculate_comm_cost() + comm_cost = op.calculate_comm_cost(mod) max_cost += comm_cost if len(inputs) == 1: table.append( @@ -512,7 +489,9 @@ def calculate_new_cost(self, mod, results): continue curr = results[arg_name] prev = arg.generate_output(mod) - reshard_cost = self.calculate_reshard_cost(prev, curr, arg.out_size) + reshard_cost = self.calculate_reshard_cost( + mod, prev, curr, arg.out_size + ) max_cost += reshard_cost table.append( [f"|-{arg.name}", ShardSpec(prev), ShardSpec(curr), reshard_cost] @@ -521,15 +500,18 @@ def calculate_new_cost(self, mod, results): if len(op.users) == 0: next_inp = ShardSpec("RR").id reshard_cost = self.calculate_reshard_cost( - output, next_inp, op.out_size + mod, output, next_inp, op.out_size ) max_cost += reshard_cost table.append( ["output", ShardSpec(output), ShardSpec(next_inp), reshard_cost] ) + max_cost = z3.simplify(max_cost).as_long() table.append(["Total", "", "", max_cost]) logger.info( - "\n" + tabulate(table, headers=["Name", "InSpec", "OutSpec", "Cost"]), + "\n" + + tabulate(table, headers=["Name", "InSpec", "OutSpec", "Cost"]) + + "\n", ranks=0, ) return max_cost @@ -603,12 +585,13 @@ def solve(self, inputs, max_iter=100): logger.info("Cannot find better solutions", ranks=0) break mod = sol.model() - logger.info(f"new_cost: {mod.evaluate(self.cost)}", ranks=0) + total_cost = mod.evaluate(self.cost) logger.info(mod, ranks=0) # Get the results results = {d.name(): mod[d] for d in mod.decls()} # 7. Calculate new cost from the results max_cost = self.calculate_new_cost(mod, results) + assert max_cost == total_cost.as_long() sol.pop() # 8. Generate sharding sequence self.generate_schedule_sequence(mod, results) From 15b3358a73e9bde8488e85dfb38c39adb2b1b300 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 22 May 2023 06:06:59 +0000 Subject: [PATCH 10/36] Add test --- slapo/sharding/solver.py | 10 +++++++++- tests/test_autoshard.py | 33 +++++++++++++++++++++------------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 42be25d4..4064c265 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -1,5 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +""" +Auto-parallelism solver that finds the optimal sharding scheme for a given model. +It models the problem as a program synthesis problem and uses Z3 to solve it. +""" import operator import torch @@ -17,6 +21,10 @@ class ShardSpec: def __init__(self, spec): + """ + R: replicated + S: sharded + """ self.map = {"RR": 0, "RS": 1, "SR": 2} if isinstance(spec, str): self.spec = spec @@ -277,7 +285,6 @@ def calculate_comm_cost_z3(self): F.relu: ElementwiseOp, F.gelu: ElementwiseOp, F.softmax: ElementwiseOp, - torch._C._nn.gelu: ElementwiseOp, operator.truediv: ElementwiseOp, operator.add: BinaryOp, } @@ -595,3 +602,4 @@ def solve(self, inputs, max_iter=100): sol.pop() # 8. Generate sharding sequence self.generate_schedule_sequence(mod, results) + return results, max_cost diff --git a/tests/test_autoshard.py b/tests/test_autoshard.py index e18718f8..7827257a 100644 --- a/tests/test_autoshard.py +++ b/tests/test_autoshard.py @@ -6,10 +6,6 @@ Verified by different combinations of resharding schemes. """ -import os -import copy -import argparse - import torch from torch import nn from torch import fx @@ -21,6 +17,7 @@ logger = get_logger(__name__) # Config for verification +p = 8 bs = 8 seq_len = 1024 hidden_size = 1024 @@ -39,14 +36,26 @@ def forward(self, x): return x -with slapo.init_empty_weights(): - mlp = MLP(hidden_size) +def test_mlp(): + with slapo.init_empty_weights(): + mlp = MLP(hidden_size) + + sch = slapo.create_schedule(mlp) + sch.trace() + assert isinstance(sch.mod, fx.GraphModule) + + from slapo.sharding import Solver -sch = slapo.create_schedule(mlp) -sch.trace() -assert isinstance(sch.mod, fx.GraphModule) + sol = Solver(sch.mod, p=p) + results, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) + # fc1: SRxRR->SR + # fc2: SRxRR->SR->RR + assert results["fc1_0"] == 2 + assert results["fc1_1"] == 0 + assert results["fc2_0"] == 2 + assert results["fc2_1"] == 0 + assert max_cost == seq_len * hidden_size / p -from slapo.sharding import Solver -sol = Solver(sch.mod, p=8) -sol.solve([torch.randn(bs, seq_len, hidden_size)]) +if __name__ == "__main__": + test_mlp() From 5f8f2d28cb31614f6da521957412d3b80d2c3221 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 22 May 2023 06:20:06 +0000 Subject: [PATCH 11/36] Fix pylint --- slapo/sharding/solver.py | 42 +++++++++++++++++++--------------------- tests/test_autoshard.py | 1 - 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 4064c265..27cbc0dc 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -62,8 +62,7 @@ def generate_output(self, mod): output = self.generate_output_z3() if isinstance(output, int): return output - else: - return mod.evaluate(output).as_long() + return mod.evaluate(output).as_long() def generate_output_z3(self): raise NotImplementedError @@ -72,8 +71,7 @@ def calculate_comm_cost(self, mod): cost = self.calculate_comm_cost_z3() if isinstance(cost, int): return cost - else: - return mod.evaluate(cost).as_long() + return mod.evaluate(cost).as_long() def calculate_comm_cost_z3(self): raise NotImplementedError @@ -122,12 +120,13 @@ def calculate_comm_cost_z3(self): return 0 -class LayerNormOp(FxOp): - pass +# TODO: support more ops +# class LayerNormOp(FxOp): +# pass -class SoftmaxOp(FxOp): - pass +# class SoftmaxOp(FxOp): +# pass class ViewOp(FxOp): @@ -196,8 +195,8 @@ def calculate_comm_cost_z3(self): return 0 -class DropoutOp(FxOp): - pass +# class DropoutOp(FxOp): +# pass class MatmulOp(FxOp): @@ -279,8 +278,8 @@ def calculate_comm_cost_z3(self): fx_op_map = { nn.Linear: MatmulOp, - nn.LayerNorm: LayerNormOp, - nn.Dropout: DropoutOp, + # nn.LayerNorm: LayerNormOp, + # nn.Dropout: DropoutOp, torch.matmul: MatmulOp, F.relu: ElementwiseOp, F.gelu: ElementwiseOp, @@ -310,7 +309,7 @@ def __init__(self, gm, p) -> None: def inference_shape(self, inputs): sp = ShapeProp(self.gm) # Tackle the case of meta device - device = self.gm.named_parameters().__next__()[1].device + device = next(self.gm.named_parameters())[1].device inputs = [inp.to(device) for inp in inputs] sp.propagate(*inputs) @@ -338,9 +337,8 @@ def dump_fx_node(self): ["|-" + name, "", "", list(param.shape), param.dtype] ) logger.info( - "\n" - + tabulate(res, headers=["name", "op", "target", "shape", "dtype"]) - + "\n", + "\n %s \n", + tabulate(res, headers=["name", "op", "target", "shape", "dtype"]), ranks=0, ) @@ -383,6 +381,7 @@ def construct_z3_graph(self): else: raise RuntimeError(f"Unsupported function: {node.target}") elif node.op == "call_method": + # pylint: disable=redefined-variable-type if node.target == "view": new_op = ViewOp(node) elif node.target == "permute": @@ -422,7 +421,7 @@ def dump_z3_graph(self, dot_file="z3_graph.dot"): for arg in op.args: res += f" {arg.name} -> {op.name};\n" res += "}" - with open(dot_file, "w") as f: + with open(dot_file, "w", encoding="utf-8") as f: f.write(res) def construct_z3_problem(self): @@ -516,9 +515,8 @@ def calculate_new_cost(self, mod, results): max_cost = z3.simplify(max_cost).as_long() table.append(["Total", "", "", max_cost]) logger.info( - "\n" - + tabulate(table, headers=["Name", "InSpec", "OutSpec", "Cost"]) - + "\n", + "\n %s \n", + tabulate(table, headers=["Name", "InSpec", "OutSpec", "Cost"]), ranks=0, ) return max_cost @@ -580,9 +578,9 @@ def solve(self, inputs, max_iter=100): # 4. Construct the z3 solver sol = z3.Solver() sol.add(self.goal) - max_cost = 1e12 + max_cost = int(1e12) for it in range(max_iter): - logger.info(f"=================== Iter {it} ===================", ranks=0) + logger.info("=================== Iter %d ===================", it, ranks=0) sol.push() # 5. Update cost constraint sol.add(self.cost < max_cost) diff --git a/tests/test_autoshard.py b/tests/test_autoshard.py index 7827257a..72a2df14 100644 --- a/tests/test_autoshard.py +++ b/tests/test_autoshard.py @@ -1,6 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=unused-argument """ Test different resharding schemes on MLP. Verified by different combinations of resharding schemes. From 4df19569b0e714b3a32ae37a5cb3ff92857b80d3 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 22 May 2023 23:49:01 +0000 Subject: [PATCH 12/36] Fix z3 graph --- slapo/sharding/solver.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 27cbc0dc..774a6ba9 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -326,6 +326,8 @@ def dump_fx_node(self): target = type(self.named_modules[node.target]) else: target = node.target + if isinstance(data, tuple): + continue res.append( [node.name, node.op, target, list(data.shape), data.dtype] ) @@ -389,7 +391,7 @@ def construct_z3_graph(self): elif node.target == "transpose": new_op = TransposeOp(node, self.z3_graph) elif node.target == "contiguous": - continue + new_op = ElementwiseOp(node) else: raise RuntimeError(f"Unsupported method: {node.target}") else: # output @@ -401,6 +403,10 @@ def construct_z3_graph(self): continue new_op.add_arg(self.z3_graph[arg.name]) self.z3_graph[arg.name].add_user(new_op) + else: + arg = node.args[0] + new_op.add_arg(self.z3_graph[arg.name]) + self.z3_graph[arg.name].add_user(new_op) self.z3_graph[node.name] = new_op def dump_z3_graph(self, dot_file="z3_graph.dot"): @@ -410,7 +416,7 @@ def dump_z3_graph(self, dot_file="z3_graph.dot"): res = "digraph z3_graph {\n" # add nodes for op in self.z3_graph.values(): - attr = f'label="{op.name}"' + attr = f'label="{op.name}\\n({op.__class__.__name__})"' if isinstance(op, PlaceholderOp): attr += ",shape=box" elif isinstance(op, MatmulOp): From f0e5f90cb66470e16beb6954fa581852266d8355 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 03:17:22 +0000 Subject: [PATCH 13/36] Update ln and softmax rules --- slapo/sharding/solver.py | 63 ++++++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 774a6ba9..5e07d282 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -120,16 +120,42 @@ def calculate_comm_cost_z3(self): return 0 -# TODO: support more ops -# class LayerNormOp(FxOp): -# pass +class LayerNormOp(FxOp): + def generate_input_z3(self): + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + format_constraints = [z3.ULE(self.z3_inputs[0], 3)] + # Reduction across the last dimension, so `RS` is prohibited. + format_constraints += [self.z3_inputs[0] != 1] + return self.z3_inputs, format_constraints + def generate_output_z3(self): + # The same spec as the input + return self.z3_inputs[0] + + def calculate_comm_cost_z3(self): + # No communication cost + return 0 -# class SoftmaxOp(FxOp): -# pass + +class SoftmaxOp(FxOp): + def generate_input_z3(self): + self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + format_constraints = [z3.ULE(self.z3_inputs[0], 3)] + # Reduction across the last dimension, so `RS` is prohibited. + format_constraints += [self.z3_inputs[0] != 1] + return self.z3_inputs, format_constraints + + def generate_output_z3(self): + # The same spec as the input + return self.z3_inputs[0] + + def calculate_comm_cost_z3(self): + # No communication cost + return 0 class ViewOp(FxOp): + # TODO: verify the behavior of general view function def generate_input_z3(self): self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) format_constraints = [z3.ULE(self.z3_inputs[0], 3)] @@ -139,16 +165,21 @@ def generate_output_z3(self): return self.z3_inputs[0] def calculate_comm_cost_z3(self): - # output remains the same spec as the inputs + # `view` can redistribute the dimensions, thus can be used to + # convert to any spec without communication return 0 class PermuteOp(FxOp): def __init__(self, node, z3_graph): - # FIXME: Suppose permute is always (0, 2, 1, 3) super().__init__(node) self.z3_graph = z3_graph - self.output_map = {"RR": "RR", "RS": "RS", "SR": "RR"} + permute_idx = list(node.args[1:]) + self.output_map = {} + for in_spec in ["RR", "RS", "SR"]: + spec = "R" * (len(permute_idx) - 2) + in_spec + out_spec = spec[-2:] + self.output_map[in_spec] = out_spec self.prev_op = self.z3_graph[self.node.args[0].name] def generate_input_z3(self): @@ -165,7 +196,7 @@ def generate_output_z3(self): return result def calculate_comm_cost_z3(self): - # output remains the same spec as the inputs + # permutation does not involve communication return 0 @@ -195,10 +226,6 @@ def calculate_comm_cost_z3(self): return 0 -# class DropoutOp(FxOp): -# pass - - class MatmulOp(FxOp): def __init__(self, node, mod=None, is_linear=False): super().__init__(node) @@ -278,12 +305,12 @@ def calculate_comm_cost_z3(self): fx_op_map = { nn.Linear: MatmulOp, - # nn.LayerNorm: LayerNormOp, - # nn.Dropout: DropoutOp, + nn.LayerNorm: LayerNormOp, + F.softmax: SoftmaxOp, + nn.Dropout: ElementwiseOp, torch.matmul: MatmulOp, F.relu: ElementwiseOp, F.gelu: ElementwiseOp, - F.softmax: ElementwiseOp, operator.truediv: ElementwiseOp, operator.add: BinaryOp, } @@ -372,8 +399,8 @@ def construct_z3_graph(self): mod=mod, is_linear=True, ) - elif isinstance(mod, (nn.LayerNorm, nn.Dropout)): - new_op = ElementwiseOp(node) + elif type(mod) in fx_op_map: + new_op = fx_op_map[type(mod)](node) else: raise RuntimeError(f"Unsupported module: {node.target}") elif node.op == "call_function": From efc0a443840e2a913e18be69a2cb1b3148c28e28 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 03:26:25 +0000 Subject: [PATCH 14/36] Add penalty for splitting --- slapo/sharding/solver.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 5e07d282..a5fb5c2b 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -379,7 +379,11 @@ def calculate_reshard_cost_z3(self, prev, curr, shape): for in_spec, target_map in self.reshard_cost_map.items(): tmp = 1e12 # invalid for out_spec, val in target_map.items(): - tmp = z3.If(curr == ShardSpec(out_spec).id, int(val * shape), tmp) + if in_spec == "RR" and out_spec in ["RS", "SR"]: + cost = 1 # add penalty for splitting cost + else: + cost = int(val * shape) + tmp = z3.If(curr == ShardSpec(out_spec).id, cost, tmp) result = z3.If(prev == ShardSpec(in_spec).id, tmp, result) return result From 472a27c04aba81b8c315ffd038cf2a80e4133fe7 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 03:33:00 +0000 Subject: [PATCH 15/36] Use total shape size to calculate communication cost --- slapo/sharding/solver.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index a5fb5c2b..8cf73d6f 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -46,7 +46,7 @@ def __init__(self, node): self.args = [] self.users = [] self.out_shape = node.meta["tensor_meta"].shape - self.out_size = self.out_shape[-2] * self.out_shape[-1] + self.out_size = int(torch.prod(torch.tensor(self.out_shape))) self.z3_inputs = [] def add_arg(self, arg): @@ -229,27 +229,6 @@ def calculate_comm_cost_z3(self): class MatmulOp(FxOp): def __init__(self, node, mod=None, is_linear=False): super().__init__(node) - self.lhs_shape = node.args[0].meta["tensor_meta"].shape - self.rhs_shape = ( - node.args[1].meta["tensor_meta"].shape - if not is_linear - else mod.weight.shape - ) - self.out_shape = ( - node.meta["tensor_meta"].shape - if not isinstance(node.meta["tensor_meta"], list) - else node.meta["tensor_meta"][0].shape - ) - self.lhs_size = self.lhs_shape[-2] * self.lhs_shape[-1] - if is_linear: - # weight is transposed - assert self.lhs_shape[-1] == self.rhs_shape[-1] - self.rhs_size = self.rhs_shape[-1] * self.rhs_shape[-2] - self.out_size = self.lhs_shape[-2] * self.rhs_shape[-2] - else: - assert self.lhs_shape[-1] == self.rhs_shape[-2] - self.rhs_size = self.rhs_shape[-2] * self.rhs_shape[-1] - self.out_size = self.lhs_shape[-2] * self.rhs_shape[-1] self.output_map = {"RR": "RS", "RS": "RR", "SR": "SR"} self.comm_cost_map = { # map from input spec to comm cost "RR": 0, From 27067ab4445c6936d8984470400c3964cb328d37 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 03:57:08 +0000 Subject: [PATCH 16/36] Dump z3 graph with specs --- slapo/sharding/solver.py | 42 ++++++++++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 8cf73d6f..a0480656 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -419,10 +419,14 @@ def construct_z3_graph(self): self.z3_graph[arg.name].add_user(new_op) self.z3_graph[node.name] = new_op - def dump_z3_graph(self, dot_file="z3_graph.dot"): + def dump_z3_graph(self, mod=None, dot_file="z3_graph.dot"): """ Dump the z3 graph in dot format """ + if mod is None: + results = None + else: + results = {d.name(): mod[d] for d in mod.decls()} res = "digraph z3_graph {\n" # add nodes for op in self.z3_graph.values(): @@ -430,12 +434,28 @@ def dump_z3_graph(self, dot_file="z3_graph.dot"): if isinstance(op, PlaceholderOp): attr += ",shape=box" elif isinstance(op, MatmulOp): - attr += ",style=filled,fillcolor=yellow" + if results is None: + attr += ",style=filled,fillcolor=yellow" + else: + weight_spec = results[op.name + "_1"] + if weight_spec == ShardSpec("RR").id: + attr += ",style=filled,fillcolor=yellow" + elif weight_spec == ShardSpec("RS").id: + attr += ',shape=box,style=striped,fillcolor="#FF5733:#FFBD33"' + else: # weight_spec == ShardSpec("SR").id + attr += ',shape=box,style=wedged,fillcolor="#FF5733:#FFBD33"' res += f" {op.name} [{attr}];\n" # add edges for op in self.z3_graph.values(): - for arg in op.args: - res += f" {arg.name} -> {op.name};\n" + for i, arg in enumerate(op.args): + if results is None: + label = "" + else: + if op.name + "_" + str(i) not in results: + label = "" + else: + label = f' [label="{ShardSpec(arg.generate_output(mod))}->{ShardSpec(results[op.name+"_"+str(i)])}"]' + res += f" {arg.name} -> {op.name}{label};\n" res += "}" with open(dot_file, "w", encoding="utf-8") as f: f.write(res) @@ -476,7 +496,8 @@ def construct_z3_problem(self): self.cost = sum(comm_costs) + sum(reshard_costs) self.goal += input_constraints - def calculate_new_cost(self, mod, results): + def calculate_new_cost(self, mod): + results = {d.name(): mod[d] for d in mod.decls()} max_cost = 0 table = [] for name, op in self.z3_graph.items(): @@ -537,9 +558,10 @@ def calculate_new_cost(self, mod, results): ) return max_cost - def generate_schedule_sequence(self, mod, results): + def generate_schedule_sequence(self, mod): print() print("Best solution:") + results = {d.name(): mod[d] for d in mod.decls()} for name, op in self.z3_graph.items(): if not isinstance(op, MatmulOp): continue @@ -608,12 +630,12 @@ def solve(self, inputs, max_iter=100): mod = sol.model() total_cost = mod.evaluate(self.cost) logger.info(mod, ranks=0) - # Get the results - results = {d.name(): mod[d] for d in mod.decls()} # 7. Calculate new cost from the results - max_cost = self.calculate_new_cost(mod, results) + max_cost = self.calculate_new_cost(mod) assert max_cost == total_cost.as_long() sol.pop() # 8. Generate sharding sequence - self.generate_schedule_sequence(mod, results) + self.generate_schedule_sequence(mod) + self.dump_z3_graph(mod, "z3_graph_sharded.dot") + results = {d.name(): mod[d] for d in mod.decls()} return results, max_cost From 9a326efc69dc1787d9ff128939a503fe9b2777d1 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 04:02:01 +0000 Subject: [PATCH 17/36] Add attn test --- tests/test_autoshard.py | 35 ++++++++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/tests/test_autoshard.py b/tests/test_autoshard.py index 72a2df14..dfb41df1 100644 --- a/tests/test_autoshard.py +++ b/tests/test_autoshard.py @@ -12,6 +12,7 @@ import slapo from slapo.logger import get_logger +from slapo.sharding import Solver logger = get_logger(__name__) @@ -43,8 +44,6 @@ def test_mlp(): sch.trace() assert isinstance(sch.mod, fx.GraphModule) - from slapo.sharding import Solver - sol = Solver(sch.mod, p=p) results, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) # fc1: SRxRR->SR @@ -53,8 +52,38 @@ def test_mlp(): assert results["fc1_1"] == 0 assert results["fc2_0"] == 2 assert results["fc2_1"] == 0 - assert max_cost == seq_len * hidden_size / p + assert max_cost == (bs * seq_len * hidden_size / p + 1) + + +def test_attn(): + from transformers import BertLMHeadModel, AutoConfig + import inspect + + config = AutoConfig.from_pretrained("bert-large-uncased") + with slapo.init_empty_weights(): + model = BertLMHeadModel(config) + logger.info(config, ranks=0) + + sch = slapo.create_schedule(model) + input_names = ["hidden_states"] + i = 0 + subsch = sch[f"bert.encoder.layer.{i}"] + sig = inspect.signature(subsch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + subsch.trace( + recursive=False, + flatten=True, + tracer="pytorch", + concrete_args=concrete_args, + ) + logger.info(subsch.mod.graph, ranks=0) + + sol = Solver(subsch.mod, p=p) + results, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) if __name__ == "__main__": test_mlp() + test_attn() From 0b1de583977668a84292a1fbe56dfa868298f83d Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 05:32:04 +0000 Subject: [PATCH 18/36] Fix ViewOp --- slapo/sharding/solver.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index a0480656..214d3c78 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -156,6 +156,13 @@ def calculate_comm_cost_z3(self): class ViewOp(FxOp): # TODO: verify the behavior of general view function + # (bs,seq,d) -> (bs,seq,h,d//h) + def __init__(self, node, z3_graph, p): + super().__init__(node) + self.z3_graph = z3_graph + self.num_devices = p + self.prev_op = self.z3_graph[self.node.args[0].name] + def generate_input_z3(self): self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) format_constraints = [z3.ULE(self.z3_inputs[0], 3)] @@ -166,8 +173,17 @@ def generate_output_z3(self): def calculate_comm_cost_z3(self): # `view` can redistribute the dimensions, thus can be used to - # convert to any spec without communication - return 0 + # convert to most of the specs without communication, + # but be careful about RS->RR, which requires an all-gather + result = z3.If( + z3.And( + self.prev_op.generate_output_z3() == ShardSpec("RS").id, + self.z3_inputs[0] == ShardSpec("RR").id, + ), + self.out_size * 1 / self.num_devices, + 0, + ) + return result class PermuteOp(FxOp): @@ -395,7 +411,7 @@ def construct_z3_graph(self): elif node.op == "call_method": # pylint: disable=redefined-variable-type if node.target == "view": - new_op = ViewOp(node) + new_op = ViewOp(node, self.z3_graph, self.num_devices) elif node.target == "permute": new_op = PermuteOp(node, self.z3_graph) elif node.target == "transpose": From ef2c565ad9d08b76ce32f1ca6eaaa37aebf21a9b Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 06:49:26 +0000 Subject: [PATCH 19/36] Fix test and output --- slapo/sharding/solver.py | 4 ++-- tests/test_autoshard.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 214d3c78..715e8554 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -589,9 +589,9 @@ def generate_schedule_sequence(self, mod): else: continue if op.node.op == "call_module": - print(f'sch["{op.node.target}"].shard("weight", dim={dim})') + print(f'sch["{op.node.target}"].shard("weight", axis={dim})') if dim == 0: - print(f'sch["{op.node.target}"].shard("bias", dim={dim})') + print(f'sch["{op.node.target}"].shard("bias", axis={dim})') if ( results[f"{name}_0"] == ShardSpec("RS").id and results[f"{name}_1"] == ShardSpec("SR").id diff --git a/tests/test_autoshard.py b/tests/test_autoshard.py index dfb41df1..30cf113a 100644 --- a/tests/test_autoshard.py +++ b/tests/test_autoshard.py @@ -81,7 +81,8 @@ def test_attn(): logger.info(subsch.mod.graph, ranks=0) sol = Solver(subsch.mod, p=p) - results, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) + _, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) + assert max_cost == 3 * (bs * seq_len * hidden_size / p) + 2 if __name__ == "__main__": From 87591402cab74533a1768a30a850ddfa88eede77 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 06:52:36 +0000 Subject: [PATCH 20/36] Fix pylint --- slapo/sharding/solver.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 715e8554..04bf3bc1 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -192,7 +192,7 @@ def __init__(self, node, z3_graph): self.z3_graph = z3_graph permute_idx = list(node.args[1:]) self.output_map = {} - for in_spec in ["RR", "RS", "SR"]: + for in_spec in ("RR", "RS", "SR"): spec = "R" * (len(permute_idx) - 2) + in_spec out_spec = spec[-2:] self.output_map[in_spec] = out_spec @@ -243,7 +243,7 @@ def calculate_comm_cost_z3(self): class MatmulOp(FxOp): - def __init__(self, node, mod=None, is_linear=False): + def __init__(self, node): super().__init__(node) self.output_map = {"RR": "RS", "RS": "RR", "SR": "SR"} self.comm_cost_map = { # map from input spec to comm cost @@ -374,7 +374,7 @@ def calculate_reshard_cost_z3(self, prev, curr, shape): for in_spec, target_map in self.reshard_cost_map.items(): tmp = 1e12 # invalid for out_spec, val in target_map.items(): - if in_spec == "RR" and out_spec in ["RS", "SR"]: + if in_spec == "RR" and out_spec in {"RS", "SR"}: cost = 1 # add penalty for splitting cost else: cost = int(val * shape) @@ -393,11 +393,7 @@ def construct_z3_graph(self): elif node.op == "call_module": mod = self.named_modules[node.target] if isinstance(mod, nn.Linear): - new_op = MatmulOp( - node, - mod=mod, - is_linear=True, - ) + new_op = MatmulOp(node) elif type(mod) in fx_op_map: new_op = fx_op_map[type(mod)](node) else: @@ -466,11 +462,10 @@ def dump_z3_graph(self, mod=None, dot_file="z3_graph.dot"): for i, arg in enumerate(op.args): if results is None: label = "" + elif op.name + "_" + str(i) not in results: + label = "" else: - if op.name + "_" + str(i) not in results: - label = "" - else: - label = f' [label="{ShardSpec(arg.generate_output(mod))}->{ShardSpec(results[op.name+"_"+str(i)])}"]' + label = f' [label="{ShardSpec(arg.generate_output(mod))}->{ShardSpec(results[op.name+"_"+str(i)])}"]' res += f" {arg.name} -> {op.name}{label};\n" res += "}" with open(dot_file, "w", encoding="utf-8") as f: From 285e4a7c2c5241b4b9e56352e3c66dfb3db52579 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Fri, 26 May 2023 21:26:23 +0000 Subject: [PATCH 21/36] Update dependency --- ci/install_test_pkgs.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/install_test_pkgs.sh b/ci/install_test_pkgs.sh index 8fbbf0c2..14c59b12 100644 --- a/ci/install_test_pkgs.sh +++ b/ci/install_test_pkgs.sh @@ -5,3 +5,4 @@ python3 -m pip install black==22.10.0 python3 -m pip install transformers==4.25.1 --no-deps python3 -m pip install pylint==2.14.0 astroid==2.11.6 mock==4.0.3 +python3 -m pip install z3-solver tabulate From 49945c669f0bde212e232a9cc41aaeb62249c5df Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sat, 27 May 2023 00:31:11 +0000 Subject: [PATCH 22/36] Add test_bert --- tests/autoshard/test_bert.py | 284 +++++++++++++++++++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 tests/autoshard/test_bert.py diff --git a/tests/autoshard/test_bert.py b/tests/autoshard/test_bert.py new file mode 100644 index 00000000..d1b03cb9 --- /dev/null +++ b/tests/autoshard/test_bert.py @@ -0,0 +1,284 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import copy +import inspect +import operator +import argparse + +import torch +import torch.distributed as dist +from transformers import BertLMHeadModel, AutoConfig + +import slapo +from slapo.logger import get_logger + +logger = get_logger(__name__) + +# Config for verification +bs = 8 +seq_len = 512 + + +def perf_model(mod, input_tensor): + """Measure the performance of a mod with certain resharding schemes""" + # warmup + for _ in range(5): + mod(input_tensor) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + for _ in range(10): + mod(input_tensor) + end_event.record() + torch.cuda.synchronize() + if dist.get_rank() == 0: + print(f"{start_event.elapsed_time(end_event) / 10:.3f} ms") + + +def trace_and_find_view(sch): + input_names = ["hidden_states"] + sig = inspect.signature(sch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + sch.trace( + recursive=False, flatten=True, tracer="pytorch", concrete_args=concrete_args + ) + ops = sch.find_node(lambda node: node.op == "call_method" and node.target == "view") + assert len(ops) == 4 # q,k,v,context_layer + return ops + + +def fix_attention_mask_shape_megatron(sch): + ops = trace_and_find_view(sch) + + def new_view(tensor, args): + if len(args) == 4: # q,k,v + out = tensor.view(args[0], args[1], args[2] // sch.world_size, args[3]) + else: # context_layer + out = tensor.view(args[0], args[1], args[2] // sch.world_size) + return out + + for op in ops: + sch.replace(new_view, op) + + +def scheme_megatron(model, input_ids, config): + sch = slapo.create_schedule(model) + + with slapo.Verify(sch, [input_ids]): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + subsch["dense"].shard("weight", axis=1) # replace + subsch["dense"].sync("fwd_post", sync_op_or_fn="all_reduce") # replace + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["intermediate.dense"].shard("weight", axis=0) + subsch["intermediate.dense"].shard("bias", axis=0) + subsch["output.dense"].shard("weight", axis=1) + subsch["output.dense"].sync("fwd_post", sync_op_or_fn="all_reduce") + + return sch + + +def scheme_weight_stationary(model, input_ids, config): + sch = slapo.create_schedule(model) + + from slapo.sharding.reshard_ops import ( + reshard_RR_to_SR, + reshard_SR_to_RR, + reshard_RS_to_RR, + ) + + def fix_attention_mask_shape(sch): + ops = trace_and_find_view(sch) + + def new_view_kv(tensor, args): + return tensor.view(args[0], args[1], args[2], args[3] // sch.world_size) + + sch.replace(new_view_kv, ops[0]) # key + sch.replace(new_view_kv, ops[1]) # value + + def reshard_and_add(dropout, hidden_states): + """Replace the add operator with reshard_and_add""" + reshard_hidden_states = reshard_RR_to_SR(hidden_states, sch.group) + return dropout + reshard_hidden_states + + def new_matmul(lhs, rhs): + return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) + + def new_matmul_1(lhs, rhs): + return torch.matmul(lhs, reshard_RS_to_RR(rhs, sch.group)) + + with slapo.Verify(sch, [input_ids], enable=False): + for i in range(config.num_hidden_layers): + # attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + subsch["query"].sync(mode="fwd_pre", sync_op_or_fn="RR->SR") + fix_attention_mask_shape(subsch) + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == torch.matmul + ) + assert len(ops) == 2 + subsch.replace(new_matmul, ops[0]) + subsch.replace(new_matmul_1, ops[1]) + # residual add + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + subsch.trace(recursive=False, flatten=False, tracer="pytorch") + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == operator.add + ) + subsch.replace(reshard_and_add, ops[0]) + # MLP + sch[f"bert.encoder.layer.{i}.output.LayerNorm"].sync( + mode="fwd_post", sync_op_or_fn="SR->RR" + ) + + return sch + + +def scheme_activation_stationary(model, input_ids, config): + sch = slapo.create_schedule(model) + with slapo.Verify(sch, [input_ids]): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + # shape here: [4096, 256](RS). Need to matmul with [1024, 1024] (without shard) + subsch["dense"].sync("fwd_pre", sync_op_or_fn="RS->RR") + subsch["dense"].shard("weight", axis=0) + subsch["dense"].shard("bias", axis=0) + subsch["dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["intermediate.dense"].shard("weight", axis=0) + subsch["intermediate.dense"].shard("bias", axis=0) + subsch["intermediate.dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + subsch["output.dense"].shard("weight", axis=0) + subsch["output.dense"].shard("bias", axis=0) + subsch["output.dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + + return sch + + +def scheme_activation_sharding(model, input_ids, config): + sch = slapo.create_schedule(model) + + from slapo.sharding.reshard_ops import reshard_RR_to_SR + + def reshard_and_add(dropout, hidden_states): + """Replace the add operator with reshard_and_add""" + reshard_hidden_states = reshard_RR_to_SR(hidden_states, sch.group) + return dropout + reshard_hidden_states + + with slapo.Verify(sch, [input_ids]): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + + subsch.trace(recursive=False, flatten=False, tracer="pytorch") + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == operator.add + ) + subsch.replace(reshard_and_add, ops[0]) + + # shape here: RS + subsch["dense"].sync( + "fwd_pre", sync_op_or_fn="RS->SR" + ) # LayerNorm will crash for SR x RR = SR + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["output.LayerNorm"].sync("fwd_post", sync_op_or_fn="SR->RR") + + return sch + + +def test_schemes(init_dist): + torch.cuda.set_device(dist.get_rank()) + device = torch.cuda.current_device() + + config = AutoConfig.from_pretrained("bert-large-uncased") + with slapo.init_empty_weights(): + model = BertLMHeadModel(config) + + schs = [] + input_ids = torch.ones(bs, seq_len, dtype=torch.long, device=device) + # 1. Slapo-Megatron + # RR x RS = RS, RS x SR = RR + schs.append(scheme_megatron(copy.deepcopy(model), input_ids, config)) + # 2. Weight-Stationary + # RR->RS x RR = RS, RS x RR = RS->RR + schs.append(scheme_weight_stationary(copy.deepcopy(model), input_ids, config)) + # 3. Activation-Stationary + # RR x RS = RS + schs.append(scheme_activation_stationary(copy.deepcopy(model), input_ids, config)) + # 4. Activation Sharding. SR x RR = SR + schs.append(scheme_activation_sharding(copy.deepcopy(model), input_ids, config)) + return schs + + +if __name__ == "__main__": + # Create parser + parser = argparse.ArgumentParser(description="Resharding schemes on BERT") + # Add arguments + parser.add_argument("--bs", type=int, help="Batch size", default=1) + parser.add_argument("--seq", type=int, help="Sequence length", default=512) + # Parse the arguments + args = parser.parse_args() + + bs = args.bs + seq_len = args.seq + + dist.init_process_group("nccl", world_size=int(os.environ["WORLD_SIZE"])) + + logger.info( + "Number of GPUs: %d, bs=%d, seq_len=%d; Model: BERT-large", + dist.get_world_size(), + bs, + seq_len, + ranks=0, + ) + + schs = test_schemes(None) + + input_ids = torch.ones( + bs, seq_len, dtype=torch.long, device=f"cuda:{dist.get_rank()}" + ) + for i, sch in enumerate(schs): + mod, _ = slapo.build(sch, init_weights=sch.mod._init_weights) + mod.to(f"cuda:{dist.get_rank()}") + perf_model(mod, input_ids) + del mod + torch.cuda.empty_cache() From 59e0b9bfceed677ced709f79aa9e12dde542ee9c Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 28 May 2023 19:32:54 +0000 Subject: [PATCH 23/36] Fix viewop --- slapo/sharding/solver.py | 46 +++++++++++++++++++++++++++++++++++----- tests/test_autoshard.py | 2 +- 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 04bf3bc1..09a71efe 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -155,8 +155,40 @@ def calculate_comm_cost_z3(self): class ViewOp(FxOp): + """ # TODO: verify the behavior of general view function - # (bs,seq,d) -> (bs,seq,h,d//h) + Only certain view functions can be sharded without communication. + Currently only reshaping the *last* dimension is supported. + + Consider the view function in Transformer: + (bs,seq,d) -> (bs,seq,h,d//h) + + We have the following communication matrix: + (The src spec only considers the last dimension) + src\dst RR RS SR + R 0 0 0 + S 1/p 1/p 0 + + S->RR requires an all-gather to retrieve all the data. + S->RS also requires an all-gather before the view function to ensure the data + is correct. To illustrate this case, consider h=2, p=2, d=8, d//h=4, and + the original data is [0 1 2 3 4 5 6 7], and the expected result is + [[0 1 2 3], [4 5 6 7]]. If we want to shard it into RS spec on two devices, + the data should be as follows: + Device 1 | Device 2 + 0 1 | 2 3 + 4 5 | 6 7 + But if we directly view from the source sharded spec shown below, + Device 1 | Device 2 + 0 1 2 3 | 4 5 6 7 + and reshape it to (h,d//h//p), we get + Device 1 | Device 2 + 0 1 | 4 5 + 2 3 | 6 7 + Thus, the data is incorrect. + To avoid this, we need to all-gather the data first, and then reshape it. + """ + def __init__(self, node, z3_graph, p): super().__init__(node) self.z3_graph = z3_graph @@ -172,16 +204,20 @@ def generate_output_z3(self): return self.z3_inputs[0] def calculate_comm_cost_z3(self): - # `view` can redistribute the dimensions, thus can be used to - # convert to most of the specs without communication, - # but be careful about RS->RR, which requires an all-gather result = z3.If( z3.And( self.prev_op.generate_output_z3() == ShardSpec("RS").id, self.z3_inputs[0] == ShardSpec("RR").id, ), self.out_size * 1 / self.num_devices, - 0, + z3.If( + z3.And( + self.prev_op.generate_output_z3() == ShardSpec("RS").id, + self.z3_inputs[0] == ShardSpec("RS").id, + ), + self.out_size * 1 / self.num_devices, + 0, + ), ) return result diff --git a/tests/test_autoshard.py b/tests/test_autoshard.py index 30cf113a..8fb24f50 100644 --- a/tests/test_autoshard.py +++ b/tests/test_autoshard.py @@ -82,7 +82,7 @@ def test_attn(): sol = Solver(subsch.mod, p=p) _, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) - assert max_cost == 3 * (bs * seq_len * hidden_size / p) + 2 + assert max_cost == 3 * (bs * seq_len * hidden_size / p) + 4 if __name__ == "__main__": From e8d126ac37ad6777b60cc43b7cf4ec36fd0f47f5 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 28 May 2023 19:50:17 +0000 Subject: [PATCH 24/36] Add seq_par --- tests/autoshard/test_bert.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/tests/autoshard/test_bert.py b/tests/autoshard/test_bert.py index d1b03cb9..68231066 100644 --- a/tests/autoshard/test_bert.py +++ b/tests/autoshard/test_bert.py @@ -94,7 +94,7 @@ def scheme_megatron(model, input_ids, config): return sch -def scheme_weight_stationary(model, input_ids, config): +def scheme_sequence_parallel(model, input_ids, config): sch = slapo.create_schedule(model) from slapo.sharding.reshard_ops import ( @@ -103,36 +103,25 @@ def scheme_weight_stationary(model, input_ids, config): reshard_RS_to_RR, ) - def fix_attention_mask_shape(sch): - ops = trace_and_find_view(sch) - - def new_view_kv(tensor, args): - return tensor.view(args[0], args[1], args[2], args[3] // sch.world_size) - - sch.replace(new_view_kv, ops[0]) # key - sch.replace(new_view_kv, ops[1]) # value - def reshard_and_add(dropout, hidden_states): """Replace the add operator with reshard_and_add""" reshard_hidden_states = reshard_RR_to_SR(hidden_states, sch.group) return dropout + reshard_hidden_states def new_matmul(lhs, rhs): - return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) + return torch.matmul(lhs, reshard_RS_to_RR(rhs, sch.group)) def new_matmul_1(lhs, rhs): - return torch.matmul(lhs, reshard_RS_to_RR(rhs, sch.group)) + return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) - with slapo.Verify(sch, [input_ids], enable=False): + with slapo.Verify(sch, [input_ids], eval_mode=True, enable=True): for i in range(config.num_hidden_layers): # attention subsch = sch[f"bert.encoder.layer.{i}.attention.self"] - subsch["key"].shard("weight", axis=0) - subsch["key"].shard("bias", axis=0) - subsch["value"].shard("weight", axis=0) - subsch["value"].shard("bias", axis=0) subsch["query"].sync(mode="fwd_pre", sync_op_or_fn="RR->SR") - fix_attention_mask_shape(subsch) + subsch["key"].sync(mode="fwd_pre", sync_op_or_fn="RR->SR") + subsch["value"].sync(mode="fwd_pre", sync_op_or_fn="RR->SR") + trace_and_find_view(subsch) ops = subsch.find_node( lambda node: node.op == "call_function" and node.target == torch.matmul ) @@ -145,6 +134,7 @@ def new_matmul_1(lhs, rhs): ops = subsch.find_node( lambda node: node.op == "call_function" and node.target == operator.add ) + assert len(ops) == 1 subsch.replace(reshard_and_add, ops[0]) # MLP sch[f"bert.encoder.layer.{i}.output.LayerNorm"].sync( @@ -238,9 +228,9 @@ def test_schemes(init_dist): # 1. Slapo-Megatron # RR x RS = RS, RS x SR = RR schs.append(scheme_megatron(copy.deepcopy(model), input_ids, config)) - # 2. Weight-Stationary + # 2. Sequence-Parallel # RR->RS x RR = RS, RS x RR = RS->RR - schs.append(scheme_weight_stationary(copy.deepcopy(model), input_ids, config)) + schs.append(scheme_sequence_parallel(copy.deepcopy(model), input_ids, config)) # 3. Activation-Stationary # RR x RS = RS schs.append(scheme_activation_stationary(copy.deepcopy(model), input_ids, config)) From 6324524e6396d04e216de09931a871a3f10c39a6 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 28 May 2023 20:16:44 +0000 Subject: [PATCH 25/36] Efficient seq_par --- tests/autoshard/test_bert.py | 37 ++++++++++-------------------------- 1 file changed, 10 insertions(+), 27 deletions(-) diff --git a/tests/autoshard/test_bert.py b/tests/autoshard/test_bert.py index 68231066..810bf72e 100644 --- a/tests/autoshard/test_bert.py +++ b/tests/autoshard/test_bert.py @@ -24,19 +24,20 @@ def perf_model(mod, input_tensor): """Measure the performance of a mod with certain resharding schemes""" # warmup - for _ in range(5): + for _ in range(10): mod(input_tensor) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - for _ in range(10): + iters = 40 + for _ in range(iters): mod(input_tensor) end_event.record() torch.cuda.synchronize() if dist.get_rank() == 0: - print(f"{start_event.elapsed_time(end_event) / 10:.3f} ms") + print(f"{start_event.elapsed_time(end_event) / iters:.3f} ms") def trace_and_find_view(sch): @@ -98,16 +99,10 @@ def scheme_sequence_parallel(model, input_ids, config): sch = slapo.create_schedule(model) from slapo.sharding.reshard_ops import ( - reshard_RR_to_SR, reshard_SR_to_RR, reshard_RS_to_RR, ) - def reshard_and_add(dropout, hidden_states): - """Replace the add operator with reshard_and_add""" - reshard_hidden_states = reshard_RR_to_SR(hidden_states, sch.group) - return dropout + reshard_hidden_states - def new_matmul(lhs, rhs): return torch.matmul(lhs, reshard_RS_to_RR(rhs, sch.group)) @@ -115,12 +110,9 @@ def new_matmul_1(lhs, rhs): return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) with slapo.Verify(sch, [input_ids], eval_mode=True, enable=True): + sch["bert.embeddings.LayerNorm"].sync(mode="fwd_post", sync_op_or_fn="RR->SR") for i in range(config.num_hidden_layers): - # attention subsch = sch[f"bert.encoder.layer.{i}.attention.self"] - subsch["query"].sync(mode="fwd_pre", sync_op_or_fn="RR->SR") - subsch["key"].sync(mode="fwd_pre", sync_op_or_fn="RR->SR") - subsch["value"].sync(mode="fwd_pre", sync_op_or_fn="RR->SR") trace_and_find_view(subsch) ops = subsch.find_node( lambda node: node.op == "call_function" and node.target == torch.matmul @@ -128,18 +120,9 @@ def new_matmul_1(lhs, rhs): assert len(ops) == 2 subsch.replace(new_matmul, ops[0]) subsch.replace(new_matmul_1, ops[1]) - # residual add - subsch = sch[f"bert.encoder.layer.{i}.attention.output"] - subsch.trace(recursive=False, flatten=False, tracer="pytorch") - ops = subsch.find_node( - lambda node: node.op == "call_function" and node.target == operator.add - ) - assert len(ops) == 1 - subsch.replace(reshard_and_add, ops[0]) - # MLP - sch[f"bert.encoder.layer.{i}.output.LayerNorm"].sync( - mode="fwd_post", sync_op_or_fn="SR->RR" - ) + sch[f"bert.encoder.layer.{config.num_hidden_layers - 1}.output.LayerNorm"].sync( + mode="fwd_post", sync_op_or_fn="SR->RR" + ) return sch @@ -243,7 +226,7 @@ def test_schemes(init_dist): # Create parser parser = argparse.ArgumentParser(description="Resharding schemes on BERT") # Add arguments - parser.add_argument("--bs", type=int, help="Batch size", default=1) + parser.add_argument("--bs", type=int, help="Batch size", default=8) parser.add_argument("--seq", type=int, help="Sequence length", default=512) # Parse the arguments args = parser.parse_args() @@ -269,6 +252,6 @@ def test_schemes(init_dist): for i, sch in enumerate(schs): mod, _ = slapo.build(sch, init_weights=sch.mod._init_weights) mod.to(f"cuda:{dist.get_rank()}") + torch.cuda.empty_cache() perf_model(mod, input_ids) del mod - torch.cuda.empty_cache() From 14483c1626871b9ff728e6a6c8e3ed2922df7dfc Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 28 May 2023 21:23:47 +0000 Subject: [PATCH 26/36] Add comments --- slapo/sharding/solver.py | 7 +++++++ tests/test_autoshard.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 09a71efe..546bf5e6 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -197,6 +197,13 @@ def __init__(self, node, z3_graph, p): def generate_input_z3(self): self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) + # TODO: Need to consider when we support higher dimensions + # compute_constraints = [ + # z3.Implies( + # self.prev_op.generate_output_z3() == ShardSpec("SR").id, + # self.z3_inputs[0] == ShardSpec("RR").id, + # ), + # ] format_constraints = [z3.ULE(self.z3_inputs[0], 3)] return self.z3_inputs, format_constraints diff --git a/tests/test_autoshard.py b/tests/test_autoshard.py index 8fb24f50..c355f6ea 100644 --- a/tests/test_autoshard.py +++ b/tests/test_autoshard.py @@ -19,7 +19,7 @@ # Config for verification p = 8 bs = 8 -seq_len = 1024 +seq_len = 512 hidden_size = 1024 From b61f536fd648d61e80d7f20bf3ae1f74621c2b3b Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 29 May 2023 03:39:25 +0000 Subject: [PATCH 27/36] Move file --- tests/{ => autoshard}/test_autoshard.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{ => autoshard}/test_autoshard.py (100%) diff --git a/tests/test_autoshard.py b/tests/autoshard/test_autoshard.py similarity index 100% rename from tests/test_autoshard.py rename to tests/autoshard/test_autoshard.py From ff024a1cb6b4bb8fe5d82a898bd927e3ed2c1800 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 29 May 2023 03:53:09 +0000 Subject: [PATCH 28/36] Fix device --- slapo/sharding/solver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 546bf5e6..9c0dedcd 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -375,7 +375,8 @@ def inference_shape(self, inputs): sp = ShapeProp(self.gm) # Tackle the case of meta device device = next(self.gm.named_parameters())[1].device - inputs = [inp.to(device) for inp in inputs] + inputs = [inp.to("meta") for inp in inputs] + self.gm = self.gm.to(device) sp.propagate(*inputs) def dump_fx_node(self): From 9a55da02b180fa1d021b8d592bf229c415e3390f Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 29 May 2023 03:58:32 +0000 Subject: [PATCH 29/36] Fix TensorMetadata --- slapo/sharding/solver.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 9c0dedcd..d7f753e5 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -10,7 +10,7 @@ from torch import nn from torch import fx import torch.nn.functional as F -from torch.fx.passes.shape_prop import ShapeProp +from torch.fx.passes.shape_prop import ShapeProp, TensorMetadata import z3 from tabulate import tabulate @@ -392,7 +392,7 @@ def dump_fx_node(self): target = type(self.named_modules[node.target]) else: target = node.target - if isinstance(data, tuple): + if not isinstance(data, TensorMetadata): continue res.append( [node.name, node.op, target, list(data.shape), data.dtype] From 83c60534e07b6c802ea2918324330da90e0f2aec Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 29 May 2023 05:11:47 +0000 Subject: [PATCH 30/36] Support more ops --- slapo/sharding/solver.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index d7f753e5..39303458 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -349,8 +349,14 @@ def calculate_comm_cost_z3(self): torch.matmul: MatmulOp, F.relu: ElementwiseOp, F.gelu: ElementwiseOp, + torch.tensor: PlaceholderOp, + torch.where: ElementwiseOp, + torch.pow: ElementwiseOp, + torch.tanh: ElementwiseOp, operator.truediv: ElementwiseOp, + operator.getitem: ElementwiseOp, operator.add: BinaryOp, + operator.mul: BinaryOp, } @@ -456,16 +462,18 @@ def construct_z3_graph(self): new_op = PermuteOp(node, self.z3_graph) elif node.target == "transpose": new_op = TransposeOp(node, self.z3_graph) - elif node.target == "contiguous": + elif node.target in ["contiguous", "to"]: new_op = ElementwiseOp(node) else: raise RuntimeError(f"Unsupported method: {node.target}") + elif node.op == "get_attr": # extra buffers + new_op = PlaceholderOp(node) else: # output continue # construct edges if not (node.op == "call_method" and node.target == "view"): for arg in node.args: - if not isinstance(arg, fx.Node): + if not isinstance(arg, fx.Node) or arg.name not in self.z3_graph: continue new_op.add_arg(self.z3_graph[arg.name]) self.z3_graph[arg.name].add_user(new_op) From 4bce113dd5dc85fa2974c9e54785508fa6e1d1d1 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Mon, 29 May 2023 20:43:45 +0000 Subject: [PATCH 31/36] Add gpt support --- tests/autoshard/test_gpt.py | 273 ++++++++++++++++++++++++++++++++++++ 1 file changed, 273 insertions(+) create mode 100644 tests/autoshard/test_gpt.py diff --git a/tests/autoshard/test_gpt.py b/tests/autoshard/test_gpt.py new file mode 100644 index 00000000..13607931 --- /dev/null +++ b/tests/autoshard/test_gpt.py @@ -0,0 +1,273 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import copy +import inspect +import operator +import argparse + +import torch +from torch import fx +import torch.distributed as dist +from transformers import GPTNeoModel, AutoConfig + +import slapo +from slapo.logger import get_logger + +logger = get_logger(__name__) + +# Config for verification +bs = 8 +seq_len = 1024 + + +def perf_model(mod, input_tensor): + """Measure the performance of a mod with certain resharding schemes""" + # warmup + for _ in range(10): + mod(input_tensor) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + iters = 40 + for _ in range(iters): + mod(input_tensor) + end_event.record() + torch.cuda.synchronize() + if dist.get_rank() == 0: + print(f"{start_event.elapsed_time(end_event) / iters:.3f} ms") + + +def trace_and_find_view(sch, config): + input_names = ["hidden_states"] + sig = inspect.signature(sch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + sch.trace( + recursive=False, + flatten=True, + tracer="huggingface", + concrete_args=concrete_args, + config=config, + ) + ops = sch.find_node( + lambda node: node.op == "call_method" + and node.target == "view" + and ( + (node.args[0].op == "call_module" and "proj" in node.args[0].target) + or ( + len(node.args) > 1 + and isinstance(node.args[1], fx.Node) + and node.args[1].op == "call_function" + and node.args[1].target == operator.add + ) + ) + ) + assert len(ops) == 4 # q,k,v,context_layer + return ops + + +def fix_attention_mask_shape_megatron(sch, config): + ops = trace_and_find_view(sch, config) + + def new_view(tensor, args): + if len(args) == 4: # q,k,v + out = tensor.view(args[0], args[1], args[2] // sch.world_size, args[3]) + else: # context_layer + out = tensor.view(args[0], args[1], args[2] // sch.world_size) + return out + + for op in ops: + sch.replace(new_view, op) + + +def scheme_megatron(model, input_ids, config): + sch = slapo.create_schedule(model) + + with slapo.Verify(sch, [input_ids], enable=True): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"h.{i}.attn.attention"] + # no bias for GPTNeo + subsch["q_proj"].shard("weight", axis=0) + subsch["k_proj"].shard("weight", axis=0) + subsch["v_proj"].shard("weight", axis=0) + subsch["out_proj"].shard("weight", axis=1) + subsch["out_proj"].sync("fwd_post", sync_op_or_fn="all_reduce") + fix_attention_mask_shape_megatron(subsch, config) + # shard MLP + subsch = sch[f"h.{i}.mlp"] + subsch["c_fc"].shard("weight", axis=0) + subsch["c_fc"].shard("bias", axis=0) + subsch["c_proj"].shard("weight", axis=1) + subsch["c_proj"].sync("fwd_post", sync_op_or_fn="all_reduce") + + return sch + + +def scheme_sequence_parallel(model, input_ids, config): + sch = slapo.create_schedule(model) + + from slapo.sharding.reshard_ops import ( + reshard_SR_to_RR, + reshard_RS_to_RR, + ) + + def new_matmul(lhs, rhs): + return torch.matmul(lhs, reshard_RS_to_RR(rhs, sch.group)) + + def new_matmul_1(lhs, rhs): + return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) + + with slapo.Verify(sch, [input_ids], eval_mode=True, enable=True): + sch["bert.embeddings.LayerNorm"].sync(mode="fwd_post", sync_op_or_fn="RR->SR") + for i in range(config.num_hidden_layers): + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + trace_and_find_view(subsch) + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == torch.matmul + ) + assert len(ops) == 2 + subsch.replace(new_matmul, ops[0]) + subsch.replace(new_matmul_1, ops[1]) + sch[f"bert.encoder.layer.{config.num_hidden_layers - 1}.output.LayerNorm"].sync( + mode="fwd_post", sync_op_or_fn="SR->RR" + ) + + return sch + + +def scheme_activation_stationary(model, input_ids, config): + sch = slapo.create_schedule(model) + with slapo.Verify(sch, [input_ids]): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + # shape here: [4096, 256](RS). Need to matmul with [1024, 1024] (without shard) + subsch["dense"].sync("fwd_pre", sync_op_or_fn="RS->RR") + subsch["dense"].shard("weight", axis=0) + subsch["dense"].shard("bias", axis=0) + subsch["dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["intermediate.dense"].shard("weight", axis=0) + subsch["intermediate.dense"].shard("bias", axis=0) + subsch["intermediate.dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + subsch["output.dense"].shard("weight", axis=0) + subsch["output.dense"].shard("bias", axis=0) + subsch["output.dense"].sync("fwd_post", sync_op_or_fn="RS->RR") + + return sch + + +def scheme_activation_sharding(model, input_ids, config): + sch = slapo.create_schedule(model) + + from slapo.sharding.reshard_ops import reshard_RR_to_SR + + def reshard_and_add(dropout, hidden_states): + """Replace the add operator with reshard_and_add""" + reshard_hidden_states = reshard_RR_to_SR(hidden_states, sch.group) + return dropout + reshard_hidden_states + + with slapo.Verify(sch, [input_ids]): + for i in range(config.num_hidden_layers): + # shard attention + subsch = sch[f"bert.encoder.layer.{i}.attention.self"] + subsch["query"].shard("weight", axis=0) + subsch["query"].shard("bias", axis=0) + subsch["key"].shard("weight", axis=0) + subsch["key"].shard("bias", axis=0) + subsch["value"].shard("weight", axis=0) + subsch["value"].shard("bias", axis=0) + fix_attention_mask_shape_megatron(subsch) + subsch = sch[f"bert.encoder.layer.{i}.attention.output"] + + subsch.trace(recursive=False, flatten=False, tracer="pytorch") + ops = subsch.find_node( + lambda node: node.op == "call_function" and node.target == operator.add + ) + subsch.replace(reshard_and_add, ops[0]) + + # shape here: RS + subsch["dense"].sync( + "fwd_pre", sync_op_or_fn="RS->SR" + ) # LayerNorm will crash for SR x RR = SR + # shard MLP + subsch = sch[f"bert.encoder.layer.{i}"] + subsch["output.LayerNorm"].sync("fwd_post", sync_op_or_fn="SR->RR") + + return sch + + +def test_schemes(init_dist): + torch.cuda.set_device(dist.get_rank()) + device = torch.cuda.current_device() + + config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B") + config.use_cache = False + with slapo.init_empty_weights(): + model = GPTNeoModel(config) + + schs = [] + input_ids = torch.ones(bs, seq_len, dtype=torch.long, device=device) + # 1. Slapo-Megatron + # RR x RS = RS, RS x SR = RR + schs.append(scheme_megatron(copy.deepcopy(model), input_ids, config)) + sys.exit() + # 2. Sequence-Parallel + # RR->RS x RR = RS, RS x RR = RS->RR + schs.append(scheme_sequence_parallel(copy.deepcopy(model), input_ids, config)) + # 3. Activation-Stationary + # RR x RS = RS + schs.append(scheme_activation_stationary(copy.deepcopy(model), input_ids, config)) + # 4. Activation Sharding. SR x RR = SR + schs.append(scheme_activation_sharding(copy.deepcopy(model), input_ids, config)) + return schs + + +if __name__ == "__main__": + # Create parser + parser = argparse.ArgumentParser(description="Resharding schemes on GPTNeo") + # Add arguments + parser.add_argument("--bs", type=int, help="Batch size", default=2) + parser.add_argument("--seq", type=int, help="Sequence length", default=1024) + # Parse the arguments + args = parser.parse_args() + + bs = args.bs + seq_len = args.seq + + dist.init_process_group("nccl", world_size=int(os.environ["WORLD_SIZE"])) + + logger.info( + "Number of GPUs: %d, bs=%d, seq_len=%d; Model: GPTNeo", + dist.get_world_size(), + bs, + seq_len, + ranks=0, + ) + + schs = test_schemes(None) + + input_ids = torch.ones( + bs, seq_len, dtype=torch.long, device=f"cuda:{dist.get_rank()}" + ) + for i, sch in enumerate(schs): + mod, _ = slapo.build(sch, init_weights=sch.mod._init_weights) + mod.to(f"cuda:{dist.get_rank()}") + torch.cuda.empty_cache() + perf_model(mod, input_ids) + del mod From f93309896724747ca8cae2c1a255f41c9138568f Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 30 May 2023 02:55:55 +0000 Subject: [PATCH 32/36] Support seq par for GPT --- tests/autoshard/test_gpt.py | 49 +++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/tests/autoshard/test_gpt.py b/tests/autoshard/test_gpt.py index 13607931..049905d3 100644 --- a/tests/autoshard/test_gpt.py +++ b/tests/autoshard/test_gpt.py @@ -8,6 +8,7 @@ import argparse import torch +from torch import nn from torch import fx import torch.distributed as dist from transformers import GPTNeoModel, AutoConfig @@ -88,7 +89,8 @@ def new_view(tensor, args): def scheme_megatron(model, input_ids, config): sch = slapo.create_schedule(model) - with slapo.Verify(sch, [input_ids], enable=True): + enable = True if input_ids.shape[0] == 1 else False + with slapo.Verify(sch, [input_ids], enable=enable): for i in range(config.num_hidden_layers): # shard attention subsch = sch[f"h.{i}.attn.attention"] @@ -123,20 +125,42 @@ def new_matmul(lhs, rhs): def new_matmul_1(lhs, rhs): return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) - with slapo.Verify(sch, [input_ids], eval_mode=True, enable=True): - sch["bert.embeddings.LayerNorm"].sync(mode="fwd_post", sync_op_or_fn="RR->SR") + class NewMask(nn.Module): + def forward(self, query, key, bias): + query_length, key_length = ( + query.size(-2) * sch.world_size, + key.size(-2) * sch.world_size, + ) + size_per_chunk = query_length // sch.world_size + start_idx = key_length - query_length + size_per_chunk * sch.rank + end_idx = start_idx + size_per_chunk + causal_mask = bias[:, :, start_idx:end_idx, :key_length] + return causal_mask + + enable = True if input_ids.shape[0] == 1 else False + with slapo.Verify(sch, [input_ids], eval_mode=True, enable=enable): + sch["drop"].sync(mode="fwd_post", sync_op_or_fn="RR->SR") for i in range(config.num_hidden_layers): - subsch = sch[f"bert.encoder.layer.{i}.attention.self"] - trace_and_find_view(subsch) + subsch = sch[f"h.{i}.attn.attention"] + trace_and_find_view(subsch, config) ops = subsch.find_node( lambda node: node.op == "call_function" and node.target == torch.matmul ) assert len(ops) == 2 subsch.replace(new_matmul, ops[0]) subsch.replace(new_matmul_1, ops[1]) - sch[f"bert.encoder.layer.{config.num_hidden_layers - 1}.output.LayerNorm"].sync( - mode="fwd_post", sync_op_or_fn="SR->RR" - ) + + # Need to shard the tril matrix (causal mask) + def pattern(query, key, bias): + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = bias[ + :, :, key_length - query_length : key_length, :key_length + ] + return causal_mask + + ops = subsch.find(pattern) + subsch.replace(NewMask(), target_ops=[ops[-1]]) + sch[f"ln_f"].sync(mode="fwd_post", sync_op_or_fn="SR->RR") return sch @@ -226,15 +250,14 @@ def test_schemes(init_dist): # 1. Slapo-Megatron # RR x RS = RS, RS x SR = RR schs.append(scheme_megatron(copy.deepcopy(model), input_ids, config)) - sys.exit() # 2. Sequence-Parallel # RR->RS x RR = RS, RS x RR = RS->RR schs.append(scheme_sequence_parallel(copy.deepcopy(model), input_ids, config)) # 3. Activation-Stationary # RR x RS = RS - schs.append(scheme_activation_stationary(copy.deepcopy(model), input_ids, config)) - # 4. Activation Sharding. SR x RR = SR - schs.append(scheme_activation_sharding(copy.deepcopy(model), input_ids, config)) + # schs.append(scheme_activation_stationary(copy.deepcopy(model), input_ids, config)) + # # 4. Activation Sharding. SR x RR = SR + # schs.append(scheme_activation_sharding(copy.deepcopy(model), input_ids, config)) return schs @@ -242,7 +265,7 @@ def test_schemes(init_dist): # Create parser parser = argparse.ArgumentParser(description="Resharding schemes on GPTNeo") # Add arguments - parser.add_argument("--bs", type=int, help="Batch size", default=2) + parser.add_argument("--bs", type=int, help="Batch size", default=4) parser.add_argument("--seq", type=int, help="Sequence length", default=1024) # Parse the arguments args = parser.parse_args() From 5a66148e17a9b6aa4134e78585b3966b345db7f8 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 30 May 2023 02:56:28 +0000 Subject: [PATCH 33/36] Add TODO --- slapo/sharding/solver.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 39303458..c7d04d01 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -156,7 +156,8 @@ def calculate_comm_cost_z3(self): class ViewOp(FxOp): """ - # TODO: verify the behavior of general view function + # TODO: 1. Verify the behavior of general view function + # 2. Support merging two dimensions Only certain view functions can be sharded without communication. Currently only reshaping the *last* dimension is supported. @@ -350,6 +351,7 @@ def calculate_comm_cost_z3(self): F.relu: ElementwiseOp, F.gelu: ElementwiseOp, torch.tensor: PlaceholderOp, + # FIXME: three operands, need to ensure specs are the same torch.where: ElementwiseOp, torch.pow: ElementwiseOp, torch.tanh: ElementwiseOp, From 5ccac8b3e6906fbc928e481792fa5ec07e2b4a55 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 30 May 2023 03:22:21 +0000 Subject: [PATCH 34/36] Add gpt_attn unit test --- tests/autoshard/test_autoshard.py | 43 ++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/tests/autoshard/test_autoshard.py b/tests/autoshard/test_autoshard.py index c355f6ea..3a4b9733 100644 --- a/tests/autoshard/test_autoshard.py +++ b/tests/autoshard/test_autoshard.py @@ -55,7 +55,7 @@ def test_mlp(): assert max_cost == (bs * seq_len * hidden_size / p + 1) -def test_attn(): +def test_bert_attn(): from transformers import BertLMHeadModel, AutoConfig import inspect @@ -80,11 +80,46 @@ def test_attn(): ) logger.info(subsch.mod.graph, ranks=0) + seq_len = 512 sol = Solver(subsch.mod, p=p) - _, max_cost = sol.solve([torch.randn(bs, seq_len, hidden_size)]) - assert max_cost == 3 * (bs * seq_len * hidden_size / p) + 4 + _, max_cost = sol.solve([torch.randn(bs, seq_len, config.hidden_size)]) + assert max_cost == 3 * (bs * seq_len * config.hidden_size / p) + 4 + + +def test_gpt_attn(): + from transformers import GPTNeoModel, AutoConfig + import inspect + + config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-1.3B") + # config.use_cache = False + with slapo.init_empty_weights(): + model = GPTNeoModel(config) + logger.info(config, ranks=0) + + sch = slapo.create_schedule(model) + input_names = ["hidden_states"] + i = 0 + subsch = sch[f"h.{i}"] + sig = inspect.signature(subsch.mod.forward) + concrete_args = { + p.name: p.default for p in sig.parameters.values() if p.name not in input_names + } + subsch.trace( + recursive=False, + flatten=True, + tracer="huggingface", + concrete_args=concrete_args, + config=config, + ) + logger.info(subsch.mod.graph, ranks=0) + + seq_len = 1024 + sol = Solver(subsch.mod, p=p) + _, max_cost = sol.solve([torch.randn(bs, seq_len, config.hidden_size)]) + assert max_cost == 3 * (bs * seq_len * config.hidden_size // p) + 3 if __name__ == "__main__": test_mlp() - test_attn() + test_bert_attn() + test_gpt_attn() From 8d92d72bc5db74d628924fa8c94789b458957141 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 30 May 2023 03:22:36 +0000 Subject: [PATCH 35/36] Rename --- tests/autoshard/{test_autoshard.py => test_solver.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/autoshard/{test_autoshard.py => test_solver.py} (100%) diff --git a/tests/autoshard/test_autoshard.py b/tests/autoshard/test_solver.py similarity index 100% rename from tests/autoshard/test_autoshard.py rename to tests/autoshard/test_solver.py From 1f165a24c5f9f42070fce3ef2db988d395e2608e Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 30 May 2023 05:56:31 +0000 Subject: [PATCH 36/36] Update bs --- tests/autoshard/test_bert.py | 15 ++++++++++----- tests/autoshard/test_gpt.py | 7 ++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/autoshard/test_bert.py b/tests/autoshard/test_bert.py index 810bf72e..7aafd81a 100644 --- a/tests/autoshard/test_bert.py +++ b/tests/autoshard/test_bert.py @@ -17,13 +17,14 @@ logger = get_logger(__name__) # Config for verification -bs = 8 +bs = 4 seq_len = 512 def perf_model(mod, input_tensor): """Measure the performance of a mod with certain resharding schemes""" # warmup + mod.eval() for _ in range(10): mod(input_tensor) @@ -71,7 +72,8 @@ def new_view(tensor, args): def scheme_megatron(model, input_ids, config): sch = slapo.create_schedule(model) - with slapo.Verify(sch, [input_ids]): + enable = True if input_ids.shape[0] <= 4 else False + with slapo.Verify(sch, [input_ids], enable=enable): for i in range(config.num_hidden_layers): # shard attention subsch = sch[f"bert.encoder.layer.{i}.attention.self"] @@ -109,7 +111,8 @@ def new_matmul(lhs, rhs): def new_matmul_1(lhs, rhs): return torch.matmul(lhs, reshard_SR_to_RR(rhs, sch.group)) - with slapo.Verify(sch, [input_ids], eval_mode=True, enable=True): + enable = True if input_ids.shape[0] <= 4 else False + with slapo.Verify(sch, [input_ids], enable=enable): sch["bert.embeddings.LayerNorm"].sync(mode="fwd_post", sync_op_or_fn="RR->SR") for i in range(config.num_hidden_layers): subsch = sch[f"bert.encoder.layer.{i}.attention.self"] @@ -129,7 +132,8 @@ def new_matmul_1(lhs, rhs): def scheme_activation_stationary(model, input_ids, config): sch = slapo.create_schedule(model) - with slapo.Verify(sch, [input_ids]): + enable = True if input_ids.shape[0] <= 4 else False + with slapo.Verify(sch, [input_ids], enable=enable): for i in range(config.num_hidden_layers): # shard attention subsch = sch[f"bert.encoder.layer.{i}.attention.self"] @@ -168,7 +172,8 @@ def reshard_and_add(dropout, hidden_states): reshard_hidden_states = reshard_RR_to_SR(hidden_states, sch.group) return dropout + reshard_hidden_states - with slapo.Verify(sch, [input_ids]): + enable = True if input_ids.shape[0] <= 4 else False + with slapo.Verify(sch, [input_ids], enable=enable): for i in range(config.num_hidden_layers): # shard attention subsch = sch[f"bert.encoder.layer.{i}.attention.self"] diff --git a/tests/autoshard/test_gpt.py b/tests/autoshard/test_gpt.py index 049905d3..c3f8dce2 100644 --- a/tests/autoshard/test_gpt.py +++ b/tests/autoshard/test_gpt.py @@ -19,13 +19,15 @@ logger = get_logger(__name__) # Config for verification -bs = 8 +bs = 4 seq_len = 1024 def perf_model(mod, input_tensor): """Measure the performance of a mod with certain resharding schemes""" # warmup + mod.eval() + # mod.to(torch.float16) for _ in range(10): mod(input_tensor) @@ -283,11 +285,10 @@ def test_schemes(init_dist): ranks=0, ) - schs = test_schemes(None) - input_ids = torch.ones( bs, seq_len, dtype=torch.long, device=f"cuda:{dist.get_rank()}" ) + schs = test_schemes(None) for i, sch in enumerate(schs): mod, _ = slapo.build(sch, init_weights=sch.mod._init_weights) mod.to(f"cuda:{dist.get_rank()}")