From 1c0ae91c374b751f6d9bed55f872084d4bd6a184 Mon Sep 17 00:00:00 2001 From: Pan Daoxin Date: Mon, 13 Nov 2023 19:46:55 -0500 Subject: [PATCH] For llama fixed shape. --- dicp/dicp/vendor/AscendGraph/codegen/ascend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py index d0c08c026..9ce749303 100644 --- a/dicp/dicp/vendor/AscendGraph/codegen/ascend.py +++ b/dicp/dicp/vendor/AscendGraph/codegen/ascend.py @@ -294,6 +294,8 @@ def gen_call_func(self): call_body.writeline(f'''dims = None''') # generate output shapes + extra_stride_str = '' + extra_storage_offset_str = '' if len(self.sym_in_args) > 0 or len(self.sym_to_inputs) > 0: shape_str = f'''output_shape = [''' for elem in self.output_args: @@ -309,8 +311,6 @@ def gen_call_func(self): shape_str += "[" + ','.join(map(str, shape)) + "]," # process output_shape with modified args - extra_stride_str = '' - extra_storage_offset_str = '' for elem in self.assign_args: shape = list(self.input_args[elem[1]].meta['val'].shape) if len(shape) == 0: