Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refact dynamic_shape #418

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 39 additions & 96 deletions dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
import numpy as np
import torch.fx.traceback as fx_traceback
from torch._subclasses import FakeTensor
import dicp.vendor.AscendGraph.ascend_op as ascend_op
from dicp.vendor.AscendGraph.codegen.utils import (
symint_in_shape,
Expand Down Expand Up @@ -95,14 +96,29 @@ def generate_sym_int(elem):
# concat all ops
return self.get_proxy(ascend_op.ConcatD, (x_names, 0))

def get_shape_proxy(self, shape):
if isinstance(shape, torch.fx.proxy.Proxy) or isinstance(shape, FakeTensor):
return shape
elif isinstance(shape, list) and symint_in_shape(shape):
return self.process_dynamic_shape(shape)
else:
return self.get_proxy(
ascend_op.Const, (shape, torch.int32, [len(shape)]))

def get_param_proxy(self, param, type, target_shape):
if not isinstance(param, torch.fx.proxy.Proxy) and not isinstance(param, FakeTensor):
param = param if isinstance(param, list) else [param]
param = self.get_proxy(
ascend_op.Const, (param, type, [len(param)]))
shape_op = self.get_shape_proxy(target_shape)
param = self.get_proxy(ascend_op.BroadcastTo, (param, shape_op))
return param

def mul_scalar(self, x, y):
out_dtype = fx_traceback.get_current_meta()['val'].dtype
const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype
y_op = self.get_proxy(ascend_op.Const, ([y], const_dtype, [1]))
y_shape = list(x.node.meta['val'].shape)
if symint_in_shape(y_shape):
y_shape_op = self.process_dynamic_shape(y_shape)
y_op = self.get_proxy(ascend_op.BroadcastTo, (y_op, y_shape_op))
y_op = self.get_param_proxy(y, const_dtype, y_shape)
if out_dtype == torch.float16:
y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16"))
return self.get_proxy(ascend_op.Mul, (x, y_op))
Expand Down Expand Up @@ -140,13 +156,9 @@ def mul(self, x, y):
y_dtype = y.node.meta['val'].dtype
# handling with broadcasting cases
if np.prod(x_shape) < np.prod(y_shape):
if symint_in_shape(y_shape):
y_shape_op = self.process_dynamic_shape(y_shape)
x = self.get_proxy(ascend_op.BroadcastTo, (x, y_shape_op))
x = self.get_param_proxy(x, None, y_shape)
elif np.prod(x_shape) > np.prod(y_shape):
if symint_in_shape(x_shape):
x_shape_op = self.process_dynamic_shape(x_shape)
y = self.get_proxy(ascend_op.BroadcastTo, (y, x_shape_op))
y = self.get_param_proxy(y, None, x_shape)
if x_dtype != out_dtype:
x = self.get_proxy(
ascend_op.Cast, (x, get_ascend_dtype(out_dtype)), {})
Expand Down Expand Up @@ -237,11 +249,8 @@ def div(self, x, y):
assert y != 0
out_dtype = fx_traceback.get_current_meta()['val'].dtype
const_dtype = torch.float32 if out_dtype == torch.float16 else out_dtype
y_op = self.get_proxy(ascend_op.Const, ([y], const_dtype, []))
y_shape = list(x.node.meta['val'].shape)
if symint_in_shape(y_shape):
y_shape_op = self.process_dynamic_shape(y_shape)
y_op = self.get_proxy(ascend_op.BroadcastTo, (y_op, y_shape_op))
y_op = self.get_param_proxy(y, const_dtype, y_shape)
if out_dtype == torch.float16:
y_op = self.get_proxy(ascend_op.Cast, (y_op, "FLOAT16"), {})
return self.get_proxy(ascend_op.Div, (x, y_op), {})
Expand All @@ -258,17 +267,8 @@ def slice(self, x, dim=0, start=None, end=None, step=1):
assert start >= 0 and start < x_shape[dim]
offset = [0] * len(x_shape)
offset[dim] = start
if symint_in_shape(offset):
offset = self.process_dynamic_shape(offset)
else:
offset = self.get_proxy(
ascend_op.Const, (offset, torch.int32, [len(offset)]))
size = None
if symint_in_shape(y_shape):
size = self.process_dynamic_shape(y_shape)
else:
size = self.get_proxy(
ascend_op.Const, (y_shape, torch.int32, [len(y_shape)]))
offset = self.get_shape_proxy(offset)
size = self.get_shape_proxy(y_shape)
return self.get_proxy(ascend_op.Slice, (x, offset, size))

@register_conversion(aten.bernoulli.p)
Expand Down Expand Up @@ -319,23 +319,10 @@ def select(self, x, dim, index):
size.append(v - offset[i])
else:
size.append(end - offset[i])

if symint_in_shape(offset):
offset = self.process_dynamic_shape(offset)
else:
offset = self.get_proxy(
ascend_op.Const, (offset, torch.int32, [len(offset)]))
if symint_in_shape(size):
size = self.process_dynamic_shape(size)
else:
size = self.get_proxy(
ascend_op.Const, (size, torch.int32, [len(size)]))
offset = self.get_shape_proxy(offset)
size = self.get_shape_proxy(size)
slice = self.get_proxy(ascend_op.Slice, (x, offset, size))
if symint_in_shape(y_shape):
y_shape = self.process_dynamic_shape(y_shape)
else:
y_shape = self.get_proxy(
ascend_op.Const, (y_shape, torch.int32, [len(y_shape)]))
y_shape = self.get_shape_proxy(y_shape)
return self.get_proxy(ascend_op.Reshape, (slice, y_shape))

@register_conversion(_operator.add)
Expand Down Expand Up @@ -382,10 +369,7 @@ def view(self, x, size):
raise RuntimeError(
"cannot handle with both negative and symint!")
shape = real_shape
if symint_in_shape(shape):
shape = self.process_dynamic_shape(shape)
else:
shape = self.get_proxy(ascend_op.Const, (shape, torch.int32))
shape = self.get_shape_proxy(shape)
if x.node.meta["val"].dtype == torch.complex64:
real = self.get_proxy(ascend_op.Identity, (x, 0))
imag = self.get_proxy(ascend_op.Identity, (x, 1))
Expand Down Expand Up @@ -434,28 +418,16 @@ def eq(self, a, b):
if not isinstance(b, torch.fx.proxy.Proxy):
assert isinstance(b, int)
b_shape = list(a.node.meta['val'].shape)
scalar_op = self.get_proxy(ascend_op.Const, (b, torch.int64))
if symint_in_shape(b_shape):
b_shape = self.process_dynamic_shape(b_shape)
else:
b_shape = self.get_proxy(
ascend_op.Const, (b_shape, torch.int32))
b = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, b_shape))
b = self.get_param_proxy(b, torch.int64, b_shape)
return self.get_proxy(ascend_op.Equal, (a, b))

