Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session key fixes and test #2138

Merged
merged 2 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ go 1.21.11
replace github.com/docker/docker => github.com/docker/docker v20.10.3-0.20220224222438-c78f6963a1c0+incompatible

require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0
github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v1.1.0
github.com/FantasyJony/openzeppelin-merkle-tree-go v1.1.3
github.com/Microsoft/go-winio v0.6.2
Expand Down Expand Up @@ -60,7 +61,6 @@ require (
require (
filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
github.com/DataDog/zstd v1.5.6 // indirect
Expand Down
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0hS+6+I79yEDJBqVNcqUzU=
github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0 h1:JZg6HRh6W6U4OLl6lk7BZ7BLisIzM9dG1R50zUk9C/M=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.16.0/go.mod h1:YL1xnZ6QejvQHWJrX/AvhFl4WW4rqHVoKspWNVwFk0M=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc=
Expand Down
143 changes: 142 additions & 1 deletion integration/tengateway/tengateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import (
"testing"
"time"

"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ten-protocol/go-ten/go/common/gethapi"

"github.com/ten-protocol/go-ten/go/responses"

"github.com/ten-protocol/go-ten/lib/gethfork/rpc"

"github.com/ten-protocol/go-ten/tools/walletextension"
Expand Down Expand Up @@ -51,7 +56,7 @@ func init() { //nolint:gochecknoinits
LogDir: testLogs,
TestType: "tengateway",
TestSubtype: "test",
LogLevel: log.LvlInfo,
LogLevel: log.LvlTrace,
})
}

Expand Down Expand Up @@ -115,6 +120,7 @@ func TestTenGateway(t *testing.T) {
"testInvokeNonSensitiveMethod": testInvokeNonSensitiveMethod,
"testGetStorageAtForReturningUserID": testGetStorageAtForReturningUserID,
"testRateLimiter": testRateLimiter,
"testSessionKeys": testSessionKeys,
} {
t.Run(name, func(t *testing.T) {
test(t, startPort, httpURL, wsURL, w)
Expand Down Expand Up @@ -167,6 +173,141 @@ func testRateLimiter(t *testing.T, _ int, httpURL, wsURL string, w wallet.Wallet
require.Equal(t, "rate limit exceeded", err.Error())
}

func testSessionKeys(t *testing.T, _ int, httpURL, wsURL string, w wallet.Wallet) {
user0, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL)
require.NoError(t, err)
testlog.Logger().Info("Created user with encryption token", "t", user0.tgClient.UserID())
// register the user so we can call the endpoints that require authentication
err = user0.RegisterAccounts()
require.NoError(t, err)

var amountToTransfer int64 = 1_000_000_000_000_000_000
// Transfer some funds to user1 to be able to make transactions
_, err = transferETHToAddress(user0.HTTPClient, user0.Wallets[0], user0.Wallets[0].Address(), amountToTransfer)
require.NoError(t, err)

// call BalanceAt - fist call should be successful
_, err = user0.HTTPClient.BalanceAt(context.Background(), user0.Wallets[0].Address(), nil)
require.NoError(t, err)

contractAddr := deployContract(t, w, user0)

// create session key
skAddr, err := user0.HTTPClient.StorageAt(context.Background(), gethcommon.HexToAddress(common.CreateSessionKeyCQMethod), gethcommon.Hash{}, nil)
require.NoError(t, err)
skAddress := gethcommon.BytesToAddress(skAddr)

// move some funds to the SK
var skAmount int64 = 100_000_000_000_000_000
_, err = transferETHToAddress(user0.HTTPClient, user0.Wallets[0], skAddress, skAmount)
require.NoError(t, err)

// activate SK
_, err = user0.HTTPClient.StorageAt(context.Background(), gethcommon.HexToAddress(common.ActivateSessionKeyCQMethod), gethcommon.Hash{}, nil)
require.NoError(t, err)

skNonce := uint64(0)

// interact with the contract - unsigned tx calling "sendRawTransaction"
contractInteractionData, err := eventsContractABI.Pack("setMessage", "user0PrivateEvent")
require.NoError(t, err)
rec, err := interactWithSmartContractUnsigned(user0.HTTPClient, true, skNonce, contractAddr, contractInteractionData, nil)
require.NoError(t, err)
require.Equal(t, uint64(0x1), rec.Status)

// move money back - unsigned tx calling "sendTransaction"
skNonce++
rec1, err := interactWithSmartContractUnsigned(user0.HTTPClient, false, skNonce, user0.Wallets[0].Address(), nil, big.NewInt(1_000))
require.NoError(t, err)
require.Equal(t, uint64(0x1), rec1.Status)

// deactivate
_, err = user0.HTTPClient.StorageAt(context.Background(), gethcommon.HexToAddress(common.DeactivateSessionKeyCQMethod), gethcommon.Hash{}, nil)
require.NoError(t, err)

// interact with the contract - unsigned - should fail
skNonce++
rec2, err := interactWithSmartContractUnsigned(user0.HTTPClient, false, skNonce, contractAddr, contractInteractionData, nil)
require.Error(t, err)
require.Nil(t, rec2)
}

func deployContract(t *testing.T, w wallet.Wallet, user0 *GatewayUser) gethcommon.Address {
// deploy events contract
deployTx := &types.LegacyTx{
Nonce: w.GetNonceAndIncrement(),
Gas: uint64(1_000_000),
GasPrice: gethcommon.Big1,
Data: gethcommon.FromHex(eventsContractBytecode),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats this eventsContract? It's not in the PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same ... don't see the events contract in this PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a generic contract. It doesn't really matter for this PR. The idea here is to just send unsigned transactions, and check that they get signed

}

err := getFeeAndGas(user0.HTTPClient, w, deployTx)
require.NoError(t, err)

signedTx, err := w.SignTransaction(deployTx)
require.NoError(t, err)

err = user0.HTTPClient.SendTransaction(context.Background(), signedTx)
require.NoError(t, err)

contractReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), user0.HTTPClient, signedTx.Hash(), time.Minute)
require.NoError(t, err)
return contractReceipt.ContractAddress
}

func interactWithSmartContractUnsigned(client *ethclient.Client, sendRaw bool, nonce uint64, contractAddress gethcommon.Address, contractInteractionData []byte, value *big.Int) (*types.Receipt, error) {
var result responses.GasPriceType
err := client.Client().CallContext(context.Background(), &result, tenrpc.GasPrice)
if err != nil {
return nil, err
}

var txHash gethcommon.Hash

if sendRaw {
interactionTx := types.LegacyTx{
Nonce: nonce,
To: &contractAddress,
Gas: uint64(10_000_000),
GasPrice: result.ToInt(),
Data: contractInteractionData,
Value: value,
}
unSignedTx := types.NewTx(&interactionTx)
blob, err := unSignedTx.MarshalBinary()
if err != nil {
return nil, err
}
err = client.Client().CallContext(context.Background(), &txHash, "eth_sendRawTransaction", hexutil.Encode(blob))
if err != nil {
return nil, err
}
} else {
n := hexutil.Uint64(nonce)
g := hexutil.Uint64(10_000_000)
d := hexutil.Bytes(contractInteractionData)
interactionTx := gethapi.TransactionArgs{
Nonce: &n,
To: &contractAddress,
Gas: &g,
GasPrice: &result,
Data: &d,
Value: (*hexutil.Big)(value),
}
err = client.Client().CallContext(context.Background(), &txHash, "eth_sendTransaction", interactionTx)
if err != nil {
return nil, err
}
}

txReceipt, err := integrationCommon.AwaitReceiptEth(context.Background(), client, txHash, 10*time.Second)
if err != nil {
return nil, err
}

return txReceipt, nil
}

func testNewHeadsSubscription(t *testing.T, _ int, httpURL, wsURL string, w wallet.Wallet) {
user0, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL)
require.NoError(t, err)
Expand Down
46 changes: 36 additions & 10 deletions tools/walletextension/rpcapi/transaction_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rpcapi

import (
"context"
"fmt"

"github.com/ten-protocol/go-ten/tools/walletextension/cache"

Expand Down Expand Up @@ -103,13 +104,26 @@ func (s *TransactionAPI) GetTransactionReceipt(ctx context.Context, hash common.
}

func (s *TransactionAPI) SendTransaction(ctx context.Context, args gethapi.TransactionArgs) (common.Hash, error) {
//txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &AuthExecCfg{account: args.From, timeout: sendTransactionDuration}, "eth_sendTransaction", args)
//if err != nil {
// return common.Hash{}, err
//}
//return *txRec, err
// not implemented for now. We might use this for session keys.
return common.Hash{}, rpcNotImplemented
user, err := extractUserForRequest(ctx, s.we)
if err != nil {
return common.Hash{}, err
}
if !user.ActiveSK {
return common.Hash{}, fmt.Errorf("please activate session key")
}

// when there is an active Session Key, sign all incoming transactions with that SK
signedTx, err := s.we.SKManager.SignTx(ctx, user, args.ToTransaction())
if err != nil {
return common.Hash{}, err
}

blob, err := signedTx.MarshalBinary()
if err != nil {
return common.Hash{}, err
}

return s.sendRawTx(ctx, blob)
}

type SignTransactionResult struct {
Expand All @@ -127,16 +141,28 @@ func (s *TransactionAPI) SendRawTransaction(ctx context.Context, input hexutil.B
return common.Hash{}, err
}

signedTx := input
signedTxBlob := input
// when there is an active Session Key, sign all incoming transactions with that SK
if user.ActiveSK && user.SessionKey != nil {
signedTx, err = s.we.SKManager.SignTx(ctx, user, input)
tx := new(types.Transaction)
if err = tx.UnmarshalBinary(input); err != nil {
return common.Hash{}, err
}
signedTx, err := s.we.SKManager.SignTx(ctx, user, tx)
if err != nil {
return common.Hash{}, err
}
signedTxBlob, err = signedTx.MarshalBinary()
if err != nil {
return common.Hash{}, err
}
}

txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &AuthExecCfg{tryAll: true, timeout: sendTransactionDuration}, "eth_sendRawTransaction", signedTx)
return s.sendRawTx(ctx, signedTxBlob)
}

func (s *TransactionAPI) sendRawTx(ctx context.Context, input hexutil.Bytes) (common.Hash, error) {
txRec, err := ExecAuthRPC[common.Hash](ctx, s.we, &AuthExecCfg{tryAll: true, timeout: sendTransactionDuration}, "eth_sendRawTransaction", input)
if err != nil {
return common.Hash{}, err
}
Expand Down
23 changes: 11 additions & 12 deletions tools/walletextension/services/sk_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package services
import (
"context"
"fmt"
"math/big"

"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/core/types"

"github.com/ethereum/go-ethereum/crypto"
Expand All @@ -23,7 +23,7 @@ import (
// From the POV of the Ten network - a session key is a normal account key
type SKManager interface {
CreateSessionKey(user *common.GWUser) (*common.GWSessionKey, error)
SignTx(ctx context.Context, user *common.GWUser, input hexutil.Bytes) (hexutil.Bytes, error)
SignTx(ctx context.Context, user *common.GWUser, input *types.Transaction) (*types.Transaction, error)
}

type skManager struct {
Expand Down Expand Up @@ -92,17 +92,16 @@ func (m *skManager) createSK(user *common.GWUser) (*common.GWSessionKey, error)
}, nil
}

func (m *skManager) SignTx(ctx context.Context, user *common.GWUser, input hexutil.Bytes) (hexutil.Bytes, error) {
tx := new(types.Transaction)
if err := tx.UnmarshalBinary(input); err != nil {
return hexutil.Bytes{}, err
}

signer := types.NewLondonSigner(tx.ChainId())
func (m *skManager) SignTx(ctx context.Context, user *common.GWUser, tx *types.Transaction) (*types.Transaction, error) {
prvKey := user.SessionKey.PrivateKey.ExportECDSA()
signer := types.NewCancunSigner(big.NewInt(int64(m.config.TenChainID)))

tx, err := types.SignTx(tx, signer, user.SessionKey.PrivateKey.ExportECDSA())
stx, err := types.SignTx(tx, signer, prvKey)
if err != nil {
return hexutil.Bytes{}, err
return nil, err
}
return tx.MarshalBinary()

m.logger.Debug("Signed transaction with session key", "stxHash", stx.Hash().Hex())

return stx, nil
}
5 changes: 3 additions & 2 deletions tools/walletextension/storage/database/common/db_types.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package common

import (
"crypto/x509"
"fmt"

"github.com/ethereum/go-ethereum/crypto"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto/ecies"
"github.com/ten-protocol/go-ten/go/common/viewingkey"
Expand Down Expand Up @@ -50,7 +51,7 @@ func (userDB *GWUserDB) ToGWUser() (*wecommon.GWUser, error) {
}

if userDB.SessionKey != nil {
ecdsaPrivateKey, err := x509.ParseECPrivateKey(userDB.SessionKey.PrivateKey)
ecdsaPrivateKey, err := crypto.ToECDSA(userDB.SessionKey.PrivateKey)
if err != nil {
return nil, fmt.Errorf("failed to parse ECDSA private key: %w", err)
}
Expand Down