diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index ac00be03b..6b4ae3c21 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -245,11 +245,21 @@ def __init__(self): super().__init__("Cast") +class CastCpu(Operator): + def __init__(self): + super().__init__("CastCpu") + + class Identity(Operator): def __init__(self): super().__init__("Identity") +class IdentityInp(Operator): + def __init__(self): + super().__init__("IdentityInp") + + class IdentityN(Operator): def __init__(self): super().__init__("IdentityN") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index c526faf80..579616351 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -876,6 +876,12 @@ def Identity(name, input, index): op.set_input("x", input) return op.to_node() + @staticmethod + def IdentityInp(name, input, dst): + op = OP(name, "Identity") + op.set_input("x", input) + return op.to_node() + @staticmethod def Exp(name, x): op = OP(name, "Exp") @@ -989,6 +995,15 @@ def Cast(name, x, ascend_dtype, device=None): cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype)) return cast_op.to_node() + + @staticmethod + def CastCpu(name, x, ascend_dtype, device='cpu'): + cast_op = OP(name, "Cast") + cast_op.set_input("x", x) + cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype)) + return cast_op.to_node() + + @staticmethod def Const(name, x, dtype, dims=None, format="ND"): if not isinstance(x, list): diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 726aecb65..a9cfa885e 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -214,7 +214,10 @@ def add_scalar(self, x, y): @register_conversion(torch.ops.aten._to_copy.default) def _to_copy(self, x, dtype=None, layout=torch.strided, device=None): if dtype: - return self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype), device)) + if device == torch.device(type='cpu'): + return self.get_proxy(ascend_op.CastCpu, (x, get_ascend_dtype(dtype))) + else: + return self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype))) else: return self.get_proxy(ascend_op.Identity, (x, None)) @@ -411,20 +414,20 @@ def where(self, condition, x1, x2): @register_conversion(aten.arange.default) def arange(self, end, start=0, step=1, dtype=None, device='xpu', layout=None, pin_memory=False): out_dtype = fx_traceback.get_current_meta()['val'].dtype - assert isinstance(start, torch.fx.proxy.Proxy) or isinstance(start, int) - assert isinstance(end, torch.fx.proxy.Proxy) or isinstance(end, int) - assert isinstance(step, torch.fx.proxy.Proxy) or isinstance(step, int) - - if isinstance(start, int): - start = self.get_proxy(ascend_op.Const, (int(start), out_dtype)) - elif start.node.meta['val'] != out_dtype: + assert isinstance(start, torch.fx.proxy.Proxy) or type(start) in [int, float] + assert isinstance(end, torch.fx.proxy.Proxy) or type(end) in [int, float] + assert isinstance(step, torch.fx.proxy.Proxy) or type(step) in [int, float] + + if not isinstance(start, torch.fx.proxy.Proxy): # scalar const + start = self.get_proxy(ascend_op.Const, (start, out_dtype)) + elif start.node.meta['val'] != out_dtype: # align tensor dtype start = self.get_proxy(ascend_op.Cast, (start, get_ascend_dtype(out_dtype)), {}) - if isinstance(end, int): - end = self.get_proxy(ascend_op.Const, (int(end), out_dtype)) + if not isinstance(end, torch.fx.proxy.Proxy): + end = self.get_proxy(ascend_op.Const, (end, out_dtype)) elif end.node.meta['val'] != out_dtype: end = self.get_proxy(ascend_op.Cast, (end, get_ascend_dtype(out_dtype)), {}) - if isinstance(step, int): - step = self.get_proxy(ascend_op.Const, (int(step), out_dtype)) + if not isinstance(step, torch.fx.proxy.Proxy): + step = self.get_proxy(ascend_op.Const, (step, out_dtype)) elif step.node.meta['val'] != out_dtype: step = self.get_proxy(ascend_op.Cast, (step, get_ascend_dtype(out_dtype)), {}) return self.get_proxy(ascend_op.Range, (end, start, step)) @@ -916,11 +919,11 @@ def clone(self, a, memory_format=torch.contiguous_format): @register_conversion(torch.ops.aten.copy_) def copy_(self, dst, src): - return self.get_proxy(ascend_op.Identity, (src, dst)) + return self.get_proxy(ascend_op.IdentityInp, (src, dst)) @register_conversion(torch.ops.aten.copy) def copy(self, dst, src): - return self.get_proxy(ascend_op.Identity, (src, dst)) + return self.get_proxy(ascend_op.Identity, (src, None)) @register_conversion(torch.ops.aten.unsqueeze) def unsqueeze(self, x, dim): diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index 732701f8f..1e4dd7e0a 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -1,6 +1,6 @@ import torch from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer -from dicp.vendor.AscendGraph.ascend_op import MatMul, Cast, Identity +from dicp.vendor.AscendGraph.ascend_op import MatMul, CastCpu, IdentityInp from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer from dicp.vendor.AscendGraph.pattern_replacement import ( ascend_pattern_matcher, @@ -37,10 +37,10 @@ def transform(self, gm: torch.fx.graph_module): for n in gm.graph.nodes: if n.op != 'call_function': continue - if type(n.target) == Cast: + if type(n.target) == CastCpu: if len(n.args) == 3 and n.args[2] is not None and n.args[2] == torch.device(type='cpu'): self.cpu_tensor.append(n.name) - elif type(n.target) == Identity: + elif type(n.target) == IdentityInp: if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: self.assign_args.append((n.name, input_names.index(str(n.args[1]))))