Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#86dtbrzdm - Fix class instantiation causing stack error #1264

Merged
merged 1 commit into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions boa3/internal/compiler/codegenerator/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,19 +522,25 @@ def get_symbol(self, identifier: str, scope: ISymbol | None = None, is_internal:
return found_id, found_symbol
return identifier, Type.none

def initialize_static_fields(self) -> bool:
def initialize_static_fields(self) -> tuple[bool, bool]:
"""
Converts the signature of the method
:return: whether there are static fields to be initialized
:return: whether there are static fields to be initialized and if they can be generated already
"""
can_init_static_fields = False
has_static_fields = False
default_result = (has_static_fields, can_init_static_fields)

if not self.can_init_static_fields:
return False
return default_result
if self.initialized_static_fields:
return False
return default_result

num_static_fields = len(self._statics)
if num_static_fields > 0:
has_static_fields = num_static_fields > 0
can_init_static_fields = True
if has_static_fields:
init_data = bytearray([num_static_fields])
self.__insert1(OpcodeInfo.INITSSLOT, init_data)

Expand All @@ -548,7 +554,7 @@ def initialize_static_fields(self) -> bool:
init_method.init_bytecode = self.last_code
self.symbol_table[constants.INITIALIZE_METHOD_ID] = init_method

return num_static_fields > 0
return has_static_fields, can_init_static_fields

def end_initialize(self):
"""
Expand Down
38 changes: 20 additions & 18 deletions boa3/internal/compiler/codegenerator/codegeneratorvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ def visit_Module(self, module: ast.Module) -> GeneratorData:
for stmt in function_stmts:
self.visit(stmt)

if self.generator.initialize_static_fields():
has_static_fields, can_initialize_static_fields = self.generator.initialize_static_fields()
if can_initialize_static_fields:
last_symbols = self.symbols # save to revert in the end and not compromise consequent visits
class_non_static_stmts = []

Expand Down Expand Up @@ -243,23 +244,24 @@ def visit_Module(self, module: ast.Module) -> GeneratorData:
class_non_static_stmts.append(cls_fun)
self.symbols = last_symbols # don't use inner scopes to evaluate the other globals

# to generate the 'initialize' method for Neo
self._log_info(f"Compiling '{constants.INITIALIZE_METHOD_ID}' function")
self._is_generating_initialize = True
for stmt in global_stmts:
cur_tree = self._tree
cur_filename = self.filename
if hasattr(stmt, 'origin'):
if hasattr(stmt.origin, 'filename'):
self.set_filename(stmt.origin.filename)
self._tree = stmt.origin

self.visit(stmt)
self.filename = cur_filename
self._tree = cur_tree

self._is_generating_initialize = False
self.generator.end_initialize()
if has_static_fields:
# to generate the 'initialize' method for Neo
self._log_info(f"Compiling '{constants.INITIALIZE_METHOD_ID}' function")
self._is_generating_initialize = True
for stmt in global_stmts:
cur_tree = self._tree
cur_filename = self.filename
if hasattr(stmt, 'origin'):
if hasattr(stmt.origin, 'filename'):
self.set_filename(stmt.origin.filename)
self._tree = stmt.origin

self.visit(stmt)
self.filename = cur_filename
self._tree = cur_tree

self._is_generating_initialize = False
self.generator.end_initialize()

# generate any symbol inside classes that's not variables AFTER generating 'initialize' method
for stmt in class_non_static_stmts:
Expand Down
27 changes: 27 additions & 0 deletions boa3_test/test_sc/class_test/ClassInitWithStaticVariable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Any

from boa3.sc.compiletime import public

FOO = "bar"


class MyNFT:
def __init__(self, shape: str, color: str, background: str, size: str) -> None:
self.shape = shape
self.color = color
self.background = background
self.size = size

def export(self) -> dict:
return {
'shape': self.shape,
'color': self.color,
'background': self.background,
'size': self.size
}


@public
def test() -> Any:
nft = MyNFT('Rectangle', 'Blue', 'Black', 'Small')
return nft
16 changes: 16 additions & 0 deletions boa3_test/tests/compiler_tests/test_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,19 @@ async def test_return_dict_with_class_attributes(self):
}
result, _ = await self.call('test_pair', [], return_type=dict[str,str])
self.assertEqual(expected_result, result)

async def test_class_init_with_static_variable_no_optimization(self):
await self.set_up_contract('ClassInitWithStaticVariable.py', optimize=False)
from boa3_test.test_sc.class_test.ClassInitWithStaticVariable import MyNFT

expected_result = MyNFT('Rectangle', 'Blue', 'Black', 'Small')
result, _ = await self.call('test', [], return_type=list)
self.assertObjectEqual(expected_result, result)

async def test_class_init_with_static_variable_optimized(self):
await self.set_up_contract('ClassInitWithStaticVariable.py', optimize=True)
from boa3_test.test_sc.class_test.ClassInitWithStaticVariable import MyNFT

expected_result = MyNFT('Rectangle', 'Blue', 'Black', 'Small')
result, _ = await self.call('test', [], return_type=list)
self.assertObjectEqual(expected_result, result)
Loading