Skip to content

Commit

Permalink
Merge pull request #418 from DeepLink-org/zcx/dynamic_shape
Browse files Browse the repository at this point in the history
refact dynamic_shape
  • Loading branch information
jinminxi104 authored Nov 14, 2023
2 parents 6d835b3 + edc5b6b commit 30c4709
Showing 1 changed file with 39 additions and 96 deletions.
135 changes: 39 additions & 96 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
import numpy as np
import torch.fx.traceback as fx_traceback
from torch._subclasses import FakeTensor
import dicp.vendor.AscendGraph.ascend_op as ascend_op
from dicp.vendor.AscendGraph.codegen.utils import (
symint_in_shape,
Expand Down Expand Up @@ -95,14 +96,29 @@ def generate_sym_int(elem):
# concat all ops
return self.get_proxy(ascend_op.ConcatD, (x_names, 0))

def get_shape_proxy(self, shape):
if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor):
return shape
elif isinstance(shape, list) and symint_in_shape(shape):
return self.process_dynamic_shape(shape)
else:
return self.get_proxy(
ascend_op.Const, (shape, torch.int32, [len(shape)]))

def get_param_proxy(self, param, type, target_shape):
if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor):
param = param if isinstance(param, list) else [param]
param = self.get_proxy(
ascend_op.Const, (param, type, [len(param)]))
shape_op = self.get_shape_proxy(target_shape)
param = self.get_proxy(ascend_op.BroadcastTo, (param, shape_op))
return param

def mul_scalar(self, x, y):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype
y_op = self.get_proxy(ascend_op.Const, ([y], const_dtype, [1]))
y_shape = list(x.node.meta['val'].shape)
if symint_in_shape(y_shape):
y_shape_op = self.process_dynamic_shape(y_shape)
y_op = self.get_proxy(ascend_op.BroadcastTo, (y_op, y_shape_op))
y_op = self.get_param_proxy(y, const_dtype, y_shape)
if out_dtype == torch.float16:
y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16"))
return self.get_proxy(ascend_op.Mul, (x, y_op))
Expand Down Expand Up @@ -140,13 +156,9 @@ def mul(self, x, y):
y_dtype = y.node.meta['val'].dtype
# handling with broadcasting cases
if np.prod(x_shape) < np.prod(y_shape):
if symint_in_shape(y_shape):
y_shape_op = self.process_dynamic_shape(y_shape)
x = self.get_proxy(ascend_op.BroadcastTo, (x, y_shape_op))
x = self.get_param_proxy(x, None, y_shape)
elif np.prod(x_shape) > np.prod(y_shape):
if symint_in_shape(x_shape):
x_shape_op = self.process_dynamic_shape(x_shape)
y = self.get_proxy(ascend_op.BroadcastTo, (y, x_shape_op))
y = self.get_param_proxy(y, None, x_shape)
if x_dtype != out_dtype:
x = self.get_proxy(
ascend_op.Cast, (x, get_ascend_dtype(out_dtype)), {})
Expand Down Expand Up @@ -237,11 +249,8 @@ def div(self, x, y):
assert y != 0
out_dtype = fx_traceback.get_current_meta()['val'].dtype
const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype
y_op = self.get_proxy(ascend_op.Const, ([y], const_dtype, []))
y_shape = list(x.node.meta['val'].shape)
if symint_in_shape(y_shape):
y_shape_op = self.process_dynamic_shape(y_shape)
y_op = self.get_proxy(ascend_op.BroadcastTo, (y_op, y_shape_op))
y_op = self.get_param_proxy(y, const_dtype, y_shape)
if out_dtype == torch.float16:
y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16"), {})
return self.get_proxy(ascend_op.Div, (x, y_op), {})
Expand All @@ -258,17 +267,8 @@ def slice(self, x, dim=0, start=None, end=None, step=1):
assert start >= 0 and start < x_shape[dim]
offset = [0] * len(x_shape)
offset[dim] = start
if symint_in_shape(offset):
offset = self.process_dynamic_shape(offset)
else:
offset = self.get_proxy(
ascend_op.Const, (offset, torch.int32, [len(offset)]))
size = None
if symint_in_shape(y_shape):
size = self.process_dynamic_shape(y_shape)
else:
size = self.get_proxy(
ascend_op.Const, (y_shape, torch.int32, [len(y_shape)]))
offset = self.get_shape_proxy(offset)
size = self.get_shape_proxy(y_shape)
return self.get_proxy(ascend_op.Slice, (x, offset, size))

