Skip to content

Commit

Permalink
Pooling of big.Int instances in the EVM (#124)
Browse files Browse the repository at this point in the history
* Multiple performance improvements:
- Pooling of big.Int instances on the stack to avoid frequent allocations/cleanup
- Removing unnecessary big.Int allocation in pop() method and fixing tests
- Avoiding unnecessary op.String() calls when tracer is not initialized.

* - Fixing failed TestNewUnsafePool test.

* Remove unnecessary case

* TestPush0 fix

---------

Co-authored-by: Goran Rojovic <[email protected]>
  • Loading branch information
cokicm and goran-ethernal authored Feb 21, 2024
1 parent 536d3d5 commit f2b212e
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 36 deletions.
40 changes: 40 additions & 0 deletions helper/common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -375,3 +375,43 @@ func EncodeUint64ToBytes(value uint64) []byte {
func EncodeBytesToUint64(b []byte) uint64 {
return binary.BigEndian.Uint64(b)
}

// Generic object pool implementation intended to be used in single-threaded
// manner and avoid synchronization overhead. It could be probably additionally
// be improved by using circular buffer as oposed to stack.
type UnsafePool[T any] struct {
stack []T
}

// Creates new instance of UnsafePool. Depending on observed usage, pool size
// should be set on creation to avoid pool resizing
func NewUnsafePool[T any]() *UnsafePool[T] {
return &UnsafePool[T]{}
}

// Get retrieves an object from the unsafepool, or allocates a new one if the pool
// is empty. The allocation logic (i.e., creating a new object of type T) needs to
// be provided externally, as Go's type system does not allow calling constructors
// or functions specific to T without an interface.
func (f *UnsafePool[T]) Get(newFunc func() T) T {
n := len(f.stack)
if n == 0 {
// Allocate a new T instance using the provided newFunc if the stack is empty.
return newFunc()
}

obj := f.stack[n-1]
f.stack = f.stack[:n-1]

return obj
}

// Put returns an object to the pool and executes reset function if provided. Reset
// function is used to return the T instance to initial state.
func (f *UnsafePool[T]) Put(resetFunc func(T) T, obj T) {
if resetFunc != nil {
obj = resetFunc(obj)
}

f.stack = append(f.stack, obj)
}
53 changes: 53 additions & 0 deletions helper/common/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,56 @@ func Test_SafeAddUint64(t *testing.T) {
})
}
}

func TestNewUnsafePool(t *testing.T) {
pool := NewUnsafePool[int]()

require.NotNilf(t, pool, "NewUnsafePool returned nil")

require.Empty(t, pool.stack, "Expected empty pool.")
}

func TestUnsafePoolGetWhenEmpty(t *testing.T) {
pool := NewUnsafePool[int]()
newInt := func() int {
return 1
}

obj := pool.Get(newInt)

require.Equal(t, 1, obj, "Expected 1 from newFunc, got %v", obj)
}

func TestUnsafePoolGetPut(t *testing.T) {
pool := NewUnsafePool[int]()
resetInt := func(i int) int {
return 0
}

// Initially put an object into the pool.
pool.Put(resetInt, 2)

// Retrieve the object, which should now be the reset value.
obj := pool.Get(func() int { return 3 })

// Expecting the original object, not the one from newFunc
require.Equal(t, 0, obj, "Expected 0 from the pool, got %v", obj)

// Test if Get correctly uses newFunc when pool is empty again.
obj = pool.Get(func() int { return 3 })

require.Equal(t, 3, obj, "Expected 3 from newFunc, got %v", obj)
}

func TestUnsafePoolPutWithReset(t *testing.T) {
pool := NewUnsafePool[int]()
resetInt := func(i int) int {
return 0
}

// Put an object into the pool with a reset function.
pool.Put(resetInt, 5)

// Directly check if the object was reset.
require.Equal(t, 0, pool.stack[0], "Expected object to be reset to 0, got %v", pool.stack[0])
}
75 changes: 46 additions & 29 deletions state/runtime/evm/instructions_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package evm

import (
"errors"
"math/big"
"reflect"
"testing"

"github.com/0xPolygon/polygon-edge/chain"
Expand Down Expand Up @@ -37,9 +39,9 @@ func testLogicalOperation(t *testing.T, f instruction, test OperandsLogical, s *
f(s)

if test.expectedResult {
assert.Equal(t, one, s.pop())
assert.Equal(t, one.Uint64(), s.pop().Uint64())
} else {
assert.Equal(t, zero, s.pop())
assert.Equal(t, zero.Uint64(), s.pop().Uint64())
}
}

Expand All @@ -57,7 +59,7 @@ func testArithmeticOperation(t *testing.T, f instruction, test OperandsArithmeti

f(s)

assert.Equal(t, test.expectedResult, s.pop())
assert.Equal(t, test.expectedResult.Uint64(), s.pop().Uint64())
}

func TestAdd(t *testing.T) {
Expand Down Expand Up @@ -355,7 +357,7 @@ func TestPush0(t *testing.T) {
defer closeFn()

opPush0(s)
require.Equal(t, zero, s.pop())
require.Equal(t, zero.Uint64(), s.pop().Uint64())
})

t.Run("single push0 (EIP-3855 disabled)", func(t *testing.T) {
Expand All @@ -378,7 +380,7 @@ func TestPush0(t *testing.T) {
}

for i := 0; i < stackSize; i++ {
require.Equal(t, zero, s.pop())
require.Equal(t, zero.Uint64(), s.pop().Uint64())
}
})
}
Expand Down Expand Up @@ -857,41 +859,28 @@ func TestCallValue(t *testing.T) {
s.msg.Value = value

opCallValue(s)
assert.Equal(t, value, s.pop())
assert.Equal(t, value.Uint64(), s.pop().Uint64())
})