@register_conversion([aten.lt.Scalar, aten.lt.Tensor])
def lt(self, x, y):
if not isinstance(y, torch.fx.proxy.Proxy):
x_dtype = x.node.meta['val'].dtype
const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype
scalar_op = self.get_proxy(ascend_op.Const, (y, const_dtype))
y_shape = list(x.node.meta['val'].shape)
if symint_in_shape(y_shape):
y_shape_op = self.process_dynamic_shape(y_shape)
else:
y_shape_op = self.get_proxy(
ascend_op.Const, (y_shape, torch.int32))
y = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, y_shape_op))
y = self.get_param_proxy(y, const_dtype, y_shape)
if x_dtype == torch.float16:
y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16"))
return self.get_proxy(ascend_op.Less, (x, y))
Expand All @@ -466,13 +438,8 @@ def masked_fill(self, x, mask, value):
const_dtype = torch.float32 if x_dtype == torch.float16 else x_dtype
if str(value) == "-inf":
value = -3.4028235e38
scalar_op = self.get_proxy(ascend_op.Const, (value, const_dtype))
mask_shape = list(mask.node.meta['val'].shape)
if symint_in_shape(mask_shape):
value = self.process_dynamic_shape(mask_shape)
else:
value = self.get_proxy(ascend_op.Const, (mask_shape, torch.int32))
value = self.get_proxy(ascend_op.BroadcastTo, (scalar_op, value))
value = self.get_param_proxy(value, const_dtype, mask_shape)
if x_dtype == torch.float16:
value = self.get_proxy(ascend_op.Cast, (value, "FLOAT16"))
return self.get_proxy(ascend_op.MaskedFill, (x, mask, value))
Expand All @@ -482,12 +449,7 @@ def scatter(self, var, dim, index, value):
assert isinstance(dim, int)
index_shape = list(index.node.meta['val'].shape)
if isinstance(value, torch.fx.proxy.Proxy):
preprocess = None
if symint_in_shape(index_shape):
preprocess = self.process_dynamic_shape(index_shape)
else:
preprocess = self.get_proxy(
ascend_op.Const, (index_shape, torch.int32))
preprocess = self.get_shape_proxy(index_shape)
value = self.get_proxy(ascend_op.Reshape, (value, preprocess))
else:
out_dtype = fx_traceback.get_current_meta()['val'].dtype
Expand Down Expand Up @@ -558,11 +520,7 @@ def full(self, dims, value, dtype=torch.float32, layout=torch.strided,
dim.node, 'meta') else dim for dim in dims]
if isinstance(value, torch.fx.proxy.Proxy) and hasattr(value.node, 'meta'):
value = value.node.meta['val']
if symint_in_shape(dims):
dims = self.process_dynamic_shape(dims)
else:
dims = self.get_proxy(
ascend_op.Const, (dims, torch.int32, [len(dims)]))
dims = self.get_shape_proxy(dims)
value = self.get_proxy(ascend_op.Const, ([value], torch_dtype, []))
return self.get_proxy(ascend_op.Fill, (dims, value))

