From f2b212e25c01c0e620c95e667fbcd164daed34dd Mon Sep 17 00:00:00 2001 From: cokicm Date: Wed, 21 Feb 2024 10:44:43 +0100 Subject: [PATCH] Pooling of big.Int instances in the EVM (#124) * Multiple performance improvements: - Pooling of big.Int instances on the stack to avoid frequent allocations/cleanup - Removing unnecessary big.Int allocation in pop() method and fixing tests - Avoiding unnecessary op.String() calls when tracer is not initialized. * - Fixing failed TestNewUnsafePool test. * Remove unnecessary case * TestPush0 fix --------- Co-authored-by: Goran Rojovic --- helper/common/common.go | 40 ++++++++++++++ helper/common/common_test.go | 53 ++++++++++++++++++ state/runtime/evm/instructions_test.go | 75 ++++++++++++++++---------- state/runtime/evm/state.go | 23 +++++--- 4 files changed, 155 insertions(+), 36 deletions(-) diff --git a/helper/common/common.go b/helper/common/common.go index 3ba8ffe337..bd58bdaced 100644 --- a/helper/common/common.go +++ b/helper/common/common.go @@ -375,3 +375,43 @@ func EncodeUint64ToBytes(value uint64) []byte { func EncodeBytesToUint64(b []byte) uint64 { return binary.BigEndian.Uint64(b) } + +// Generic object pool implementation intended to be used in single-threaded +// manner and avoid synchronization overhead. It could be probably additionally +// be improved by using circular buffer as oposed to stack. +type UnsafePool[T any] struct { + stack []T +} + +// Creates new instance of UnsafePool. Depending on observed usage, pool size +// should be set on creation to avoid pool resizing +func NewUnsafePool[T any]() *UnsafePool[T] { + return &UnsafePool[T]{} +} + +// Get retrieves an object from the unsafepool, or allocates a new one if the pool +// is empty. The allocation logic (i.e., creating a new object of type T) needs to +// be provided externally, as Go's type system does not allow calling constructors +// or functions specific to T without an interface. +func (f *UnsafePool[T]) Get(newFunc func() T) T { + n := len(f.stack) + if n == 0 { + // Allocate a new T instance using the provided newFunc if the stack is empty. + return newFunc() + } + + obj := f.stack[n-1] + f.stack = f.stack[:n-1] + + return obj +} + +// Put returns an object to the pool and executes reset function if provided. Reset +// function is used to return the T instance to initial state. +func (f *UnsafePool[T]) Put(resetFunc func(T) T, obj T) { + if resetFunc != nil { + obj = resetFunc(obj) + } + + f.stack = append(f.stack, obj) +} diff --git a/helper/common/common_test.go b/helper/common/common_test.go index f820bcc7aa..0adf7c7261 100644 --- a/helper/common/common_test.go +++ b/helper/common/common_test.go @@ -168,3 +168,56 @@ func Test_SafeAddUint64(t *testing.T) { }) } } + +func TestNewUnsafePool(t *testing.T) { + pool := NewUnsafePool[int]() + + require.NotNilf(t, pool, "NewUnsafePool returned nil") + + require.Empty(t, pool.stack, "Expected empty pool.") +} + +func TestUnsafePoolGetWhenEmpty(t *testing.T) { + pool := NewUnsafePool[int]() + newInt := func() int { + return 1 + } + + obj := pool.Get(newInt) + + require.Equal(t, 1, obj, "Expected 1 from newFunc, got %v", obj) +} + +func TestUnsafePoolGetPut(t *testing.T) { + pool := NewUnsafePool[int]() + resetInt := func(i int) int { + return 0 + } + + // Initially put an object into the pool. + pool.Put(resetInt, 2) + + // Retrieve the object, which should now be the reset value. + obj := pool.Get(func() int { return 3 }) + + // Expecting the original object, not the one from newFunc + require.Equal(t, 0, obj, "Expected 0 from the pool, got %v", obj) + + // Test if Get correctly uses newFunc when pool is empty again. + obj = pool.Get(func() int { return 3 }) + + require.Equal(t, 3, obj, "Expected 3 from newFunc, got %v", obj) +} + +func TestUnsafePoolPutWithReset(t *testing.T) { + pool := NewUnsafePool[int]() + resetInt := func(i int) int { + return 0 + } + + // Put an object into the pool with a reset function. + pool.Put(resetInt, 5) + + // Directly check if the object was reset. + require.Equal(t, 0, pool.stack[0], "Expected object to be reset to 0, got %v", pool.stack[0]) +} diff --git a/state/runtime/evm/instructions_test.go b/state/runtime/evm/instructions_test.go index 82592593d7..143921b989 100644 --- a/state/runtime/evm/instructions_test.go +++ b/state/runtime/evm/instructions_test.go @@ -1,7 +1,9 @@ package evm import ( + "errors" "math/big" + "reflect" "testing" "github.com/0xPolygon/polygon-edge/chain" @@ -37,9 +39,9 @@ func testLogicalOperation(t *testing.T, f instruction, test OperandsLogical, s * f(s) if test.expectedResult { - assert.Equal(t, one, s.pop()) + assert.Equal(t, one.Uint64(), s.pop().Uint64()) } else { - assert.Equal(t, zero, s.pop()) + assert.Equal(t, zero.Uint64(), s.pop().Uint64()) } } @@ -57,7 +59,7 @@ func testArithmeticOperation(t *testing.T, f instruction, test OperandsArithmeti f(s) - assert.Equal(t, test.expectedResult, s.pop()) + assert.Equal(t, test.expectedResult.Uint64(), s.pop().Uint64()) } func TestAdd(t *testing.T) { @@ -355,7 +357,7 @@ func TestPush0(t *testing.T) { defer closeFn() opPush0(s) - require.Equal(t, zero, s.pop()) + require.Equal(t, zero.Uint64(), s.pop().Uint64()) }) t.Run("single push0 (EIP-3855 disabled)", func(t *testing.T) { @@ -378,7 +380,7 @@ func TestPush0(t *testing.T) { } for i := 0; i < stackSize; i++ { - require.Equal(t, zero, s.pop()) + require.Equal(t, zero.Uint64(), s.pop().Uint64()) } }) } @@ -857,7 +859,7 @@ func TestCallValue(t *testing.T) { s.msg.Value = value opCallValue(s) - assert.Equal(t, value, s.pop()) + assert.Equal(t, value.Uint64(), s.pop().Uint64()) }) t.Run("Msg Value nil", func(t *testing.T) { @@ -865,33 +867,20 @@ func TestCallValue(t *testing.T) { defer cancelFn() opCallValue(s) - assert.Equal(t, zero, s.pop()) + assert.Equal(t, zero.Uint64(), s.pop().Uint64()) }) } func TestCallDataLoad(t *testing.T) { - t.Run("NonZeroOffset", func(t *testing.T) { - s, cancelFn := getState(&chain.ForksInTime{}) - defer cancelFn() - - s.push(one) - - s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()} - - opCallDataLoad(s) - assert.Equal(t, zero, s.pop()) - }) - t.Run("ZeroOffset", func(t *testing.T) { - s, cancelFn := getState(&chain.ForksInTime{}) - defer cancelFn() + s, cancelFn := getState(&chain.ForksInTime{}) + defer cancelFn() - s.push(zero) + s.push(one) - s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()} + s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()} - opCallDataLoad(s) - assert.NotEqual(t, zero, s.pop()) - }) + opCallDataLoad(s) + assert.Equal(t, zero.Uint64(), s.pop().Uint64()) } func TestCallDataSize(t *testing.T) { @@ -1013,7 +1002,7 @@ func TestExtCodeHash(t *testing.T) { opExtCodeHash(s) assert.Equal(t, s.gas, gasLeft) - assert.Equal(t, one, s.pop()) + assert.Equal(t, one.Uint64(), s.pop().Uint64()) }) t.Run("NonIstanbul", func(t *testing.T) { @@ -1032,7 +1021,7 @@ func TestExtCodeHash(t *testing.T) { opExtCodeHash(s) assert.Equal(t, gasLeft, s.gas) - assert.Equal(t, zero, s.pop()) + assert.Equal(t, zero.Uint64(), s.pop().Uint64()) }) t.Run("NoForks", func(t *testing.T) { @@ -2288,11 +2277,39 @@ func Test_opReturnDataCopy(t *testing.T) { opReturnDataCopy(state) - assert.Equal(t, test.resultState, state) + assert.True(t, CompareStates(test.resultState, state)) }) } } +// Since the state is complex structure, here is the specialized comparison +// function that checks significant fields. This function should be updated +// to suite future needs. +func CompareStates(a *state, b *state) bool { + // Compare simple fields + if a.ip != b.ip || a.lastGasCost != b.lastGasCost || a.sp != b.sp || !errors.Is(a.err, b.err) || a.stop != b.stop || a.gas != b.gas { + return false + } + + // Deep compare slices + if !reflect.DeepEqual(a.code, b.code) || !reflect.DeepEqual(a.tmp, b.tmp) || !reflect.DeepEqual(a.returnData, b.returnData) || !reflect.DeepEqual(a.memory, b.memory) { + return false + } + + // Deep comparison of stacks + if len(a.stack) != len(b.stack) { + return false + } + + for i := range a.stack { + if a.stack[i].Cmp(b.stack[i]) != 0 { + return false + } + } + + return true +} + func Test_opCall(t *testing.T) { t.Parallel() diff --git a/state/runtime/evm/state.go b/state/runtime/evm/state.go index 72c1b8d70e..8f6fa2ebd0 100644 --- a/state/runtime/evm/state.go +++ b/state/runtime/evm/state.go @@ -76,6 +76,8 @@ type state struct { returnData []byte ret []byte + + unsafepool common.UnsafePool[*big.Int] } func (c *state) reset() { @@ -95,6 +97,13 @@ func (c *state) reset() { c.memory[i] = 0 } + // Before stack cleanup, return instances of big.Int to the pool + // for the future usage + for i := range c.stack { + c.unsafepool.Put(func(x *big.Int) *big.Int { + return x.SetInt64(0) + }, c.stack[i]) + } c.stack = c.stack[:0] c.tmp = c.tmp[:0] c.ret = c.ret[:0] @@ -136,7 +145,10 @@ func (c *state) push1() *big.Int { return c.stack[c.sp-1] } - v := big.NewInt(0) + v := c.unsafepool.Get(func() *big.Int { + return big.NewInt(0) + }) + c.stack = append(c.stack, v) c.sp++ @@ -180,10 +192,6 @@ func (c *state) pop() *big.Int { o := c.stack[c.sp-1] c.sp-- - if o.Cmp(zero) == 0 { - return big.NewInt(0) - } - return o } @@ -261,8 +269,9 @@ func (c *state) Run() ([]byte, error) { // execute the instruction inst.inst(c) - c.captureExecution(op.String(), ipCopy, gasCopy, gasCopy-c.gas) - + if c.host.GetTracer() != nil { + c.captureExecution(op.String(), ipCopy, gasCopy, gasCopy-c.gas) + } // check if stack size exceeds the max size if c.sp > stackSize { c.exit(&runtime.StackOverflowError{StackLen: c.sp, Limit: stackSize})