@register_conversion(aten.bernoulli.p)
Expand Down Expand Up @@ -319,23 +319,10 @@ def select(self, x, dim, index):
size.append(v - offset[i])
else:
size.append(end - offset[i])

if symint_in_shape(offset):
offset = self.process_dynamic_shape(offset)
else:
offset = self.get_proxy(
ascend_op.Const, (offset, torch.int32, [len(offset)]))
if symint_in_shape(size):
size = self.process_dynamic_shape(size)
else:
size = self.get_proxy(
ascend_op.Const, (size, torch.int32, [len(size)]))
offset = self.get_shape_proxy(offset)
size = self.get_shape_proxy(size)
slice = self.get_proxy(ascend_op.Slice, (x, offset, size))
if symint_in_shape(y_shape):
y_shape = self.process_dynamic_shape(y_shape)
else:
y_shape = self.get_proxy(
ascend_op.Const, (y_shape, torch.int32, [len(y_shape)]))
y_shape = self.get_shape_proxy(y_shape)
return self.get_proxy(ascend_op.Reshape, (slice, y_shape))

@register_conversion(_operator.add)
Expand Down Expand Up @@ -382,10 +369,7 @@ def view(self, x, size):
raise RuntimeError(
"cannot handle with both negative and symint!")
shape = real_shape
if symint_in_shape(shape):
shape = self.process_dynamic_shape(shape)
else:
shape = self.get_proxy(ascend_op.Const, (shape, torch.int32))
shape = self.get_shape_proxy(shape)
if x.node.meta["val"].dtype == torch.complex64:
real = self.get_proxy(ascend_op.Identity, (x, 0))
imag = self.get_proxy(ascend_op.Identity, (x, 1))
Expand Down Expand Up @@ -434,28 +418,16 @@ def eq(self, a, b):
if not isinstance(b, torch.fx.proxy.Proxy):
assert isinstance(b, int)
b_shape = list(a.node.meta['val'].shape)
scalar_op = self.get_proxy(ascend_op.Const, (b, torch.int64))
if symint_in_shape(b_shape):
b_shape = self.process_dynamic_shape(b_shape)
else:
b_shape = self.get_proxy(
ascend_op.Const, (b_shape, torch.int32))
b = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, b_shape))
b = self.get_param_proxy(b, torch.int64, b_shape)
return self.get_proxy(ascend_op.Equal, (a, b))

@register_conversion([aten.lt.Scalar, aten.lt.Tensor])
def lt(self, x, y):
if not isinstance(y, torch.fx.proxy.Proxy):
x_dtype = x.node.meta['val'].dtype
const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype
scalar_op = self.get_proxy(ascend_op.Const, (y, const_dtype))
y_shape = list(x.node.meta['val'].shape)
if symint_in_shape(y_shape):
y_shape_op = self.process_dynamic_shape(y_shape)
else:
y_shape_op = self.get_proxy(
ascend_op.Const, (y_shape, torch.int32))
y = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, y_shape_op))
y = self.get_param_proxy(y, const_dtype, y_shape)
if x_dtype == torch.float16:
y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16"))
return self.get_proxy(ascend_op.Less, (x, y))
Expand All @@ -466,13 +438,8 @@ def masked_fill(self, x, mask, value):
const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype
if str(value) == "-inf":
value = -3.4028235e38
scalar_op = self.get_proxy(ascend_op.Const, (value, const_dtype))
mask_shape = list(mask.node.meta['val'].shape)
if symint_in_shape(mask_shape):
value = self.process_dynamic_shape(mask_shape)
else:
value = self.get_proxy(ascend_op.Const, (mask_shape, torch.int32))
value = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, value))
value = self.get_param_proxy(value, const_dtype, mask_shape)
if x_dtype == torch.float16:
value = self.get_proxy(ascend_op.Cast, (value, "FLOAT16"))
return self.get_proxy(ascend_op.MaskedFill, (x, mask, value))
Expand All @@ -482,12 +449,7 @@ def scatter(self, var, dim, index, value):
assert isinstance(dim, int)
index_shape = list(index.node.meta['val'].shape)
if isinstance(value, torch.fx.proxy.Proxy):
preprocess = None
if symint_in_shape(index_shape):
preprocess = self.process_dynamic_shape(index_shape)
else:
preprocess = self.get_proxy(
ascend_op.Const, (index_shape, torch.int32))
preprocess = self.get_shape_proxy(index_shape)
value = self.get_proxy(ascend_op.Reshape, (value, preprocess))
else:
out_dtype = fx_traceback.get_current_meta()['val'].dtype
Expand Down Expand Up @@ -558,11 +520,7 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
dim.node, 'meta') else dim for dim in dims]
if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'):
value = value.node.meta['val']
if symint_in_shape(dims):
dims = self.process_dynamic_shape(dims)
else:
dims = self.get_proxy(
ascend_op.Const, (dims, torch.int32, [len(dims)]))
dims = self.get_shape_proxy(dims)
value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, []))
return self.get_proxy(ascend_op.Fill, (dims, value))

