From 87591402cab74533a1768a30a850ddfa88eede77 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Tue, 23 May 2023 06:52:36 +0000 Subject: [PATCH] Fix pylint --- slapo/sharding/solver.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/slapo/sharding/solver.py b/slapo/sharding/solver.py index 715e8554..04bf3bc1 100644 --- a/slapo/sharding/solver.py +++ b/slapo/sharding/solver.py @@ -192,7 +192,7 @@ def __init__(self, node, z3_graph): self.z3_graph = z3_graph permute_idx = list(node.args[1:]) self.output_map = {} - for in_spec in ["RR", "RS", "SR"]: + for in_spec in ("RR", "RS", "SR"): spec = "R" * (len(permute_idx) - 2) + in_spec out_spec = spec[-2:] self.output_map[in_spec] = out_spec @@ -243,7 +243,7 @@ def calculate_comm_cost_z3(self): class MatmulOp(FxOp): - def __init__(self, node, mod=None, is_linear=False): + def __init__(self, node): super().__init__(node) self.output_map = {"RR": "RS", "RS": "RR", "SR": "SR"} self.comm_cost_map = { # map from input spec to comm cost @@ -374,7 +374,7 @@ def calculate_reshard_cost_z3(self, prev, curr, shape): for in_spec, target_map in self.reshard_cost_map.items(): tmp = 1e12 # invalid for out_spec, val in target_map.items(): - if in_spec == "RR" and out_spec in ["RS", "SR"]: + if in_spec == "RR" and out_spec in {"RS", "SR"}: cost = 1 # add penalty for splitting cost else: cost = int(val * shape) @@ -393,11 +393,7 @@ def construct_z3_graph(self): elif node.op == "call_module": mod = self.named_modules[node.target] if isinstance(mod, nn.Linear): - new_op = MatmulOp( - node, - mod=mod, - is_linear=True, - ) + new_op = MatmulOp(node) elif type(mod) in fx_op_map: new_op = fx_op_map[type(mod)](node) else: @@ -466,11 +462,10 @@ def dump_z3_graph(self, mod=None, dot_file="z3_graph.dot"): for i, arg in enumerate(op.args): if results is None: label = "" + elif op.name + "_" + str(i) not in results: + label = "" else: - if op.name + "_" + str(i) not in results: - label = "" - else: - label = f' [label="{ShardSpec(arg.generate_output(mod))}->{ShardSpec(results[op.name+"_"+str(i)])}"]' + label = f' [label="{ShardSpec(arg.generate_output(mod))}->{ShardSpec(results[op.name+"_"+str(i)])}"]' res += f" {arg.name} -> {op.name}{label};\n" res += "}" with open(dot_file, "w", encoding="utf-8") as f: