Skip to content

Commit

Permalink
A small fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
pdx1989 committed Nov 14, 2023
1 parent 56e052b commit deb8c4f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,7 +997,7 @@ def Cast(name, x, ascend_dtype, device=None):


@staticmethod
def CastCpu(name, x, ascend_dtype, device='cpu'):
def CastCpu(name, x, ascend_dtype):
cast_op = OP(name, "Cast")
cast_op.set_input("x", x)
cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype))
Expand Down
8 changes: 3 additions & 5 deletions dicp/dicp/vendor/AscendGraph/opset_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@ def transform(self, gm: torch.fx.graph_module):
if n.op != 'call_function':
continue
if type(n.target) == CastCpu:
if len(n.args) == 3 and n.args[2] is not None and n.args[2] == torch.device(type='cpu'):
self.cpu_tensor.append(n.name)
elif type(n.target) == IdentityInp:
if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names:
self.assign_args.append((n.name, input_names.index(str(n.args[1]))))
self.cpu_tensor.append(n.name)
elif type(n.target) == IdentityInp and len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names:
self.assign_args.append((n.name, input_names.index(str(n.args[1]))))

for n in gm.graph.nodes:
if n.op == 'call_function':
Expand Down

0 comments on commit deb8c4f

Please sign in to comment.