From 776d9bc2fb65a7b29ac4b6a82db1a156c0bf3c33 Mon Sep 17 00:00:00 2001 From: Mirella de Medeiros Date: Tue, 4 Jun 2024 16:31:57 -0300 Subject: [PATCH] #86dtbrzdm - Fix class instantiation causing stack error --- .../compiler/codegenerator/codegenerator.py | 18 ++++++--- .../codegenerator/codegeneratorvisitor.py | 38 ++++++++++--------- .../class_test/ClassInitWithStaticVariable.py | 27 +++++++++++++ boa3_test/tests/compiler_tests/test_class.py | 16 ++++++++ 4 files changed, 75 insertions(+), 24 deletions(-) create mode 100644 boa3_test/test_sc/class_test/ClassInitWithStaticVariable.py diff --git a/boa3/internal/compiler/codegenerator/codegenerator.py b/boa3/internal/compiler/codegenerator/codegenerator.py index 5666aaff..6aeb494e 100644 --- a/boa3/internal/compiler/codegenerator/codegenerator.py +++ b/boa3/internal/compiler/codegenerator/codegenerator.py @@ -522,19 +522,25 @@ def get_symbol(self, identifier: str, scope: ISymbol | None = None, is_internal: return found_id, found_symbol return identifier, Type.none - def initialize_static_fields(self) -> bool: + def initialize_static_fields(self) -> tuple[bool, bool]: """ Converts the signature of the method - :return: whether there are static fields to be initialized + :return: whether there are static fields to be initialized and if they can be generated already """ + can_init_static_fields = False + has_static_fields = False + default_result = (has_static_fields, can_init_static_fields) + if not self.can_init_static_fields: - return False + return default_result if self.initialized_static_fields: - return False + return default_result num_static_fields = len(self._statics) - if num_static_fields > 0: + has_static_fields = num_static_fields > 0 + can_init_static_fields = True + if has_static_fields: init_data = bytearray([num_static_fields]) self.__insert1(OpcodeInfo.INITSSLOT, init_data) @@ -548,7 +554,7 @@ def initialize_static_fields(self) -> bool: init_method.init_bytecode = self.last_code self.symbol_table[constants.INITIALIZE_METHOD_ID] = init_method - return num_static_fields > 0 + return has_static_fields, can_init_static_fields def end_initialize(self): """ diff --git a/boa3/internal/compiler/codegenerator/codegeneratorvisitor.py b/boa3/internal/compiler/codegenerator/codegeneratorvisitor.py index 7146e8c3..86b47fd8 100644 --- a/boa3/internal/compiler/codegenerator/codegeneratorvisitor.py +++ b/boa3/internal/compiler/codegenerator/codegeneratorvisitor.py @@ -205,7 +205,8 @@ def visit_Module(self, module: ast.Module) -> GeneratorData: for stmt in function_stmts: self.visit(stmt) - if self.generator.initialize_static_fields(): + has_static_fields, can_initialize_static_fields = self.generator.initialize_static_fields() + if can_initialize_static_fields: last_symbols = self.symbols # save to revert in the end and not compromise consequent visits class_non_static_stmts = [] @@ -243,23 +244,24 @@ def visit_Module(self, module: ast.Module) -> GeneratorData: class_non_static_stmts.append(cls_fun) self.symbols = last_symbols # don't use inner scopes to evaluate the other globals - # to generate the 'initialize' method for Neo - self._log_info(f"Compiling '{constants.INITIALIZE_METHOD_ID}' function") - self._is_generating_initialize = True - for stmt in global_stmts: - cur_tree = self._tree - cur_filename = self.filename - if hasattr(stmt, 'origin'): - if hasattr(stmt.origin, 'filename'): - self.set_filename(stmt.origin.filename) - self._tree = stmt.origin - - self.visit(stmt) - self.filename = cur_filename - self._tree = cur_tree - - self._is_generating_initialize = False - self.generator.end_initialize() + if has_static_fields: + # to generate the 'initialize' method for Neo + self._log_info(f"Compiling '{constants.INITIALIZE_METHOD_ID}' function") + self._is_generating_initialize = True + for stmt in global_stmts: + cur_tree = self._tree + cur_filename = self.filename + if hasattr(stmt, 'origin'): + if hasattr(stmt.origin, 'filename'): + self.set_filename(stmt.origin.filename) + self._tree = stmt.origin + + self.visit(stmt) + self.filename = cur_filename + self._tree = cur_tree + + self._is_generating_initialize = False + self.generator.end_initialize() # generate any symbol inside classes that's not variables AFTER generating 'initialize' method for stmt in class_non_static_stmts: diff --git a/boa3_test/test_sc/class_test/ClassInitWithStaticVariable.py b/boa3_test/test_sc/class_test/ClassInitWithStaticVariable.py new file mode 100644 index 00000000..664dfcd6 --- /dev/null +++ b/boa3_test/test_sc/class_test/ClassInitWithStaticVariable.py @@ -0,0 +1,27 @@ +from typing import Any + +from boa3.sc.compiletime import public + +FOO = "bar" + + +class MyNFT: + def __init__(self, shape: str, color: str, background: str, size: str) -> None: + self.shape = shape + self.color = color + self.background = background + self.size = size + + def export(self) -> dict: + return { + 'shape': self.shape, + 'color': self.color, + 'background': self.background, + 'size': self.size + } + + +@public +def test() -> Any: + nft = MyNFT('Rectangle', 'Blue', 'Black', 'Small') + return nft diff --git a/boa3_test/tests/compiler_tests/test_class.py b/boa3_test/tests/compiler_tests/test_class.py index 34a52047..b1ab848e 100644 --- a/boa3_test/tests/compiler_tests/test_class.py +++ b/boa3_test/tests/compiler_tests/test_class.py @@ -489,3 +489,19 @@ async def test_return_dict_with_class_attributes(self): } result, _ = await self.call('test_pair', [], return_type=dict[str,str]) self.assertEqual(expected_result, result) + + async def test_class_init_with_static_variable_no_optimization(self): + await self.set_up_contract('ClassInitWithStaticVariable.py', optimize=False) + from boa3_test.test_sc.class_test.ClassInitWithStaticVariable import MyNFT + + expected_result = MyNFT('Rectangle', 'Blue', 'Black', 'Small') + result, _ = await self.call('test', [], return_type=list) + self.assertObjectEqual(expected_result, result) + + async def test_class_init_with_static_variable_optimized(self): + await self.set_up_contract('ClassInitWithStaticVariable.py', optimize=True) + from boa3_test.test_sc.class_test.ClassInitWithStaticVariable import MyNFT + + expected_result = MyNFT('Rectangle', 'Blue', 'Black', 'Small') + result, _ = await self.call('test', [], return_type=list) + self.assertObjectEqual(expected_result, result)