diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 579616351..8297d22db 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -877,7 +877,7 @@ def Identity(name, input, index): return op.to_node() @staticmethod - def IdentityInp(name, input, dst): + def IdentityInp(name, input, dst=None): op = OP(name, "Identity") op.set_input("x", input) return op.to_node() @@ -997,7 +997,7 @@ def Cast(name, x, ascend_dtype, device=None): @staticmethod - def CastCpu(name, x, ascend_dtype, device='cpu'): + def CastCpu(name, x, ascend_dtype, device=None): cast_op = OP(name, "Cast") cast_op.set_input("x", x) cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype)) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index a7b3f64fb..7ac81ad92 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -923,7 +923,7 @@ def copy_(self, dst, src): @register_conversion(torch.ops.aten.copy) def copy(self, dst, src): - return self.get_proxy(ascend_op.Identity, (src, None)) + return self.get_proxy(ascend_op.IdentityInp, (src, dst)) @register_conversion(torch.ops.aten.unsqueeze) def unsqueeze(self, x, dim):