Skip to content

Commit

Permalink
Partly revert back.
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 committed Nov 14, 2023
1 parent deb8c4f commit cf43da6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def Cast(name, x, ascend_dtype, device=None):


@staticmethod
def CastCpu(name, x, ascend_dtype):
def CastCpu(name, x, ascend_dtype, device='cpu'):
cast_op = OP(name, "Cast")
cast_op.set_input("x", x)
cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype))
Expand Down
2 changes: 1 addition & 1 deletion dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def add_scalar(self, x, y):
def _to_copy(self, x, dtype=None, layout=torch.strided, device=None):
if dtype:
if device == torch.device(type='cpu'):
return self.get_proxy(ascend_op.CastCpu, (x, get_ascend_dtype(dtype)))
return self.get_proxy(ascend_op.CastCpu, (x, get_ascend_dtype(dtype), device))
else:
return self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype)))
else:
Expand Down
8 changes: 5 additions & 3 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ def transform(self, gm: torch.fx.graph_module):
if n.op != 'call_function':
continue
if type(n.target) == CastCpu:
self.cpu_tensor.append(n.name)
elif type(n.target) == IdentityInp and len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names:
self.assign_args.append((n.name, input_names.index(str(n.args[1]))))
if len(n.args) == 3 and n.args[2] is not None and n.args[2] == torch.device(type='cpu'):
self.cpu_tensor.append(n.name)
elif type(n.target) == IdentityInp:
if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names:
self.assign_args.append((n.name, input_names.index(str(n.args[1]))))

for n in gm.graph.nodes:
if n.op == 'call_function':
Expand Down

0 comments on commit cf43da6

Please sign in to comment.