diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 9e7c3c67c..579616351 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -997,7 +997,7 @@ def Cast(name, x, ascend_dtype, device=None): @staticmethod - def CastCpu(name, x, ascend_dtype): + def CastCpu(name, x, ascend_dtype, device='cpu'): cast_op = OP(name, "Cast") cast_op.set_input("x", x) cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype)) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index a9cfa885e..a7b3f64fb 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -215,7 +215,7 @@ def add_scalar(self, x, y): def _to_copy(self, x, dtype=None, layout=torch.strided, device=None): if dtype: if device == torch.device(type='cpu'): - return self.get_proxy(ascend_op.CastCpu, (x, get_ascend_dtype(dtype))) + return self.get_proxy(ascend_op.CastCpu, (x, get_ascend_dtype(dtype), device)) else: return self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype))) else: diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index 34d8d0ec9..1e4dd7e0a 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -38,9 +38,11 @@ def transform(self, gm: torch.fx.graph_module): if n.op != 'call_function': continue if type(n.target) == CastCpu: - self.cpu_tensor.append(n.name) - elif type(n.target) == IdentityInp and len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: - self.assign_args.append((n.name, input_names.index(str(n.args[1])))) + if len(n.args) == 3 and n.args[2] is not None and n.args[2] == torch.device(type='cpu'): + self.cpu_tensor.append(n.name) + elif type(n.target) == IdentityInp: + if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: + self.assign_args.append((n.name, input_names.index(str(n.args[1])))) for n in gm.graph.nodes: if n.op == 'call_function':