diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 579616351..9e7c3c67c 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -997,7 +997,7 @@ def Cast(name, x, ascend_dtype, device=None): @staticmethod - def CastCpu(name, x, ascend_dtype, device='cpu'): + def CastCpu(name, x, ascend_dtype): cast_op = OP(name, "Cast") cast_op.set_input("x", x) cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype)) diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index 1e4dd7e0a..34d8d0ec9 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -38,11 +38,9 @@ def transform(self, gm: torch.fx.graph_module): if n.op != 'call_function': continue if type(n.target) == CastCpu: - if len(n.args) == 3 and n.args[2] is not None and n.args[2] == torch.device(type='cpu'): - self.cpu_tensor.append(n.name) - elif type(n.target) == IdentityInp: - if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: - self.assign_args.append((n.name, input_names.index(str(n.args[1])))) + self.cpu_tensor.append(n.name) + elif type(n.target) == IdentityInp and len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: + self.assign_args.append((n.name, input_names.index(str(n.args[1])))) for n in gm.graph.nodes: if n.op == 'call_function':