Skip to content

Commit

Permalink
Merge branch 'main' into daoxin/fix_hf_transformer_precision
Browse files Browse the repository at this point in the history
Conflicts:
	dicp/dicp/vendor/AscendGraph/codegen/ascend.py
	dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py
  • Loading branch information
pdx1989 committed Nov 13, 2023
2 parents 2756db4 + 6d835b3 commit 1faad24
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 101 deletions.
16 changes: 10 additions & 6 deletions dicp/dicp/dynamo_bridge/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,22 @@ def name(self):
# def infer_result(self, *args, **kwargs):
# pass

def __call__(self, *args, **kwargs):
def get_meta(x):
return x if not hasattr(x, 'meta') else x.meta['val']
new_args = tree_map(get_meta, args)

def get_fake_mode_from_args(self, args):
fake_mode = None
tmp_args, _ = tree_flatten(new_args)
tmp_args, _ = tree_flatten(args)
for arg in tmp_args:
if isinstance(arg, FakeTensor):
fake_mode = arg.fake_mode
break
fake_mode = self.fake_mode if fake_mode is None else fake_mode
return fake_mode

def __call__(self, *args, **kwargs):
def get_meta(x):
return x if not hasattr(x, 'meta') else x.meta['val']
new_args = tree_map(get_meta, args)

fake_mode = self.get_fake_mode_from_args(new_args)

def make_faketensor(x):
if not isinstance(x, torch.Tensor) or (isinstance(x, FakeTensor) and x.fake_mode == fake_mode):
Expand Down
19 changes: 17 additions & 2 deletions dicp/dicp/vendor/AscendGraph/codegen/ascend.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,22 @@ def gen_call_func(self):
shape_str = shape_str[:-1] + f''']'''
call_body.writeline(shape_str)
else:
call_body.writeline(f'''output_shape = None''')
call_body.writeline('''output_shape = None''')

out_stride_str = '''out_stride = ['''
out_storage_offset_str = '''out_storage_offset = ['''
for elem in self.output_args:
if hasattr(elem, 'meta'):
elem = elem.meta['val']
stride = list(elem.stride())
if len(stride) == 0:
raise RuntimeError("Error handling empty output_stride")
out_stride_str += '[' + ','.join(map(str, stride)) + '],'
out_storage_offset_str += str(elem.storage_offset()) + ','
out_stride_str = out_stride_str[:-1] + ']'
out_storage_offset_str = out_storage_offset_str[:-1] + ']'
call_body.writeline(out_stride_str)
call_body.writeline(out_storage_offset_str)

call_body.splice("""
import torch_dipu
Expand All @@ -306,7 +321,7 @@ def gen_call_func(self):
del tmp_arg
""", strip=True)
call_body.writeline(f"({','.join(self.args)}) = args")
call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape)']
call_str = ['output_tensor = kernel_cpp_0(args, dims, output_shape, out_stride, out_storage_offset)']

if precision_check and self.aten_graph is not None:
# import aten graph
Expand Down
18 changes: 12 additions & 6 deletions dicp/dicp/vendor/AscendGraph/codegen/load_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,16 +283,18 @@ def _prepare_input(self, images, dims):
check_ret("acl.mdl.set_dataset_tensor_desc", ret)
assert (dataset == self.input_dataset)

def _prepare_output(self, output_tensor):
def _prepare_output(self, output_tensor, out_stride, out_storage_offset):
for i in range(self.num_outputs):
item = torch.empty(
self.output_dims[i], dtype=self.output_dtypes[i], device=dipu_device_str)
item = item.as_strided(
self.output_dims[i], out_stride[i], out_storage_offset[i])
output_tensor.append(item)
ret = acl.update_data_buffer(
self.output_data_buffers[i], item.data_ptr(), self.output_size[i])
check_ret("acl.update_data_buffer", ret)

def _prepare_dynamic_output(self, output_tensor):
def _prepare_dynamic_output(self, output_tensor, out_stride, out_storage_offset):
for i in range(self.num_outputs):
tot_size = 1
for elem in self.output_shape[i]:
Expand All @@ -303,12 +305,15 @@ def _prepare_dynamic_output(self, output_tensor):
self.output_size[i] = tot_size
item = torch.empty(
self.output_dims[i], dtype=self.output_dtypes[i], device=dipu_device_str)
item = item.as_strided(
self.output_dims[i], out_stride[i], out_storage_offset[i])

output_tensor.append(item)
ret = acl.update_data_buffer(
self.output_data_buffers[i], item.data_ptr(), self.output_size[i])
check_ret("acl.update_data_buffer", ret)

