diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index a0480656..214d3c78 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -156,6 +156,13 @@ def calculate_comm_cost_z3(self): class ViewOp(FxOp): # TODO: verify the behavior of general view function + # (bs,seq,d) -> (bs,seq,h,d//h) + def __init__(self, node, z3_graph, p): + super().__init__(node) + self.z3_graph = z3_graph + self.num_devices = p + self.prev_op = self.z3_graph[self.node.args[0].name] + def generate_input_z3(self): self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2)) format_constraints = [z3.ULE(self.z3_inputs[0], 3)] @@ -166,8 +173,17 @@ def generate_output_z3(self): def calculate_comm_cost_z3(self): # `view` can redistribute the dimensions, thus can be used to - # convert to any spec without communication - return 0 + # convert to most of the specs without communication, + # but be careful about RS->RR, which requires an all-gather + result = z3.If( + z3.And( + self.prev_op.generate_output_z3() == ShardSpec("RS").id, + self.z3_inputs[0] == ShardSpec("RR").id, + ), + self.out_size * 1 / self.num_devices, + 0, + ) + return result class PermuteOp(FxOp): @@ -395,7 +411,7 @@ def construct_z3_graph(self): elif node.op == "call_method": # pylint: disable=redefined-variable-type if node.target == "view": - new_op = ViewOp(node) + new_op = ViewOp(node, self.z3_graph, self.num_devices) elif node.target == "permute": new_op = PermuteOp(node, self.z3_graph) elif node.target == "transpose":