Expand Down Expand Up @@ -739,11 +697,8 @@ def expand(self, x, shape):
x = self.get_proxy(ascend_op.Cast, (x, "INT32"))
shape = [dim.meta['val'] if hasattr(
dim, 'meta') else dim for dim in shape]
if isinstance(shape, list) and symint_in_shape(shape):
preprocess_shape = self.process_dynamic_shape(shape)
return self.get_proxy(ascend_op.Expand, (x, preprocess_shape))
else:
return self.get_proxy(ascend_op.ExpandD, (x, shape))
shape = self.get_shape_proxy(shape)
return self.get_proxy(ascend_op.ExpandD, (x, shape))

@register_conversion(torch.ops.aten.slice_backward.default)
def slice_backward(self, grad, input_shape, dim, start, end, step):
Expand Down Expand Up @@ -795,12 +750,7 @@ def maximum(self, a, b):
a_shape = list(a.node.meta['val'].shape)
b_shape = list(b.node.meta['val'].shape)
if np.prod(b_shape) < np.prod(a_shape):
if symint_in_shape(a_shape):
a_shape = self.process_dynamic_shape(a_shape)
else:
a_shape = self.get_proxy(
ascend_op.Const, (a_shape, torch.int32, [len(a_shape)]))
b = self.get_proxy(ascend_op.BroadcastTo, (b, a_shape))
b = self.get_param_proxy(b, None, a_shape)
if a.node.meta['val'].dtype == torch.float16:
b = self.get_proxy(ascend_op.Cast, (b, "FLOAT16"))
return self.get_proxy(ascend_op.Maximum, (a, b))
Expand All @@ -813,11 +763,7 @@ def common_process_scalar(self, x, y):
need_cast = True
y = self.get_proxy(ascend_op.Const, (y, x_dtype))
y_shape = list(x.node.meta['val'].shape)
if symint_in_shape(y_shape):
shape_preprocess = self.process_dynamic_shape(y_shape)
else:
shape_preprocess = self.get_proxy(
ascend_op.Const, (y_shape, torch.int32))
shape_preprocess = self.get_shape_proxy(y_shape)
y = self.get_proxy(ascend_op.BroadcastTo, (y, shape_preprocess))
if need_cast:
y = self.get_proxy(ascend_op.Cast, (y, "FLOAT16"))
Expand All @@ -844,10 +790,7 @@ def transpose(self, input, dim0, dim1):
perm = [num for num in range(rank)]
perm[dim0] = dim1
perm[dim1] = dim0
if symint_in_shape(perm):
perm = self.process_dynamic_shape(perm)
else:
perm = self.get_proxy(ascend_op.Const, (perm, torch.int32))
perm = self.get_shape_proxy(perm)
return self.get_proxy(ascend_op.Transpose, (input, perm))

@register_conversion(torch.ops.aten.convolution)
Expand Down