def run(self, images, dims=None, output_shape=None):
def run(self, images, dims=None, output_shape=None, out_stride=None, out_storage_offset=None):
self.output_shape = output_shape
assert len(images) > 0
for img in images:
Expand All @@ -319,7 +324,7 @@ def run(self, images, dims=None, output_shape=None):
if self.output_shape:
self._prepare_dynamic_output(output)
else:
self._prepare_output(output)
self._prepare_output(output, out_stride, out_storage_offset)
self.forward()
self._destroy_databuffer()
return output
Expand All @@ -342,8 +347,9 @@ def __init__(self, device_id, model_path) -> None:
atexit.register(self.cleanup)
self.exe = AscendExecutor(device_id, model_path)

def run(self, images, dims=None, output_shape=None):
return self.exe.run(images, dims, output_shape)
def run(self, images, dims=None, output_shape=None,
out_stride=None, out_storage_offset=None):
return self.exe.run(images, dims, output_shape, out_stride, out_storage_offset)

def cleanup(self):
if hasattr(self, 'exe'):
Expand Down
1 change: 0 additions & 1 deletion dicp/dicp/vendor/AscendGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,6 @@ def transpose(self, input, dim0, dim1):
perm = [num for num in range(rank)]
perm[dim0] = dim1
perm[dim1] = dim0
ops = []
if symint_in_shape(perm):
perm = self.process_dynamic_shape(perm)
else:
Expand Down
86 changes: 2 additions & 84 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
from typing import Tuple
import operator

from contextlib import nullcontext
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._subclasses import FakeTensor, FakeTensorMode
from torch._functorch import config
from dicp.dynamo_bridge.operator import Operator
from torch._subclasses import FakeTensorMode
aten = torch.ops.aten


Expand All @@ -22,86 +20,6 @@ def binary_device_check(name, lhs_t, rhs_t):
return lhs_t.device


class Operator():
__name__: str
_singleton = None

def __init__(self, name_):
super().__init__()
self.__name__ = name_
if torch.__version__.startswith("2.0"):
self.shape_env = ShapeEnv() if config.use_dynamic_shapes else None
self.fake_mode = (
FakeTensorMode(shape_env=self.shape_env)
if config.use_fake_tensor
else nullcontext()
)
elif torch.__version__.startswith("2.1"):
self.shape_env = ShapeEnv() if torch._dynamo.config.dynamic_shapes else None
self.fake_mode = (
FakeTensorMode(shape_env=self.shape_env)
if config.fake_tensor_allow_meta
else nullcontext()
)
else:
raise ValueError(
f"unsupported dicp torch version: {torch.__version__}")

@classmethod
def get_singleton(cls):
args = [None] * (cls.__init__.__code__.co_argcount - 1)
if cls._singleton is None:
cls._singleton = cls(*args)
return cls._singleton

def name(self):
return self.__name__

def get_fake_mode_from_args(self, args):
fake_mode = None
for arg in args:
if isinstance(arg, FakeTensor):
fake_mode = arg.fake_mode
break
elif isinstance(arg, list):
for x in arg:
if isinstance(x, FakeTensor):
fake_mode = x.fake_mode
break
if fake_mode is not None:
break
if fake_mode is None:
fake_mode = self.fake_mode
return fake_mode

def __call__(self, *args, **kwargs):
new_args = []
for arg in args:
if isinstance(arg, list):
new_args.append([x if not hasattr(x, 'meta')
else x.meta['val'] for x in arg])
else:
new_args.append(arg if not hasattr(
arg, 'meta') else arg.meta['val'])
new_args = tuple(new_args)

fake_mode = self.get_fake_mode_from_args(new_args)

tmp_args = []
for arg in new_args:
if not isinstance(arg, torch.Tensor) or isinstance(arg, FakeTensor):
tmp_args.append(arg)
else:
tmp_args.append(FakeTensor.from_tensor(arg, fake_mode))
new_args = tuple(tmp_args)

try:
ret = self.torch_op(*new_args, **kwargs)
except Exception:
ret = None
return ret


class Add(Operator):
def __init__(self, a, b, **kwargs):
super().__init__("Add")
Expand Down
4 changes: 2 additions & 2 deletions dipu/scripts/autogen_diopi_wrapper/diopi_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@
interface: diopiMm(ctx, out, self, mat2)

- schema: "matmul(Tensor self, Tensor other) -> Tensor"
register_op: False
device: [droplet]
custom_code_at_the_beginning: |
const auto shapeA = self.sizes();
const auto shapeB = other.sizes();
Expand Down Expand Up @@ -1349,7 +1349,7 @@
interface: diopiMatmul(ctx, out, self, other)

- schema: "matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)"
register_op: False
device: [droplet]
interface: diopiMatmul(ctx, out, self, other)

- schema: "cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)"
Expand Down

0 comments on commit 1faad24

Please sign in to comment.