Expand Down Expand Up @@ -739,11 +697,8 @@ def expand(self, x, shape):
x = self.get_proxy(ascend_op.Cast, (x, "INT32"))
shape = [dim.meta['val'] if hasattr(
dim, 'meta') else dim for dim in shape]
if isinstance(shape, list) and symint_in_shape(shape):
preprocess_shape = self.process_dynamic_shape(shape)
return self.get_proxy(ascend_op.Expand, (x, preprocess_shape))
else:
return self.get_proxy(ascend_op.ExpandD, (x, shape))
shape = self.get_shape_proxy(shape)
return self.get_proxy(ascend_op.ExpandD, (x, shape))

@register_conversion(torch.ops.aten.slice_backward.default)
def slice_backward(self, grad, input_shape, dim, start, end, step):
Expand Down Expand Up @@ -795,12 +750,7 @@ def maximum(self, a, b):
a_shape = list(a.node.meta['val'].shape)
b_shape = list(b.node.meta['val'].shape)
if np.prod(b_shape) < np.prod(a_shape):
if symint_in_shape(a_shape):
a_shape = self.process_dynamic_shape(a_shape)
else:
a_shape = self.get_proxy(
ascend_op.Const, (a_shape, torch.int32, [len(a_shape)]))
b = self.get_proxy(ascend_op.BroadcastTo, (b, a_shape))
b = self.get_param_proxy(b, None, a_shape)
if a.node.meta['val'].dtype == torch.float16:
b = self.get_proxy(ascend_op.Cast, (b, "FLOAT16"))
return self.get_proxy(ascend_op.Maximum, (a, b))
Expand All @@ -813,11 +763,7 @@ def common_process_scalar(self, x, y):
need_cast = True
y = self.get_proxy(ascend_op.Const, (y, x_dtype))
y_shape = list(x.node.meta['val'].shape)
if symint_in_shape(y_shape):
shape_preprocess = self.process_dynamic_shape(y_shape)
else:
shape_preprocess = self.get_proxy(
ascend_op.Const, (y_shape, torch.int32))
shape_preprocess = self.get_shape_proxy(y_shape)
y = self.get_proxy(ascend_op.BroadcastTo, (y, shape_preprocess))
if need_cast:
y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16"))
Expand All @@ -844,10 +790,7 @@ def transpose(self, input, dim0, dim1):
perm = [num for num in range(rank)]
perm[dim0] = dim1
perm[dim1] = dim0
if symint_in_shape(perm):
perm = self.process_dynamic_shape(perm)
else:
perm = self.get_proxy(ascend_op.Const, (perm, torch.int32))
perm = self.get_shape_proxy(perm)
return self.get_proxy(ascend_op.Transpose, (input, perm))

@register_conversion(torch.ops.aten.convolution)
Expand Down

0 comments on commit 30c4709

Please sign in to comment.