Skip to content

Commit

Permalink
Fix ViewOp
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 committed May 26, 2023
1 parent 9a326ef commit 0b1de58
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions slapo/sharding/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ def calculate_comm_cost_z3(self):

class ViewOp(FxOp):
# TODO: verify the behavior of general view function
# (bs,seq,d) -> (bs,seq,h,d//h)
def __init__(self, node, z3_graph, p):
super().__init__(node)
self.z3_graph = z3_graph
self.num_devices = p
self.prev_op = self.z3_graph[self.node.args[0].name]

def generate_input_z3(self):
self.z3_inputs.append(z3.BitVec(f"{self.name}_0", 2))
format_constraints = [z3.ULE(self.z3_inputs[0], 3)]
Expand All @@ -166,8 +173,17 @@ def generate_output_z3(self):

def calculate_comm_cost_z3(self):
# `view` can redistribute the dimensions, thus can be used to
# convert to any spec without communication
return 0
# convert to most of the specs without communication,
# but be careful about RS->RR, which requires an all-gather
result = z3.If(
z3.And(
self.prev_op.generate_output_z3() == ShardSpec("RS").id,
self.z3_inputs[0] == ShardSpec("RR").id,
),
self.out_size * 1 / self.num_devices,
0,
)
return result


class PermuteOp(FxOp):
Expand Down Expand Up @@ -395,7 +411,7 @@ def construct_z3_graph(self):
elif node.op == "call_method":
# pylint: disable=redefined-variable-type
if node.target == "view":
new_op = ViewOp(node)
new_op = ViewOp(node, self.z3_graph, self.num_devices)
elif node.target == "permute":
new_op = PermuteOp(node, self.z3_graph)
elif node.target == "transpose":
Expand Down

0 comments on commit 0b1de58

Please sign in to comment.