Skip to content

Commit

Permalink
fix: increment sequence number at every call and create (#102)
Browse files Browse the repository at this point in the history
* fix increment sequence number at every call and create

* fix comments
  • Loading branch information
beer-1 authored Nov 8, 2024
1 parent 883a73e commit 5df5933
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 11 deletions.
3 changes: 2 additions & 1 deletion x/evm/ante/sequence.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func (isd IncrementSequenceDecorator) AnteHandle(ctx sdk.Context, tx sdk.Tx, sim
}

// set a flag in context to indicate that sequence has been incremented in ante handler
ctx = ctx.WithValue(ContextKeySequenceIncremented, true)
incremented := true // use pointer to enable revert after first call
ctx = ctx.WithValue(ContextKeySequenceIncremented, &incremented)
return next(ctx, tx, simulate)
}
36 changes: 26 additions & 10 deletions x/evm/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (ms *msgServerImpl) Create(ctx context.Context, msg *types.MsgCreate) (*typ
}

// handle cosmos<>evm different sequence increment logic
ctx, err = ms.handleSequenceIncremented(ctx, sender, true)
err = ms.handleSequenceIncremented(ctx, sender, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -98,7 +98,7 @@ func (ms *msgServerImpl) Create2(ctx context.Context, msg *types.MsgCreate2) (*t
}

// handle cosmos<>evm different sequence increment logic
ctx, err = ms.handleSequenceIncremented(ctx, sender, true)
err = ms.handleSequenceIncremented(ctx, sender, true)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -163,7 +163,7 @@ func (ms *msgServerImpl) Call(ctx context.Context, msg *types.MsgCall) (*types.M
}

// handle cosmos<>evm different sequence increment logic
ctx, err = ms.handleSequenceIncremented(ctx, sender, false)
err = ms.handleSequenceIncremented(ctx, sender, false)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -258,20 +258,36 @@ func (ms *msgServerImpl) testFeeDenom(ctx context.Context, params types.Params)

// In the Cosmos SDK, the sequence number is incremented in the ante handler.
// In the EVM, the sequence number is incremented during the execution of create and create2 messages.
// However, for call messages, the sequence number is incremented in the ante handler like the Cosmos SDK.
// To prevent double incrementing the sequence number during EVM execution, we need to decrement it here for create messages.
func (k *msgServerImpl) handleSequenceIncremented(ctx context.Context, sender sdk.AccAddress, isCreate bool) (context.Context, error) {
//
// If the sequence number is already incremented in the ante handler and the message is create, decrement the sequence number to prevent double incrementing.
// If the sequence number is not incremented in the ante handler and the message is call, increment the sequence number to ensure proper sequencing.
func (k *msgServerImpl) handleSequenceIncremented(ctx context.Context, sender sdk.AccAddress, isCreate bool) error {
sdkCtx := sdk.UnwrapSDKContext(ctx)
if sdkCtx.Value(evmante.ContextKeySequenceIncremented) == nil {
return nil
}

// decrement sequence of the sender
if isCreate && sdkCtx.Value(evmante.ContextKeySequenceIncremented) != nil {
incremented := sdkCtx.Value(evmante.ContextKeySequenceIncremented).(*bool)
if isCreate && *incremented {
// if the sequence is already incremented, decrement it to prevent double incrementing the sequence number at create.
acc := k.accountKeeper.GetAccount(ctx, sender)
if err := acc.SetSequence(acc.GetSequence() - 1); err != nil {
return ctx, err
return err
}

k.accountKeeper.SetAccount(ctx, acc)
} else if !isCreate && !*incremented {
// if the sequence is not incremented and the message is call, increment the sequence number.
acc := k.accountKeeper.GetAccount(ctx, sender)
if err := acc.SetSequence(acc.GetSequence() + 1); err != nil {
return err
}

k.accountKeeper.SetAccount(ctx, acc)
}

return sdkCtx.WithValue(evmante.ContextKeySequenceIncremented, nil), nil
// set the flag to false
*incremented = false

return nil
}
129 changes: 129 additions & 0 deletions x/evm/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"

"github.com/holiman/uint256"

evmante "github.com/initia-labs/minievm/x/evm/ante"
"github.com/initia-labs/minievm/x/evm/contracts/counter"
"github.com/initia-labs/minievm/x/evm/keeper"
"github.com/initia-labs/minievm/x/evm/types"

"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -176,3 +180,128 @@ func Test_MsgServer_UpdateParams(t *testing.T) {
})
require.ErrorContains(t, err, "sudoMint and sudoBurn")
}