t.Run("Msg Value nil", func(t *testing.T) {
s, cancelFn := getState(&chain.ForksInTime{})
defer cancelFn()

opCallValue(s)
assert.Equal(t, zero, s.pop())
assert.Equal(t, zero.Uint64(), s.pop().Uint64())
})
}

func TestCallDataLoad(t *testing.T) {
t.Run("NonZeroOffset", func(t *testing.T) {
s, cancelFn := getState(&chain.ForksInTime{})
defer cancelFn()

s.push(one)

s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()}

opCallDataLoad(s)
assert.Equal(t, zero, s.pop())
})
t.Run("ZeroOffset", func(t *testing.T) {
s, cancelFn := getState(&chain.ForksInTime{})
defer cancelFn()
s, cancelFn := getState(&chain.ForksInTime{})
defer cancelFn()

s.push(zero)
s.push(one)

s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()}
s.msg = &runtime.Contract{Input: big.NewInt(7).Bytes()}

opCallDataLoad(s)
assert.NotEqual(t, zero, s.pop())
})
opCallDataLoad(s)
assert.Equal(t, zero.Uint64(), s.pop().Uint64())
}

func TestCallDataSize(t *testing.T) {
Expand Down Expand Up @@ -1013,7 +1002,7 @@ func TestExtCodeHash(t *testing.T) {
opExtCodeHash(s)

assert.Equal(t, s.gas, gasLeft)
assert.Equal(t, one, s.pop())
assert.Equal(t, one.Uint64(), s.pop().Uint64())
})

t.Run("NonIstanbul", func(t *testing.T) {
Expand All @@ -1032,7 +1021,7 @@ func TestExtCodeHash(t *testing.T) {

opExtCodeHash(s)
assert.Equal(t, gasLeft, s.gas)
assert.Equal(t, zero, s.pop())
assert.Equal(t, zero.Uint64(), s.pop().Uint64())
})

t.Run("NoForks", func(t *testing.T) {
Expand Down Expand Up @@ -2288,11 +2277,39 @@ func Test_opReturnDataCopy(t *testing.T) {

opReturnDataCopy(state)

assert.Equal(t, test.resultState, state)
assert.True(t, CompareStates(test.resultState, state))
})
}
}

// Since the state is complex structure, here is the specialized comparison
// function that checks significant fields. This function should be updated
// to suite future needs.
func CompareStates(a *state, b *state) bool {
// Compare simple fields
if a.ip != b.ip || a.lastGasCost != b.lastGasCost || a.sp != b.sp || !errors.Is(a.err, b.err) || a.stop != b.stop || a.gas != b.gas {
return false
}

// Deep compare slices
if !reflect.DeepEqual(a.code, b.code) || !reflect.DeepEqual(a.tmp, b.tmp) || !reflect.DeepEqual(a.returnData, b.returnData) || !reflect.DeepEqual(a.memory, b.memory) {
return false
}

// Deep comparison of stacks
if len(a.stack) != len(b.stack) {
return false
}

for i := range a.stack {
if a.stack[i].Cmp(b.stack[i]) != 0 {
return false
}
}

return true
}

func Test_opCall(t *testing.T) {
t.Parallel()

Expand Down
23 changes: 16 additions & 7 deletions state/runtime/evm/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ type state struct {

returnData []byte
ret []byte

unsafepool common.UnsafePool[*big.Int]
}

func (c *state) reset() {
Expand All @@ -95,6 +97,13 @@ func (c *state) reset() {
c.memory[i] = 0
}

// Before stack cleanup, return instances of big.Int to the pool
// for the future usage
for i := range c.stack {
c.unsafepool.Put(func(x *big.Int) *big.Int {
return x.SetInt64(0)
}, c.stack[i])
}
c.stack = c.stack[:0]
c.tmp = c.tmp[:0]
c.ret = c.ret[:0]
Expand Down Expand Up @@ -136,7 +145,10 @@ func (c *state) push1() *big.Int {
return c.stack[c.sp-1]
}

v := big.NewInt(0)
v := c.unsafepool.Get(func() *big.Int {
return big.NewInt(0)
})

c.stack = append(c.stack, v)
c.sp++

Expand Down Expand Up @@ -180,10 +192,6 @@ func (c *state) pop() *big.Int {
o := c.stack[c.sp-1]
c.sp--

if o.Cmp(zero) == 0 {
return big.NewInt(0)
}

return o
}

Expand Down Expand Up @@ -261,8 +269,9 @@ func (c *state) Run() ([]byte, error) {
// execute the instruction
inst.inst(c)

c.captureExecution(op.String(), ipCopy, gasCopy, gasCopy-c.gas)

if c.host.GetTracer() != nil {
c.captureExecution(op.String(), ipCopy, gasCopy, gasCopy-c.gas)
}
// check if stack size exceeds the max size
if c.sp > stackSize {
c.exit(&runtime.StackOverflowError{StackLen: c.sp, Limit: stackSize})
Expand Down

0 comments on commit f2b212e

Please sign in to comment.