From edc5b6b9255fff769b22d4e865f061118cf195ef Mon Sep 17 00:00:00 2001 From: zhaochaoxing Date: Mon, 13 Nov 2023 13:44:49 +0000 Subject: [PATCH] refact dynamic_shape --- dicp/dicp/vendor/AscendGraph/conversion.py | 135 ++++++--------------- 1 file changed, 39 insertions(+), 96 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/conversion.py b/dicp/dicp/vendor/AscendGraph/conversion.py index 1d4214375..5fed2ed9b 100644 --- a/dicp/dicp/vendor/AscendGraph/conversion.py +++ b/dicp/dicp/vendor/AscendGraph/conversion.py @@ -10,6 +10,7 @@ ) import numpy as np import torch.fx.traceback as fx_traceback +from torch._subclasses import FakeTensor import dicp.vendor.AscendGraph.ascend_op as ascend_op from dicp.vendor.AscendGraph.codegen.utils import ( symint_in_shape, @@ -95,14 +96,29 @@ def generate_sym_int(elem): # concat all ops return self.get_proxy(ascend_op.ConcatD, (x_names, 0)) + def get_shape_proxy(self, shape): + if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor): + return shape + elif isinstance(shape, list) and symint_in_shape(shape): + return self.process_dynamic_shape(shape) + else: + return self.get_proxy( + ascend_op.Const, (shape, torch.int32, [len(shape)])) + + def get_param_proxy(self, param, type, target_shape): + if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor): + param = param if isinstance(param, list) else [param] + param = self.get_proxy( + ascend_op.Const, (param, type, [len(param)])) + shape_op = self.get_shape_proxy(target_shape) + param = self.get_proxy(ascend_op.BroadcastTo, (param, shape_op)) + return param + def mul_scalar(self, x, y): out_dtype = fx_traceback.get_current_meta()['val'].dtype const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype - y_op = self.get_proxy(ascend_op.Const, ([y], const_dtype, [1])) y_shape = list(x.node.meta['val'].shape) - if symint_in_shape(y_shape): - y_shape_op = self.process_dynamic_shape(y_shape) - y_op = self.get_proxy(ascend_op.BroadcastTo, (y_op, y_shape_op)) + y_op = self.get_param_proxy(y, const_dtype, y_shape) if out_dtype == torch.float16: y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16")) return self.get_proxy(ascend_op.Mul, (x, y_op)) @@ -140,13 +156,9 @@ def mul(self, x, y): y_dtype = y.node.meta['val'].dtype # handling with broadcasting cases if np.prod(x_shape) < np.prod(y_shape): - if symint_in_shape(y_shape): - y_shape_op = self.process_dynamic_shape(y_shape) - x = self.get_proxy(ascend_op.BroadcastTo, (x, y_shape_op)) + x = self.get_param_proxy(x, None, y_shape) elif np.prod(x_shape) > np.prod(y_shape): - if symint_in_shape(x_shape): - x_shape_op = self.process_dynamic_shape(x_shape) - y = self.get_proxy(ascend_op.BroadcastTo, (y, x_shape_op)) + y = self.get_param_proxy(y, None, x_shape) if x_dtype != out_dtype: x = self.get_proxy( ascend_op.Cast, (x, get_ascend_dtype(out_dtype)), {}) @@ -237,11 +249,8 @@ def div(self, x, y): assert y != 0 out_dtype = fx_traceback.get_current_meta()['val'].dtype const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype - y_op = self.get_proxy(ascend_op.Const, ([y], const_dtype, [])) y_shape = list(x.node.meta['val'].shape) - if symint_in_shape(y_shape): - y_shape_op = self.process_dynamic_shape(y_shape) - y_op = self.get_proxy(ascend_op.BroadcastTo, (y_op, y_shape_op)) + y_op = self.get_param_proxy(y, const_dtype, y_shape) if out_dtype == torch.float16: y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16"), {}) return self.get_proxy(ascend_op.Div, (x, y_op), {}) @@ -258,17 +267,8 @@ def slice(self, x, dim=0, start=None, end=None, step=1): assert start >= 0 and start < x_shape[dim] offset = [0] * len(x_shape) offset[dim] = start - if symint_in_shape(offset): - offset = self.process_dynamic_shape(offset) - else: - offset = self.get_proxy( - ascend_op.Const, (offset, torch.int32, [len(offset)])) - size = None - if symint_in_shape(y_shape): - size = self.process_dynamic_shape(y_shape) - else: - size = self.get_proxy( - ascend_op.Const, (y_shape, torch.int32, [len(y_shape)])) + offset = self.get_shape_proxy(offset) + size = self.get_shape_proxy(y_shape) return self.get_proxy(ascend_op.Slice, (x, offset, size)) @register_conversion(aten.bernoulli.p) @@ -319,23 +319,10 @@ def select(self, x, dim, index): size.append(v - offset[i]) else: size.append(end - offset[i]) - - if symint_in_shape(offset): - offset = self.process_dynamic_shape(offset) - else: - offset = self.get_proxy( - ascend_op.Const, (offset, torch.int32, [len(offset)])) - if symint_in_shape(size): - size = self.process_dynamic_shape(size) - else: - size = self.get_proxy( - ascend_op.Const, (size, torch.int32, [len(size)])) + offset = self.get_shape_proxy(offset) + size = self.get_shape_proxy(size) slice = self.get_proxy(ascend_op.Slice, (x, offset, size)) - if symint_in_shape(y_shape): - y_shape = self.process_dynamic_shape(y_shape) - else: - y_shape = self.get_proxy( - ascend_op.Const, (y_shape, torch.int32, [len(y_shape)])) + y_shape = self.get_shape_proxy(y_shape) return self.get_proxy(ascend_op.Reshape, (slice, y_shape)) @register_conversion(_operator.add) @@ -382,10 +369,7 @@ def view(self, x, size): raise RuntimeError( "cannot handle with both negative and symint!") shape = real_shape - if symint_in_shape(shape): - shape = self.process_dynamic_shape(shape) - else: - shape = self.get_proxy(ascend_op.Const, (shape, torch.int32)) + shape = self.get_shape_proxy(shape) if x.node.meta["val"].dtype == torch.complex64: real = self.get_proxy(ascend_op.Identity, (x, 0)) imag = self.get_proxy(ascend_op.Identity, (x, 1)) @@ -434,13 +418,7 @@ def eq(self, a, b): if not isinstance(b, torch.fx.proxy.Proxy): assert isinstance(b, int) b_shape = list(a.node.meta['val'].shape) - scalar_op = self.get_proxy(ascend_op.Const, (b, torch.int64)) - if symint_in_shape(b_shape): - b_shape = self.process_dynamic_shape(b_shape) - else: - b_shape = self.get_proxy( - ascend_op.Const, (b_shape, torch.int32)) - b = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, b_shape)) + b = self.get_param_proxy(b, torch.int64, b_shape) return self.get_proxy(ascend_op.Equal, (a, b)) @register_conversion([aten.lt.Scalar, aten.lt.Tensor]) @@ -448,14 +426,8 @@ def lt(self, x, y): if not isinstance(y, torch.fx.proxy.Proxy): x_dtype = x.node.meta['val'].dtype const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype - scalar_op = self.get_proxy(ascend_op.Const, (y, const_dtype)) y_shape = list(x.node.meta['val'].shape) - if symint_in_shape(y_shape): - y_shape_op = self.process_dynamic_shape(y_shape) - else: - y_shape_op = self.get_proxy( - ascend_op.Const, (y_shape, torch.int32)) - y = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, y_shape_op)) + y = self.get_param_proxy(y, const_dtype, y_shape) if x_dtype == torch.float16: y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16")) return self.get_proxy(ascend_op.Less, (x, y)) @@ -466,13 +438,8 @@ def masked_fill(self, x, mask, value): const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype if str(value) == "-inf": value = -3.4028235e38 - scalar_op = self.get_proxy(ascend_op.Const, (value, const_dtype)) mask_shape = list(mask.node.meta['val'].shape) - if symint_in_shape(mask_shape): - value = self.process_dynamic_shape(mask_shape) - else: - value = self.get_proxy(ascend_op.Const, (mask_shape, torch.int32)) - value = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, value)) + value = self.get_param_proxy(value, const_dtype, mask_shape) if x_dtype == torch.float16: value = self.get_proxy(ascend_op.Cast, (value, "FLOAT16")) return self.get_proxy(ascend_op.MaskedFill, (x, mask, value)) @@ -482,12 +449,7 @@ def scatter(self, var, dim, index, value): assert isinstance(dim, int) index_shape = list(index.node.meta['val'].shape) if isinstance(value, torch.fx.proxy.Proxy): - preprocess = None - if symint_in_shape(index_shape): - preprocess = self.process_dynamic_shape(index_shape) - else: - preprocess = self.get_proxy( - ascend_op.Const, (index_shape, torch.int32)) + preprocess = self.get_shape_proxy(index_shape) value = self.get_proxy(ascend_op.Reshape, (value, preprocess)) else: out_dtype = fx_traceback.get_current_meta()['val'].dtype @@ -558,11 +520,7 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided, dim.node, 'meta') else dim for dim in dims] if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'): value = value.node.meta['val'] - if symint_in_shape(dims): - dims = self.process_dynamic_shape(dims) - else: - dims = self.get_proxy( - ascend_op.Const, (dims, torch.int32, [len(dims)])) + dims = self.get_shape_proxy(dims) value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, [])) return self.get_proxy(ascend_op.Fill, (dims, value)) @@ -739,11 +697,8 @@ def expand(self, x, shape): x = self.get_proxy(ascend_op.Cast, (x, "INT32")) shape = [dim.meta['val'] if hasattr( dim, 'meta') else dim for dim in shape] - if isinstance(shape, list) and symint_in_shape(shape): - preprocess_shape = self.process_dynamic_shape(shape) - return self.get_proxy(ascend_op.Expand, (x, preprocess_shape)) - else: - return self.get_proxy(ascend_op.ExpandD, (x, shape)) + shape = self.get_shape_proxy(shape) + return self.get_proxy(ascend_op.ExpandD, (x, shape)) @register_conversion(torch.ops.aten.slice_backward.default) def slice_backward(self, grad, input_shape, dim, start, end, step): @@ -795,12 +750,7 @@ def maximum(self, a, b): a_shape = list(a.node.meta['val'].shape) b_shape = list(b.node.meta['val'].shape) if np.prod(b_shape) < np.prod(a_shape): - if symint_in_shape(a_shape): - a_shape = self.process_dynamic_shape(a_shape) - else: - a_shape = self.get_proxy( - ascend_op.Const, (a_shape, torch.int32, [len(a_shape)])) - b = self.get_proxy(ascend_op.BroadcastTo, (b, a_shape)) + b = self.get_param_proxy(b, None, a_shape) if a.node.meta['val'].dtype == torch.float16: b = self.get_proxy(ascend_op.Cast, (b, "FLOAT16")) return self.get_proxy(ascend_op.Maximum, (a, b)) @@ -813,11 +763,7 @@ def common_process_scalar(self, x, y): need_cast = True y = self.get_proxy(ascend_op.Const, (y, x_dtype)) y_shape = list(x.node.meta['val'].shape) - if symint_in_shape(y_shape): - shape_preprocess = self.process_dynamic_shape(y_shape) - else: - shape_preprocess = self.get_proxy( - ascend_op.Const, (y_shape, torch.int32)) + shape_preprocess = self.get_shape_proxy(y_shape) y = self.get_proxy(ascend_op.BroadcastTo, (y, shape_preprocess)) if need_cast: y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16")) @@ -844,10 +790,7 @@ def transpose(self, input, dim0, dim1): perm = [num for num in range(rank)] perm[dim0] = dim1 perm[dim1] = dim0 - if symint_in_shape(perm): - perm = self.process_dynamic_shape(perm) - else: - perm = self.get_proxy(ascend_op.Const, (perm, torch.int32)) + perm = self.get_shape_proxy(perm) return self.get_proxy(ascend_op.Transpose, (input, perm)) @register_conversion(torch.ops.aten.convolution)