func Test_MsgServer_NonceIncrement_Call(t *testing.T) {
ctx, input := createDefaultTestInput(t)
_, _, addr := keyPubAddr()
caller := common.BytesToAddress(addr.Bytes())

counterBz, err := hexutil.Decode(counter.CounterBin)
require.NoError(t, err)

retBz, contractAddr, _, err := input.EVMKeeper.EVMCreate(ctx, caller, counterBz, nil, nil)
require.NoError(t, err)
require.NotEmpty(t, retBz)
require.Len(t, contractAddr, 20)

parsed, err := counter.CounterMetaData.GetAbi()
require.NoError(t, err)

// increment sequence
incremented := true
ctx = ctx.WithValue(evmante.ContextKeySequenceIncremented, &incremented)
acc := input.AccountKeeper.GetAccount(ctx, addr)
seq := acc.GetSequence() + 1
acc.SetSequence(seq)
input.AccountKeeper.SetAccount(ctx, acc)

inputBz, err := parsed.Pack("increase")
require.NoError(t, err)

// should not increment sequence
msgServer := keeper.NewMsgServerImpl(&input.EVMKeeper)
res, err := msgServer.Call(ctx, &types.MsgCall{
Sender: addr.String(),
ContractAddr: contractAddr.Hex(),
Input: hexutil.Encode(inputBz),
})
require.NoError(t, err)
require.Equal(t, "0x", res.Result)
require.NotEmpty(t, res.Logs)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq, acc.GetSequence())

// call again should increment sequence
res, err = msgServer.Call(ctx, &types.MsgCall{
Sender: addr.String(),
ContractAddr: contractAddr.Hex(),
Input: hexutil.Encode(inputBz),
})
require.NoError(t, err)
require.Equal(t, "0x", res.Result)
require.NotEmpty(t, res.Logs)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq+1, acc.GetSequence())

// create should increment sequence
_, err = msgServer.Create(ctx, &types.MsgCreate{
Sender: addr.String(),
Code: counter.CounterBin,
})
require.NoError(t, err)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq+2, acc.GetSequence())
}

func Test_MsgServer_NonceIncrement_Create(t *testing.T) {
ctx, input := createDefaultTestInput(t)
_, _, addr := keyPubAddr()
caller := common.BytesToAddress(addr.Bytes())

counterBz, err := hexutil.Decode(counter.CounterBin)
require.NoError(t, err)

retBz, contractAddr, _, err := input.EVMKeeper.EVMCreate(ctx, caller, counterBz, nil, nil)
require.NoError(t, err)
require.NotEmpty(t, retBz)
require.Len(t, contractAddr, 20)

parsed, err := counter.CounterMetaData.GetAbi()
require.NoError(t, err)

// increment sequence
incremented := true
ctx = ctx.WithValue(evmante.ContextKeySequenceIncremented, &incremented)
acc := input.AccountKeeper.GetAccount(ctx, addr)
seq := acc.GetSequence() + 1
acc.SetSequence(seq)
input.AccountKeeper.SetAccount(ctx, acc)

inputBz, err := parsed.Pack("increase")
require.NoError(t, err)

// should not increment sequence
msgServer := keeper.NewMsgServerImpl(&input.EVMKeeper)
_, err = msgServer.Create(ctx, &types.MsgCreate{
Sender: addr.String(),
Code: counter.CounterBin,
})
require.NoError(t, err)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq, acc.GetSequence())

// call again should increment sequence
_, err = msgServer.Call(ctx, &types.MsgCall{
Sender: addr.String(),
ContractAddr: contractAddr.Hex(),
Input: hexutil.Encode(inputBz),
})
require.NoError(t, err)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq+1, acc.GetSequence())

// create should increment sequence
_, err = msgServer.Create(ctx, &types.MsgCreate{
Sender: addr.String(),
Code: counter.CounterBin,
})
require.NoError(t, err)

acc = input.AccountKeeper.GetAccount(ctx, addr)
require.Equal(t, seq+2, acc.GetSequence())
}

0 comments on commit 5df5933

Please sign in to comment.