From 2429afb78c1dea7b9c1dd68c24f2bbc3d6476deb Mon Sep 17 00:00:00 2001 From: pdx1989 Date: Thu, 16 Nov 2023 11:17:24 +0800 Subject: [PATCH] [dicp][ascend] Fix hf llama-inference precision. (#414) --- dicp/dicp/dynamo_bridge/op_transformer.py | 10 + dicp/dicp/vendor/AscendGraph/ascend_op.py | 14 +- .../dicp/vendor/AscendGraph/codegen/ascend.py | 211 ++++++++++++++---- .../AscendGraph/codegen/load_and_run.py | 26 +-- dicp/dicp/vendor/AscendGraph/conversion.py | 97 +++++--- dicp/dicp/vendor/AscendGraph/opset_convert.py | 60 ++++- 6 files changed, 324 insertions(+), 94 deletions(-) diff --git a/dicp/dicp/dynamo_bridge/op_transformer.py b/dicp/dicp/dynamo_bridge/op_transformer.py index 23d506348..a9166849f 100644 --- a/dicp/dicp/dynamo_bridge/op_transformer.py +++ b/dicp/dicp/dynamo_bridge/op_transformer.py @@ -6,6 +6,7 @@ from torch.fx.proxy import Proxy from typing import Any, Dict, Tuple from dicp.dynamo_bridge.compile_fx import is_torch_210 +from dicp.vendor.AscendGraph.codegen.utils import symint_in_shape class OpSetTransformer: @@ -24,6 +25,7 @@ def __init__(self, module, conversions): super().__init__(module) self._conversions = conversions self.sym_to_inputs = {} + self.sym_in_args = {} def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Proxy: proxy = super().placeholder(target, args, kwargs) @@ -31,6 +33,14 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict fake_tensor = proxy.node.meta['val'] if isinstance(fake_tensor, torch.SymInt): self.sym_to_inputs[fake_tensor.node.str()] = proxy + elif symint_in_shape(fake_tensor.shape): + # mention symint position in args + # dynamic shape feature + for idx, dim in enumerate(fake_tensor.shape): + if isinstance(dim, torch.SymInt): + st = dim.node.str() + if not st in self.sym_in_args: + self.sym_in_args[st] = (proxy, idx) return proxy def get_proxy(self, target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] = {}): diff --git a/dicp/dicp/vendor/AscendGraph/ascend_op.py b/dicp/dicp/vendor/AscendGraph/ascend_op.py index 18f6ebb34..e7f0f11ca 100644 --- a/dicp/dicp/vendor/AscendGraph/ascend_op.py +++ b/dicp/dicp/vendor/AscendGraph/ascend_op.py @@ -40,7 +40,7 @@ def __init__(self): super().__init__("Range") -class CumSum(Operator): +class Cumsum(Operator): def __init__(self): super().__init__("Cumsum") @@ -170,7 +170,7 @@ def __init__(self): super().__init__("TopK") -class ScatterElement(Operator): +class ScatterElements(Operator): def __init__(self): super().__init__("ScatterElements") @@ -245,11 +245,21 @@ def __init__(self): super().__init__("Cast") +class CastToCpu(Operator): + def __init__(self): + super().__init__("CastToCpu") + + class Identity(Operator): def __init__(self): super().__init__("Identity") +class IdentityInp(Operator): + def __init__(self): + super().__init__("IdentityInp") + + class IdentityN(Operator): def __init__(self): super().__init__("IdentityN") diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index 51c3a23ca..e68d2d523 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -16,8 +16,6 @@ precision_check = bool(os.environ.get("DICP_ASCEND_PRECISION_CHECK", False)) -sym_to_inputs = {} - def get_graph_id(): global graph_id @@ -62,14 +60,22 @@ def __init__(self, graph, aten_graph=None, folder=None, graph_key=None): self.graph_output_names = [] self.build_options = [] - global sym_to_inputs - sym_to_inputs = {} + self.folder = folder + self.graph_key = graph_key + + self.sym_to_inputs = {} + self.sym_in_args = {} + + # for modified args return + self.assign_args = [] + self.cpu_tensor = [] super().__init__(graph) def placeholder(self, name, target, args, kwargs): self.args_dict[name] = name self.input_args.append(self.cur_node) + fake_tensor = self.cur_node.meta['val'] format = "NCHW" @@ -79,8 +85,16 @@ def placeholder(self, name, target, args, kwargs): dims = [1] data_type = "INT32" format = "ND" - sym_to_inputs[fake_tensor.node.str()] = name + self.sym_to_inputs[fake_tensor.node.str()] = name elif symint_in_shape(fake_tensor.shape): + # mention symint position in args + # dynamic shape feature + for idx, dim in enumerate(fake_tensor.shape): + if isinstance(dim, torch.SymInt): + st = dim.node.str() + if not st in self.sym_in_args: + self.sym_in_args[st] = (name, idx) + # deal with dynamic shape -1 shape = [-1 if isinstance(elem, torch.SymInt) else elem for elem in fake_tensor.shape] @@ -114,6 +128,12 @@ def call_function(self, name, target, args, kwargs): if name not in self.args_dict.keys(): self.args_dict[name] = name + if hasattr(self.cur_node, 'meta'): + if 'prop' in self.cur_node.meta and 'cpu_tensor' in self.cur_node.meta['prop']: + self.cpu_tensor.append(self.cur_node.meta['prop']['cpu_tensor']) + if 'prop' in self.cur_node.meta and 'assign_args' in self.cur_node.meta['prop']: + self.assign_args.append(self.cur_node.meta['prop']['assign_args']) + _, args_list = AscendOverrides.gen_args( self.args_dict[name], self.args_dict, args) real_op = process_name(name, target) @@ -156,7 +176,7 @@ def codegen(self): return self.generate_code() def parse_outputs(self): - symint_inputs = sym_to_inputs.values() + symint_inputs = self.sym_to_inputs.values() for node in self.output_args: if isinstance(node, torch.fx.node.Node): name = self.args_dict[node.name] @@ -170,6 +190,9 @@ def parse_outputs(self): else: self.py_output_names.append(str(node)) + if len(self.assign_args) > 0: + self.graph_output_names.extend(list(zip(*self.assign_args))[0]) + def gen_import_code(self): self.import_code.splice( """ @@ -193,30 +216,71 @@ def check_tensor(a, b, atol=5e-2, rtol=1e-2): return self.import_code.getvalue() def process_sym_name(self, st): + # dynamic shape feature if st.isdigit(): return st elif '+' in st: sp = st.split('+') + if len(sp) > 2: + sp = [sp[0], '+'.join(sp[1:])] assert (len(sp) == 2) sp = [elem.strip() for elem in sp] - return sym_to_inputs[sp[0]] + '+' + sp[1] + if sp[0].isdigit(): + (sp[1], sp[0]) = (sp[0], sp[1]) + if sp[0] in self.sym_in_args: + arg, idx = self.sym_in_args[sp[0]] + return "{}.shape[{}]".format(arg, idx) + '+' + sp[1] + if sp[0] in self.sym_to_inputs.keys(): + return self.sym_to_inputs[sp[0]] + '+' + sp[1] + else: + return self.process_sym_name(sp[0]) + '+' + sp[1] elif '-' in st: sp = st.split('-') + if len(sp) > 2: + sp = [sp[0], '-'.join(sp[1:])] assert (len(sp) == 2) sp = [elem.strip() for elem in sp] - return sym_to_inputs[sp[0]] + '-' + sp[1] + if sp[0] in self.sym_in_args: + arg, idx = self.sym_in_args[sp[0]] + return "{}.shape[{}]".format(arg, idx) + '-' + sp[1] + if sp[0] in self.sym_to_inputs.keys(): + return self.sym_to_inputs[sp[0]] + '-' + sp[1] + else: + return self.process_sym_name(sp[0]) + '-' + sp[1] + elif '*' in st: + sp = st.split('*') + if len(sp) > 2: + sp = [sp[0], '*'.join(sp[1:])] + assert (len(sp) == 2) + sp = [elem.strip() for elem in sp] + if sp[0].isdigit(): + (sp[1], sp[0]) = (sp[0], sp[1]) + if sp[0] in self.sym_in_args: + arg, idx = self.sym_in_args[sp[0]] + return "{}.shape[{}]".format(arg, idx) + '*' + sp[1] + if sp[0] in self.sym_to_inputs.keys(): + return self.sym_to_inputs[sp[0]] + '*' + sp[1] + else: + return self.process_sym_name(sp[0]) + '*' + sp[1] else: - return sym_to_inputs[st] + if st in self.sym_in_args: + arg, idx = self.sym_in_args[st] + return "{}.shape[{}]".format(arg, idx) + return self.sym_to_inputs[st] def gen_call_func(self): # TODO check scalar input call_body = IndentedBuffer() self.args = [self.args_dict[x.name] for x in self.input_args] + shape_symint = [value[0] for value in self.sym_in_args.values()] - if len(self.dynamic_inputs) > 0: - args = ['_' if arg not in sym_to_inputs.values() - else arg for arg in self.args] + # dynamic shape feature + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + args = ['_' if not arg in shape_symint and not arg in self.sym_to_inputs.values() else arg for arg in self.args] call_body.writeline(f"({','.join(args)}) = args") + + # generate input dims + if len(self.dynamic_inputs) > 0: dim_len = 0 for shape in self.actual_shape: dim_len += len(shape) @@ -229,32 +293,63 @@ def gen_call_func(self): ":[" + ','.join(map(str, elem)) + '],' dims = dims[:-1] + '}' call_body.writeline(dims) - - shape_str = '''output_shape = [''' + else: + call_body.writeline(f'''dims = None''') + + # generate output shapes + # dynamic shape feature + extra_stride_str = '' + extra_storage_offset_str = '' + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + shape_str = f'''output_shape = [''' for elem in self.output_args: if hasattr(elem, 'meta'): elem = elem.meta['val'] - elem = list(elem.shape) - if len(elem) == 0: + if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool): + shape_str += '[1],' + continue + shape = list(elem.shape) + if len(shape) == 0: raise RuntimeError("Error handling empty output_shape") - elem = [self.process_sym_name(str(dim)) for dim in elem] - shape_str += "[" + ','.join(map(str, elem)) + '],' - shape_str = shape_str[:-1] + ''']''' + shape = [self.process_sym_name(str(dim)) for dim in shape] + shape_str += "[" + ','.join(map(str, shape)) + "]," + + # process output_shape with modified args + for elem in self.assign_args: + shape = list(self.input_args[elem[1]].meta['val'].shape) + if len(shape) == 0: + raise RuntimeError("Error handling empty output_shape") + shape = [self.process_sym_name(str(dim)) for dim in shape] + shape_str += "[" + ','.join(map(str, shape)) + "]," + stride = list(self.input_args[elem[1]].meta['val'].stride()) + if len(stride) == 0: + raise RuntimeError("Error handling empty output_stride") + stride = [self.process_sym_name(str(dim)) for dim in stride] + extra_stride_str += '[' + ','.join(map(str, stride)) + '],' + extra_storage_offset_str += str(self.input_args[elem[1]].meta['val'].storage_offset()) + ',' + shape_str = shape_str[:-1] + f''']''' call_body.writeline(shape_str) else: - call_body.writeline('''dims = None''') call_body.writeline('''output_shape = None''') - + + # add stride & storage_offset info out_stride_str = '''out_stride = [''' out_storage_offset_str = '''out_storage_offset = [''' for elem in self.output_args: if hasattr(elem, 'meta'): elem = elem.meta['val'] + if isinstance(elem, torch.SymInt) or isinstance(elem, torch.SymBool): + out_stride_str += '[1],' + out_storage_offset_str += '0,' + continue stride = list(elem.stride()) if len(stride) == 0: raise RuntimeError("Error handling empty output_stride") + stride = [self.process_sym_name(str(dim)) for dim in stride] out_stride_str += '[' + ','.join(map(str, stride)) + '],' out_storage_offset_str += str(elem.storage_offset()) + ',' + out_stride_str += extra_stride_str + out_storage_offset_str += extra_storage_offset_str out_stride_str = out_stride_str[:-1] + ']' out_storage_offset_str = out_storage_offset_str[:-1] + ']' call_body.writeline(out_stride_str) @@ -266,40 +361,50 @@ def gen_call_func(self): for idx in range(len(args)): if isinstance(args[idx], int): args[idx] = torch.tensor(args[idx], device=dipu_device_str, dtype=torch.int32) + if isinstance(args[idx], torch.Tensor): + tmp_arg = args[idx].clone() + with torch.no_grad(): + args[idx].copy_(tmp_arg) + del tmp_arg """, strip=True) call_body.writeline(f"({','.join(self.args)}) = args") call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset)'] if precision_check and self.aten_graph is not None: - # 1. export aten graph to disk - def get_unique_str(): - uuid_str = str(uuid.uuid1()) - return 'm' + str(uuid.uuid3(uuid.NAMESPACE_DNS, uuid_str + str(os.getpid()))) - module_path = get_unique_str().replace('-', '') - folder_path = "/tmp/dicp_debug/aten_modules/" + module_path - os.system(f"mkdir -p {folder_path}") - self.aten_graph.to_folder(folder_path) - - # 2. import aten graph - call_str.append(f'from aten_modules.{module_path} import FxModule') - call_str.append('aten_call = FxModule()') - call_str.append('aten_output = aten_call(*args)') + # import aten graph + call_str.append(f"import sys") + call_str.append(f"if '{self.folder}' not in sys.path:") + call_str.append(f" sys.path.insert(0, '{self.folder}')") + call_str.append(f"from {self.graph_key[:4]} import {self.graph_key} as graph_module") + call_str.append(f"aten_call = graph_module()") + + call_str.append('aten_args = list(map(lambda x: x.to("cpu"), args))') + call_str.append('for idx in modified:') + call_str.append(' aten_args[idx] = aten_args[idx].item()') + call_str.append('aten_output = aten_call(*aten_args)') for i, name in enumerate(self.graph_output_names): if name not in self.symint_outputs: - call_str.append(f'{name} = output_tensor[{i}]') + if name in self.cpu_tensor: + call_str.append(f'{name} = output_tensor[{i}].cpu()') + else: + call_str.append(f'{name} = output_tensor[{i}]') else: call_str.extend([f'del {name}', f'{name} = int(output_tensor[{i}])']) + + # dealing with modified args passing back + output_convert = [f'args[{name[1]}].copy_({name[0]})' for name in self.assign_args] + call_str.extend(output_convert) + if precision_check: for i, name in enumerate(self.py_output_names): if name != 'None' and name not in self.args and name not in self.symint_outputs: call_str.append(f"{name}_cpu = aten_output[{i}]") - call_str.append(f"check_tensor({name}, {name}_cpu)") + call_str.append(f"check_tensor({name}.cpu(), {name}_cpu)") call_body.writelines(call_str) - del_args = [ - 'del ' + x for x in self.args if x not in self.py_output_names] + del_args = [f'del {x}' for x in self.args if x not in self.py_output_names] call_body.writelines(del_args) call_body.writeline("args.clear()") call_body.writeline(f"return ({', '.join(self.py_output_names)})") @@ -384,7 +489,7 @@ def remove_symint(self, cur): def gen_graph_json(self): self.parse_outputs() self.gen_build_options() - has_dynamic_shape = False if len(sym_to_inputs) == 0 else True + has_dynamic_shape = False if len(self.sym_in_args) == 0 and len(self.sym_to_inputs) == 0 else True graph = { "name": "graph", "input_names": self.graph_input_names, @@ -672,10 +777,10 @@ def DivNoNan(name, x1, x2): @staticmethod def Select(name, cond, x1, x2): - op = OP(name, "Select") + op = OP(name, "SelectV2") op.set_input("condition", cond) - op.set_input("x1", x1) - op.set_input("x2", x2) + op.set_input("then", x1) + op.set_input("else", x2) return op.to_node() @staticmethod @@ -765,12 +870,18 @@ def Squeeze(name, x, dim): @staticmethod def Identity(name, input, index): op = OP(name, "Identity") - if index is not None: + if index is not None and isinstance(index, int): op.set_input_with_index("x", input, index) else: op.set_input("x", input) return op.to_node() + @staticmethod + def IdentityInp(name, input, dst=None): + op = OP(name, "Identity") + op.set_input("x", input) + return op.to_node() + @staticmethod def Exp(name, x): op = OP(name, "Exp") @@ -878,7 +989,14 @@ def Fill(name, dims, value): return op.to_node() @staticmethod - def Cast(name, x, ascend_dtype): + def Cast(name, x, ascend_dtype, device=None): + cast_op = OP(name, "Cast") + cast_op.set_input("x", x) + cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype)) + return cast_op.to_node() + + @staticmethod + def CastToCpu(name, x, ascend_dtype, device=None): cast_op = OP(name, "Cast") cast_op.set_input("x", x) cast_op.set_attr_int("dst_type", get_ascend_dtype_num(ascend_dtype)) @@ -904,7 +1022,7 @@ def BroadcastTo(name, x, shape): return broadcast_op.to_node() @staticmethod - def Empty(name, shape, dtype, layout, device): + def Empty(name, shape, dtype, layout=torch.strided, device='cpu'): dtype = get_ascend_dtype_num(get_ascend_dtype(dtype)) op = OP(name, "Empty") op.set_input("shape", shape) @@ -923,6 +1041,7 @@ def Sort(name, x, dim, descending): op.set_input("x", x) op.set_attr_int("axis", dim) op.set_attr_bool("descending", descending) + op.set_attr_int("_keep_dtype", 1) return op.to_node() @staticmethod @@ -933,6 +1052,7 @@ def TopK(name, x, k, dim, largest, sorted): op.set_attr_int("dim", dim) op.set_attr_bool("largest", largest) op.set_attr_bool("sorted", sorted) + op.set_attr_int("_keep_dtype", 1) return op.to_node() @staticmethod @@ -960,6 +1080,7 @@ def BatchMatMul(name, x1, x2, adj_x1: bool, adj_x2: bool): op.set_attr_bool("adj_x1", adj_x1) op.set_input("x2", x2) op.set_attr_bool("adj_x2", adj_x2) + op.set_attr_int("_keep_dtype", 1) return op.to_node() @staticmethod diff --git a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py index f702503bd..b57516e38 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py @@ -283,30 +283,32 @@ def _prepare_input(self, images, dims): check_ret("acl.mdl.set_dataset_tensor_desc", ret) assert (dataset == self.input_dataset) - def _prepare_output(self, output_tensor, out_stride, out_storage_offset): + def _prepare_output(self, output_tensor, output_shape, out_stride, out_storage_offset): for i in range(self.num_outputs): item = torch.empty( self.output_dims[i], dtype=self.output_dtypes[i], device=dipu_device_str) - item = item.as_strided( - self.output_dims[i], out_stride[i], out_storage_offset[i]) + # TODO! add case judgement for stride info + # item = item.as_strided( + # self.output_dims[i], out_stride[i], out_storage_offset[i]) output_tensor.append(item) ret = acl.update_data_buffer( self.output_data_buffers[i], item.data_ptr(), self.output_size[i]) check_ret("acl.update_data_buffer", ret) - def _prepare_dynamic_output(self, output_tensor, out_stride, out_storage_offset): + def _prepare_dynamic_output(self, output_tensor, output_shape, out_stride, out_storage_offset): for i in range(self.num_outputs): tot_size = 1 - for elem in self.output_shape[i]: + for elem in output_shape[i]: tot_size *= elem dtype = acl.mdl.get_output_data_type(self.model_desc, i) tot_size *= acl.data_type_size(dtype) - self.output_dims[i] = self.output_shape[i] + self.output_dims[i] = output_shape[i] self.output_size[i] = tot_size item = torch.empty( self.output_dims[i], dtype=self.output_dtypes[i], device=dipu_device_str) - item = item.as_strided( - self.output_dims[i], out_stride[i], out_storage_offset[i]) + # TODO! add case judgement for stride info + # item = item.as_strided( + # self.output_dims[i], out_stride[i], out_storage_offset[i]) output_tensor.append(item) ret = acl.update_data_buffer( @@ -314,17 +316,15 @@ def _prepare_dynamic_output(self, output_tensor, out_stride, out_storage_offset) check_ret("acl.update_data_buffer", ret) def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None): - self.output_shape = output_shape assert len(images) > 0 input = [x.to(dipu_device_str) if isinstance(x, torch.Tensor) and x.device.type != dipu_device_str else x for x in images] self._prepare_input(input, dims) output = [] - if dims is not None: - assert self.output_shape is not None - self._prepare_dynamic_output(output) + if output_shape: + self._prepare_dynamic_output(output, output_shape, out_stride, out_storage_offset) else: - self._prepare_output(output, out_stride, out_storage_offset) + self._prepare_output(output, output_shape, out_stride, out_storage_offset) self.forward() self._destroy_databuffer() return output diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 76e356b2f..72e14f28b 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -63,19 +63,38 @@ def generate_digits_op(shapes): def generate_sym_int(elem): elem = elem.node.str() elems = elem.strip().split(' ') + + arg = None + # dynamic shape feature + if elems[0] in self.sym_in_args: + arg, idx = self.sym_in_args[elems[0]] + shape = self.get_proxy(ascend_op.Shape, (arg,)) + axis = self.get_proxy( + ascend_op.Const, ([0], torch.int32, [1])) + indice = self.get_proxy( + ascend_op.Const, ([idx], torch.int32, [1])) + gather = self.get_proxy( + ascend_op.GatherV2, (shape, indice, axis)) + if len(elems) > 1: assert len(elems) == 3 assert elems[2].isdigit() assert elems[1] == '+' or elems[1] == '-' const_op = self.get_proxy( ascend_op.Const, ([int(elems[2])], torch.int32, [1])) - args = (self.sym_to_inputs[elems[0]], const_op) + if arg is not None: + args = (gather, const_op) + else: + args = (self.sym_to_inputs[elems[0]], const_op) if elems[1] == '+': x_names.append(self.get_proxy(ascend_op.Add, args)) else: x_names.append(self.get_proxy(ascend_op.Sub, args)) else: - x_names.append(self.sym_to_inputs[elems[0]]) + if arg is not None: + x_names.append(gather) + else: + x_names.append(self.sym_to_inputs[elems[0]]) dims = [] for elem in shape: @@ -186,16 +205,19 @@ def add(self, x, y, alpha: Optional[Number] = 1): if y_dtype != out_dtype: y = self.get_proxy( ascend_op.Cast, (y, get_ascend_dtype(out_dtype)), {}) - return self.get_proxy(ascend_op.Add, (x, y), {}) + return self.get_proxy(ascend_op.AddV2, (x, y), {}) @register_conversion(torch.ops.aten.add.Scalar) def add_scalar(self, x, y): return self.add(x, y) @register_conversion(torch.ops.aten._to_copy.default) - def _to_copy(self, x, dtype=None, layout=torch.strided, device='cpu'): + def _to_copy(self, x, dtype=None, layout=torch.strided, device=None): if dtype: - return self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype))) + if device == torch.device(type='cpu'): + return self.get_proxy(ascend_op.CastToCpu, (x, get_ascend_dtype(dtype))) + else: + return self.get_proxy(ascend_op.Cast, (x, get_ascend_dtype(dtype))) else: return self.get_proxy(ascend_op.Identity, (x, None)) @@ -240,7 +262,7 @@ def inge(self, x, y): if not isinstance(y, torch.fx.proxy.Proxy): assert isinstance(y, int) y = self.get_proxy(ascend_op.Const, ([y], torch.int32, [])) - return self.get_proxy(ascend_op.GreaterEqual, x, y) + return self.get_proxy(ascend_op.GreaterEqual, (x, y)) @register_conversion(aten.div) def div(self, x, y): @@ -391,22 +413,23 @@ def where(self, condition, x1, x2): @register_conversion(aten.arange.default) def arange(self, end, start=0, step=1, dtype=None, device='xpu', layout=None, pin_memory=False): - assert isinstance(start, str) or isinstance(start, int) - assert isinstance(end, str) or isinstance(end, int) - assert isinstance(step, str) or isinstance(step, int) - assert dtype is None or dtype == torch.int64 - if isinstance(start, str) and start.isdigit(): - start = int(start) - if isinstance(end, str) and end.isdigit(): - end = int(end) - if isinstance(step, str) and step.isdigit(): - step = int(step) - if isinstance(start, int): - start = self.get_proxy(ascend_op.Const, (int(start), torch.int64)) - if isinstance(end, int): - end = self.get_proxy(ascend_op.Const, (int(end), torch.int64)) - if isinstance(step, int): - step = self.get_proxy(ascend_op.Const, (int(step), torch.int64)) + out_dtype = fx_traceback.get_current_meta()['val'].dtype + assert isinstance(start, torch.fx.proxy.Proxy) or type(start) in [int, float] + assert isinstance(end, torch.fx.proxy.Proxy) or type(end) in [int, float] + assert isinstance(step, torch.fx.proxy.Proxy) or type(step) in [int, float] + + if not isinstance(start, torch.fx.proxy.Proxy): # scalar const + start = self.get_proxy(ascend_op.Const, (start, out_dtype)) + elif start.node.meta['val'] != out_dtype: # align tensor dtype + start = self.get_proxy(ascend_op.Cast, (start, get_ascend_dtype(out_dtype)), {}) + if not isinstance(end, torch.fx.proxy.Proxy): + end = self.get_proxy(ascend_op.Const, (end, out_dtype)) + elif end.node.meta['val'] != out_dtype: + end = self.get_proxy(ascend_op.Cast, (end, get_ascend_dtype(out_dtype)), {}) + if not isinstance(step, torch.fx.proxy.Proxy): + step = self.get_proxy(ascend_op.Const, (step, out_dtype)) + elif step.node.meta['val'] != out_dtype: + step = self.get_proxy(ascend_op.Cast, (step, get_ascend_dtype(out_dtype)), {}) return self.get_proxy(ascend_op.Range, (end, start, step)) @register_conversion(aten.arange.start) @@ -437,7 +460,7 @@ def masked_fill(self, x, mask, value): x_dtype = x.node.meta['val'].dtype const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype if str(value) == "-inf": - value = -3.4028235e38 + value = -3.4028234663852886e+38 mask_shape = list(mask.node.meta['val'].shape) value = self.get_param_proxy(value, const_dtype, mask_shape) if x_dtype == torch.float16: @@ -695,10 +718,13 @@ def expand(self, x, shape): return self.get_proxy(ascend_op.Identity, (x, None)) if x.node.meta['val'].dtype == torch.int64: x = self.get_proxy(ascend_op.Cast, (x, "INT32")) - shape = [dim.meta['val'] if hasattr( - dim, 'meta') else dim for dim in shape] - shape = self.get_shape_proxy(shape) - return self.get_proxy(ascend_op.ExpandD, (x, shape)) + shape = [dim.node.meta['val'] if hasattr( + dim, 'node') else dim for dim in shape] + if isinstance(shape, list) and symint_in_shape(shape): + preprocess_shape = self.process_dynamic_shape(shape) + return self.get_proxy(ascend_op.Expand, (x, preprocess_shape)) + else: + return self.get_proxy(ascend_op.ExpandD, (x, shape)) @register_conversion(torch.ops.aten.slice_backward.default) def slice_backward(self, grad, input_shape, dim, start, end, step): @@ -813,20 +839,29 @@ def convolution(self, input, weight, bias, stride, padding, @register_conversion(_operator.mul) def inmul(self, x, y): assert (not isinstance(y, torch.fx.proxy.Proxy)) - y = self.get_proxy(ascend_op.Const, ([y], torch.float32, [])) + y = self.get_proxy(ascend_op.Const, ([y], torch.int32, [])) return self.get_proxy(ascend_op.Mul, (x, y)) @register_conversion(torch.ops.aten.sym_size) def symsize(self, x, dim): dim = [dim] if not isinstance(dim, list) else dim + shape = self.get_proxy(ascend_op.Shape, (x,)) axis = self.get_proxy(ascend_op.Const, ([0], torch.int32, [1])) indices = self.get_proxy( ascend_op.Const, (dim, torch.int32, [len(dim)])) - return self.get_proxy(ascend_op.GatherV2, (x, indices, axis)) + return self.get_proxy(ascend_op.GatherV2, (shape, indices, axis)) @register_conversion(torch.ops.aten.mm.default) def mm(self, x, y): - return self.get_proxy(ascend_op.MatMul, (x, y, False, False)) + # TODO! MatMul not support fp32 input + # for higher precision in some cases + if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: + x = self.get_proxy(ascend_op.Unsqueeze, (x, [0])) + y = self.get_proxy(ascend_op.Unsqueeze, (y, [0])) + mm = self.get_proxy(ascend_op.BatchMatMul, (x, y, False, False)) + return self.get_proxy(ascend_op.Squeeze, (mm, [0])) + else: + return self.get_proxy(ascend_op.MatMul, (x, y, False, False)) @register_conversion(aten.bmm.default) def bmm(self, x, y): @@ -884,7 +919,7 @@ def clone(self, a, memory_format=torch.contiguous_format): @register_conversion(torch.ops.aten.copy_) def copy_(self, dst, src): - return self.get_proxy(ascend_op.Identity, (src, None)) + return self.get_proxy(ascend_op.IdentityInp, (src, dst)) @register_conversion(torch.ops.aten.copy) def copy(self, dst, src): diff --git a/dicp/dicp/vendor/AscendGraph/opset_convert.py b/dicp/dicp/vendor/AscendGraph/opset_convert.py index c42c8b9d4..9913055e7 100644 --- a/dicp/dicp/vendor/AscendGraph/opset_convert.py +++ b/dicp/dicp/vendor/AscendGraph/opset_convert.py @@ -1,6 +1,6 @@ import torch from dicp.dynamo_bridge.op_transformer import BackendPatternMatcherTransformer -from dicp.vendor.AscendGraph.ascend_op import MatMul +from dicp.vendor.AscendGraph.ascend_op import MatMul, CastToCpu, IdentityInp from dicp.vendor.AscendGraph.conversion import AtenToAscendTransformer from dicp.vendor.AscendGraph.pattern_replacement import ( ascend_pattern_matcher, @@ -22,13 +22,67 @@ def transform(self, gm: torch.fx.graph_module): return gm +class OutputMarkPass: + def __init__(self): + self.assign_args = [] + self.cpu_tensor = [] + + def transform(self, gm: torch.fx.graph_module): + # dynamic shape feature + input_names = [] + for n in gm.graph.nodes: + if n.op == 'placeholder': + input_names.append(n.name) + + for n in gm.graph.nodes: + if n.op != 'call_function': + continue + if type(n.target) == CastToCpu: + self.cpu_tensor.append(n.name) + elif type(n.target) == IdentityInp: + if len(n.args) == 2 and n.args[1] is not None and str(n.args[1]) in input_names: + self.assign_args.append((n.name, input_names.index(str(n.args[1])))) + else: + raise RuntimeError("Op inner copy_ error!") + + for n in gm.graph.nodes: + if n.op == 'call_function': + prop = {} + if n.name in self.cpu_tensor: + prop.update({'cpu_tensor' : n.name}) + if len(self.assign_args) > 0 and n.name in list(zip(*self.assign_args))[0]: + idx = list(zip(*self.assign_args))[0].index(n.name) + prop.update({'assign_args' : (self.assign_args[idx][0], self.assign_args[idx][1])}) + n.meta['prop'] = prop + return gm + + +def symint_in_inputs(nodes): + # dynamic shape feature + for node in nodes: + if node.op == 'placeholder': + if hasattr(node, 'meta'): + node = node.meta['val'] + if isinstance(node, torch.SymInt): + return True + if hasattr(node, 'shape'): + for dim in node.shape: + if isinstance(dim, torch.SymInt): + return True + return False + def ascendgraph_opset_convert( gm: torch.fx.GraphModule, ): gm = BackendPatternMatcherTransformer( ascend_pattern_matcher, aten_patterns_cls_list).transform(gm) gm = AtenToAscendTransformer(gm).transform() - gm = BackendPatternMatcherTransformer( - ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm) + + # For bug in pytorch + # Avoid for dynamic shape + if not symint_in_inputs(list(gm.graph.nodes)): + gm = BackendPatternMatcherTransformer( + ascend_pattern_matcher, ascend_patterns_cls_list).transform(gm) + gm = OutputMarkPass().transform(gm) # gm = ArgsTransDataPass().transform(gm) return gm