Skip to content

Commit

Permalink
Fix pylint
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 committed May 26, 2023
1 parent ef2c565 commit 8759140
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions slapo/sharding/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def __init__(self, node, z3_graph):
self.z3_graph = z3_graph
permute_idx = list(node.args[1:])
self.output_map = {}
for in_spec in ["RR", "RS", "SR"]:
for in_spec in ("RR", "RS", "SR"):
spec = "R" * (len(permute_idx) - 2) + in_spec
out_spec = spec[-2:]
self.output_map[in_spec] = out_spec
Expand Down Expand Up @@ -243,7 +243,7 @@ def calculate_comm_cost_z3(self):


class MatmulOp(FxOp):
def __init__(self, node, mod=None, is_linear=False):
def __init__(self, node):
super().__init__(node)
self.output_map = {"RR": "RS", "RS": "RR", "SR": "SR"}
self.comm_cost_map = { # map from input spec to comm cost
Expand Down Expand Up @@ -374,7 +374,7 @@ def calculate_reshard_cost_z3(self, prev, curr, shape):
for in_spec, target_map in self.reshard_cost_map.items():
tmp = 1e12 # invalid
for out_spec, val in target_map.items():
if in_spec == "RR" and out_spec in ["RS", "SR"]:
if in_spec == "RR" and out_spec in {"RS", "SR"}:
cost = 1 # add penalty for splitting cost
else:
cost = int(val * shape)
Expand All @@ -393,11 +393,7 @@ def construct_z3_graph(self):
elif node.op == "call_module":
mod = self.named_modules[node.target]
if isinstance(mod, nn.Linear):
new_op = MatmulOp(
node,
mod=mod,
is_linear=True,
)
new_op = MatmulOp(node)
elif type(mod) in fx_op_map:
new_op = fx_op_map[type(mod)](node)
else:
Expand Down Expand Up @@ -466,11 +462,10 @@ def dump_z3_graph(self, mod=None, dot_file="z3_graph.dot"):
for i, arg in enumerate(op.args):
if results is None:
label = ""
elif op.name + "_" + str(i) not in results:
label = ""
else:
if op.name + "_" + str(i) not in results:
label = ""
else:
label = f' [label="{ShardSpec(arg.generate_output(mod))}->{ShardSpec(results[op.name+"_"+str(i)])}"]'
label = f' [label="{ShardSpec(arg.generate_output(mod))}->{ShardSpec(results[op.name+"_"+str(i)])}"]'
res += f" {arg.name} -> {op.name}{label};\n"
res += "}"
with open(dot_file, "w", encoding="utf-8") as f:
Expand Down

0 comments on commit 8759140

Please sign in to comment.