Skip to content

Commit

Permalink
Change ascend_op naming policy.
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 committed Nov 14, 2023
1 parent e13878b commit 56e052b
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 17 deletions.
10 changes: 10 additions & 0 deletions dicp/dicp/vendor/AscendGraph/ascend_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,21 @@ def __init__(self):
super().__init__("Cast")


class CastCpu(Operator):
def __init__(self):
super().__init__("CastCpu")


class Identity(Operator):
def __init__(self):
super().__init__("Identity")


class IdentityInp(Operator):
def __init__(self):
super().__init__("IdentityInp")


class IdentityN(Operator):
def __init__(self):
super().__init__("IdentityN")
Expand Down
15 changes: 15 additions & 0 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,12 @@ def Identity(name, input, index):
op.set_input("x", input)
return op.to_node()

@staticmethod
def IdentityInp(name, input, dst):
op = OP(name, "Identity")
op.set_input("x", input)
return op.to_node()

@staticmethod
def Exp(name, x):
op = OP(name, "Exp")
Expand Down Expand Up @@ -989,6 +995,15 @@ def Cast(name, x, ascend_dtype, device=None):
cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype))
return cast_op.to_node()


@staticmethod
def CastCpu(name, x, ascend_dtype, device='cpu'):
cast_op = OP(name, "Cast")
cast_op.set_input("x", x)
cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype))
return cast_op.to_node()


@staticmethod
def Const(name, x, dtype, dims=None, format="ND"):
if not isinstance(x, list):
Expand Down
31 changes: 17 additions & 14 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def add_scalar(self, x, y):
@register_conversion(torch.ops.aten._to_copy.default)
def _to_copy(self, x, dtype=None, layout=torch.strided, device=None):
if dtype:
return self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype), device))
if device == torch.device(type='cpu'):
return self.get_proxy(ascend_op.CastCpu, (x, get_ascend_dtype(dtype)))
else:
return self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype)))
else:
return self.get_proxy(ascend_op.Identity, (x, None))

Expand Down Expand Up @@ -411,20 +414,20 @@ def where(self, condition, x1, x2):
@register_conversion(aten.arange.default)
def arange(self, end, start=0, step=1, dtype=None, device='xpu', layout=None, pin_memory=False):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
assert isinstance(start, torch.fx.proxy.Proxy) or isinstance(start, int)
assert isinstance(end, torch.fx.proxy.Proxy) or isinstance(end, int)
assert isinstance(step, torch.fx.proxy.Proxy) or isinstance(step, int)
if isinstance(start, int):
start = self.get_proxy(ascend_op.Const, (int(start), out_dtype))
elif start.node.meta['val'] != out_dtype:
assert isinstance(start, torch.fx.proxy.Proxy) or type(start) in [int, float]
assert isinstance(end, torch.fx.proxy.Proxy) or type(end) in [int, float]
assert isinstance(step, torch.fx.proxy.Proxy) or type(step) in [int, float]

if not isinstance(start, torch.fx.proxy.Proxy): # scalar const
start = self.get_proxy(ascend_op.Const, (start, out_dtype))
elif start.node.meta['val'] != out_dtype: # align tensor dtype
start = self.get_proxy(ascend_op.Cast, (start, get_ascend_dtype(out_dtype)), {})
if isinstance(end, int):
end = self.get_proxy(ascend_op.Const, (int(end), out_dtype))
if not isinstance(end, torch.fx.proxy.Proxy):
end = self.get_proxy(ascend_op.Const, (end, out_dtype))
elif end.node.meta['val'] != out_dtype:
end = self.get_proxy(ascend_op.Cast, (end, get_ascend_dtype(out_dtype)), {})
if isinstance(step, int):
step = self.get_proxy(ascend_op.Const, (int(step), out_dtype))
if not isinstance(step, torch.fx.proxy.Proxy):
step = self.get_proxy(ascend_op.Const, (step, out_dtype))
elif step.node.meta['val'] != out_dtype:
step = self.get_proxy(ascend_op.Cast, (step, get_ascend_dtype(out_dtype)), {})
return self.get_proxy(ascend_op.Range, (end, start, step))
Expand Down Expand Up @@ -916,11 +919,11 @@ def clone(self, a, memory_format=torch.contiguous_format):

@register_conversion(torch.ops.aten.copy_)
def copy_(self, dst, src):
return self.get_proxy(ascend_op.Identity, (src, dst))
return self.get_proxy(ascend_op.IdentityInp, (src, dst))

@register_conversion(torch.ops.aten.copy)
def copy(self, dst, src):
return self.get_proxy(ascend_op.Identity, (src, dst))
return self.get_proxy(ascend_op.Identity, (src, None))

@register_conversion(torch.ops.aten.unsqueeze)
def unsqueeze(self, x, dim):
Expand Down
6 changes: 3 additions & 3 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer
from dicp.vendor.AscendGraph.ascend_op import MatMul, Cast, Identity
from dicp.vendor.AscendGraph.ascend_op import MatMul, CastCpu, IdentityInp
from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer
from dicp.vendor.AscendGraph.pattern_replacement import (
ascend_pattern_matcher,
Expand Down Expand Up @@ -37,10 +37,10 @@ def transform(self, gm: torch.fx.graph_module):
for n in gm.graph.nodes:
if n.op != 'call_function':
continue
if type(n.target) == Cast:
if type(n.target) == CastCpu:
if len(n.args) == 3 and n.args[2] is not None and n.args[2] == torch.device(type='cpu'):
self.cpu_tensor.append(n.name)
elif type(n.target) == Identity:
elif type(n.target) == IdentityInp:
if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names:
self.assign_args.append((n.name, input_names.index(str(n.args[1]))))

Expand Down

0 comments on commit 56e052b

Please sign in to comment.