diff --git a/helper/common/common.go b/helper/common/common.go index 3ba8ffe337..57ec57fd7d 100644 --- a/helper/common/common.go +++ b/helper/common/common.go @@ -375,3 +375,48 @@ func EncodeUint64ToBytes(value uint64) []byte { func EncodeBytesToUint64(b []byte) uint64 { return binary.BigEndian.Uint64(b) } + +// Generic object pool implementation intended to be used in single-threaded +// manner and avoid synchronization overhead. It could be probably additionally +// be improved by using circular buffer as oposed to stack. +type UnsafePool[T any] struct { + stack []T +} + +// Creates new instance of UnsafePool. Depending on observed usage, pool size +// should be set on creation to avoid pool resizing +func NewUnsafePool[T any]() *UnsafePool[T] { + return &UnsafePool[T]{} +} + +// Get retrieves an object from the unsafepool, or allocates a new one if the pool +// is empty. The allocation logic (i.e., creating a new object of type T) needs to +// be provided externally, as Go's type system does not allow calling constructors +// or functions specific to T without an interface. +func (f *UnsafePool[T]) Get(newFunc func() T) T { + n := len(f.stack) + if n == 0 { + // Allocate a new T instance using the provided newFunc if the stack is empty. + return newFunc() + } + + obj := f.stack[n-1] + f.stack = f.stack[:n-1] + + return obj +} + +// Put returns an object to the pool and executes reset function if provided. Reset +// function is used to return the T instance to initial state. +func (f *UnsafePool[T]) Put(resetFunc func(T) T, obj T) { + if resetFunc != nil { + obj = resetFunc(obj) + } + + f.stack = append(f.stack, obj) +} + +// Clear the content of the pool +func (f *UnsafePool[T]) Clear() { + f.stack = f.stack[:0] +} diff --git a/helper/common/common_test.go b/helper/common/common_test.go index f820bcc7aa..e7168c0662 100644 --- a/helper/common/common_test.go +++ b/helper/common/common_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -168,3 +169,80 @@ func Test_SafeAddUint64(t *testing.T) { }) } } + +func TestNewUnsafePool(t *testing.T) { + pool := NewUnsafePool[int]() + + if pool == nil { + t.Errorf("NewUnsafePool returned nil") + } + + if len(pool.stack) != 0 { + t.Errorf("Expected empty pool, got %v", pool.stack) + } +} + +func TestUnsafePoolGetWhenEmpty(t *testing.T) { + pool := NewUnsafePool[int]() + newInt := func() int { + return 1 + } + + obj := pool.Get(newInt) + if obj != 1 { + t.Errorf("Expected 1 from newFunc, got %v", obj) + } +} + +func TestUnsafePoolGetPut(t *testing.T) { + pool := NewUnsafePool[int]() + resetInt := func(i int) int { + return 0 + } + + // Initially put an object into the pool. + pool.Put(resetInt, 2) + + // Retrieve the object, which should now be the reset value. + obj := pool.Get(func() int { return 3 }) + if obj != 0 { // Expecting the original object, not the one from newFunc + t.Errorf("Expected 2 from the pool, got %v", obj) + } + + // Test if Get correctly uses newFunc when pool is empty again. + obj = pool.Get(func() int { return 3 }) + if obj != 3 { + t.Errorf("Expected 3 from newFunc, got %v", obj) + } +} + +func TestUnsafePoolPutWithReset(t *testing.T) { + pool := NewUnsafePool[int]() + resetInt := func(i int) int { + return 0 + } + + // Put an object into the pool with a reset function. + pool.Put(resetInt, 5) + + // Directly check if the object was reset. + if pool.stack[0] != 0 { + t.Errorf("Expected object to be reset to 0, got %v", pool.stack[0]) + } +} + +func TestUnsafePoolClear(t *testing.T) { + pool := NewUnsafePool[int]() + resetInt := func(i int) int { + return 0 + } + + // Put an object into the pool with a reset function. + pool.Put(resetInt, 1) + + assert.Len(t, pool.stack, 1, "Expected pool stack 1") + + pool.Clear() + + assert.Len(t, pool.stack, 0, "Expected pool stack 0") +} diff --git a/state/runtime/evm/instructions_test.go b/state/runtime/evm/instructions_test.go index 82592593d7..0830da9d64 100644 --- a/state/runtime/evm/instructions_test.go +++ b/state/runtime/evm/instructions_test.go @@ -1,7 +1,9 @@ package evm import ( + "errors" "math/big" + "reflect" "testing" "github.com/0xPolygon/polygon-edge/chain" @@ -37,9 +39,9 @@ func testLogicalOperation(t *testing.T, f instruction, test OperandsLogical, s * f(s) if test.expectedResult { - assert.Equal(t, one, s.pop()) + assert.Equal(t, one.Uint64(), s.pop().Uint64()) } else { - assert.Equal(t, zero, s.pop()) + assert.Equal(t, zero.Uint64(), s.pop().Uint64()) } } @@ -57,7 +59,7 @@ func testArithmeticOperation(t *testing.T, f instruction, test OperandsArithmeti f(s) - assert.Equal(t, test.expectedResult, s.pop()) + assert.Equal(t, test.expectedResult.Uint64(), s.pop().Uint64()) } func TestAdd(t *testing.T) { @@ -355,7 +357,7 @@ func TestPush0(t *testing.T) { defer closeFn() opPush0(s) - require.Equal(t, zero, s.pop()) + require.Equal(t, zero.Uint64(), s.pop().Uint64()) }) t.Run("single push0 (EIP-3855 disabled)", func(t *testing.T) { @@ -857,7 +859,7 @@ func TestCallValue(t *testing.T) { s.msg.Value = value opCallValue(s) - assert.Equal(t, value, s.pop()) + assert.Equal(t, value.Uint64(), s.pop().Uint64()) }) t.Run("Msg Value nil", func(t *testing.T) { @@ -865,7 +867,7 @@ func TestCallValue(t *testing.T) { defer cancelFn() opCallValue(s) - assert.Equal(t, zero, s.pop()) + assert.Equal(t, zero.Uint64(), s.pop().Uint64()) }) } @@ -879,7 +881,7 @@ func TestCallDataLoad(t *testing.T) { s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()} opCallDataLoad(s) - assert.Equal(t, zero, s.pop()) + assert.Equal(t, zero.Uint64(), s.pop().Uint64()) }) t.Run("ZeroOffset", func(t *testing.T) { s, cancelFn := getState(&chain.ForksInTime{}) @@ -890,7 +892,7 @@ func TestCallDataLoad(t *testing.T) { s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()} opCallDataLoad(s) - assert.NotEqual(t, zero, s.pop()) + assert.NotEqual(t, zero.Uint64(), s.pop().Uint64()) }) } @@ -1013,7 +1015,7 @@ func TestExtCodeHash(t *testing.T) { opExtCodeHash(s) assert.Equal(t, s.gas, gasLeft) - assert.Equal(t, one, s.pop()) + assert.Equal(t, one.Uint64(), s.pop().Uint64()) }) t.Run("NonIstanbul", func(t *testing.T) { @@ -1032,7 +1034,7 @@ func TestExtCodeHash(t *testing.T) { opExtCodeHash(s) assert.Equal(t, gasLeft, s.gas) - assert.Equal(t, zero, s.pop()) + assert.Equal(t, zero.Uint64(), s.pop().Uint64()) }) t.Run("NoForks", func(t *testing.T) { @@ -2288,11 +2290,39 @@ func Test_opReturnDataCopy(t *testing.T) { opReturnDataCopy(state) - assert.Equal(t, test.resultState, state) + assert.True(t, CompareStates(test.resultState, state)) }) } } +// Since the state is complex structure, here is the specialized comparison +// function that checks significant fields. This function should be updated +// to suite future needs. +func CompareStates(a *state, b *state) bool { + // Compare simple fields + if a.ip != b.ip || a.lastGasCost != b.lastGasCost || a.sp != b.sp || !errors.Is(a.err, b.err) || a.stop != b.stop || a.gas != b.gas { + return false + } + + // Deep compare slices + if !reflect.DeepEqual(a.code, b.code) || !reflect.DeepEqual(a.tmp, b.tmp) || !reflect.DeepEqual(a.returnData, b.returnData) || !reflect.DeepEqual(a.memory, b.memory) { + return false + } + + // Deep comparison of stacks + if len(a.stack) != len(b.stack) { + return false + } + + for i := range a.stack { + if a.stack[i].Cmp(b.stack[i]) != 0 { + return false + } + } + + return true +} + func Test_opCall(t *testing.T) { t.Parallel() diff --git a/state/runtime/evm/state.go b/state/runtime/evm/state.go index 72c1b8d70e..8f6fa2ebd0 100644 --- a/state/runtime/evm/state.go +++ b/state/runtime/evm/state.go @@ -76,6 +76,8 @@ type state struct { returnData []byte ret []byte + + unsafepool common.UnsafePool[*big.Int] } func (c *state) reset() { @@ -95,6 +97,13 @@ func (c *state) reset() { c.memory[i] = 0 } + // Before stack cleanup, return instances of big.Int to the pool + // for the future usage + for i := range c.stack { + c.unsafepool.Put(func(x *big.Int) *big.Int { + return x.SetInt64(0) + }, c.stack[i]) + } c.stack = c.stack[:0] c.tmp = c.tmp[:0] c.ret = c.ret[:0] @@ -136,7 +145,10 @@ func (c *state) push1() *big.Int { return c.stack[c.sp-1] } - v := big.NewInt(0) + v := c.unsafepool.Get(func() *big.Int { + return big.NewInt(0) + }) + c.stack = append(c.stack, v) c.sp++ @@ -180,10 +192,6 @@ func (c *state) pop() *big.Int { o := c.stack[c.sp-1] c.sp-- - if o.Cmp(zero) == 0 { - return big.NewInt(0) - } - return o } @@ -261,8 +269,9 @@ func (c *state) Run() ([]byte, error) { // execute the instruction inst.inst(c) - c.captureExecution(op.String(), ipCopy, gasCopy, gasCopy-c.gas) - + if c.host.GetTracer() != nil { + c.captureExecution(op.String(), ipCopy, gasCopy, gasCopy-c.gas) + } // check if stack size exceeds the max size if c.sp > stackSize { c.exit(&runtime.StackOverflowError{StackLen: c.sp, Limit: stackSize})