Skip to content

Commit

Permalink
refactor users to have a wallet, not be a wallet
Browse files Browse the repository at this point in the history
  • Loading branch information
BedrockSquirrel committed Feb 20, 2024
1 parent 53e195c commit 72e9e2f
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 149 deletions.
2 changes: 1 addition & 1 deletion integration/networktest/actions/native_fund_actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func (s *SendNativeFunds) Run(ctx context.Context, _ networktest.NetworkConnecto
if err != nil {
return ctx, err
}
txHash, err := user.SendFunds(ctx, target.Address(), s.Amount)
txHash, err := user.SendFunds(ctx, target.Wallet().Address(), s.Amount)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions integration/networktest/actions/setup_actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ func (c *CreateTestUser) Run(ctx context.Context, network networktest.NetworkCon
if err != nil {
return ctx, fmt.Errorf("failed to get required gateway URL: %w", err)
}
user, err = userwallet.NewGatewayUser(wal.PrivateKey(), gwURL, logger)
user, err = userwallet.NewGatewayUser(wal, gwURL, logger)
if err != nil {
return ctx, fmt.Errorf("failed to create gateway user: %w", err)
}
} else {
// traffic sim users are round robin-ed onto the validators for now (todo (@matt) - make that overridable)
user = userwallet.NewUserWallet(wal.PrivateKey(), network.ValidatorRPCAddress(c.UserID%network.NumValidators()), logger)
user = userwallet.NewUserWallet(wal, network.ValidatorRPCAddress(c.UserID%network.NumValidators()), logger)
}
return storeTestUser(ctx, c.UserID, user), nil
}
Expand All @@ -58,7 +58,7 @@ func (a *AllocateFaucetFunds) Run(ctx context.Context, network networktest.Netwo
if err != nil {
return ctx, err
}
return ctx, network.AllocateFaucetFunds(ctx, user.Address())
return ctx, network.AllocateFaucetFunds(ctx, user.Wallet().Address())
}

func (a *AllocateFaucetFunds) Verify(_ context.Context, _ networktest.NetworkConnector) error {
Expand Down
3 changes: 2 additions & 1 deletion integration/networktest/env/testnet.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@ func NewTestnetConnectorWithFaucetAccount(seqRPCAddr string, validatorRPCAddress
if err != nil {
panic(err)
}
wal := wallet.NewInMemoryWalletFromPK(big.NewInt(integration.TenChainID), ecdsaKey, testlog.Logger())
return &testnetConnector{
seqRPCAddress: seqRPCAddr,
validatorRPCAddresses: validatorRPCAddressses,
faucetWallet: userwallet.NewUserWallet(ecdsaKey, validatorRPCAddressses[0], testlog.Logger(), userwallet.WithChainID(big.NewInt(integration.TenChainID))),
faucetWallet: userwallet.NewUserWallet(wal, validatorRPCAddressses[0], testlog.Logger()),
l1RPCURL: l1RPCAddress,
tenGatewayURL: tenGatewayURL,
}
Expand Down
95 changes: 14 additions & 81 deletions integration/networktest/userwallet/authclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@ package userwallet

import (
"context"
"crypto/ecdsa"
"errors"
"fmt"
"math/big"
"time"

gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
gethlog "github.com/ethereum/go-ethereum/log"
"github.com/ten-protocol/go-ten/go/common/retry"
"github.com/ten-protocol/go-ten/go/obsclient"
"github.com/ten-protocol/go-ten/go/rpc"
"github.com/ten-protocol/go-ten/integration"
"github.com/ten-protocol/go-ten/go/wallet"
)

const (
Expand All @@ -26,41 +24,19 @@ const (
// AuthClientUser is a test user that uses the auth client to talk to directly to a node
// Note: AuthClientUser is **not** thread-safe for a single wallet (creates nonce conflicts etc.)
type AuthClientUser struct {
privateKey *ecdsa.PrivateKey
publicKey *ecdsa.PublicKey
accountAddress gethcommon.Address
chainID *big.Int
rpcEndpoint string

// state managed by the wallet
nonce uint64
wal wallet.Wallet
rpcEndpoint string

client *obsclient.AuthObsClient // lazily initialised and authenticated on first usage
logger gethlog.Logger
}

// Option modifies a AuthClientUser. See below for options, in the form `WithXxx(xxx)` that can be chained into constructor
type Option func(wallet *AuthClientUser)

func NewUserWallet(pk *ecdsa.PrivateKey, rpcEndpoint string, logger gethlog.Logger, opts ...Option) *AuthClientUser {
publicKeyECDSA, ok := pk.Public().(*ecdsa.PublicKey)
if !ok {
// this shouldn't happen
logger.Crit("error casting public key to ECDSA")
}
wal := &AuthClientUser{
privateKey: pk,
publicKey: publicKeyECDSA,
accountAddress: crypto.PubkeyToAddress(*publicKeyECDSA),
chainID: big.NewInt(integration.TenChainID), // default, overridable using `WithChainID(...) opt`
rpcEndpoint: rpcEndpoint,
logger: logger,
func NewUserWallet(wal wallet.Wallet, rpcEndpoint string, logger gethlog.Logger) *AuthClientUser {
return &AuthClientUser{
wal: wal,
rpcEndpoint: rpcEndpoint,
logger: logger,
}
// apply any optional config to the wallet
for _, opt := range opts {
opt(wal)
}
return wal
}

func (s *AuthClientUser) SendFunds(ctx context.Context, addr gethcommon.Address, value *big.Int) (*gethcommon.Hash, error) {
Expand All @@ -70,7 +46,6 @@ func (s *AuthClientUser) SendFunds(ctx context.Context, addr gethcommon.Address,
}

txData := &types.LegacyTx{
Nonce: s.nonce,
Value: value,
To: &addr,
}
Expand All @@ -85,7 +60,7 @@ func (s *AuthClientUser) SendFunds(ctx context.Context, addr gethcommon.Address,
}

func (s *AuthClientUser) SendTransaction(ctx context.Context, tx types.TxData) (*gethcommon.Hash, error) {
signedTx, err := s.SignTransaction(tx)
signedTx, err := s.wal.SignTransaction(tx)
if err != nil {
return nil, fmt.Errorf("unable to sign transaction - %w", err)
}
Expand All @@ -97,7 +72,7 @@ func (s *AuthClientUser) SendTransaction(ctx context.Context, tx types.TxData) (

txHash := signedTx.Hash()
// transaction has been sent, we increment the nonce
s.nonce++
s.wal.GetNonceAndIncrement()
return &txHash, nil
}

Expand All @@ -115,50 +90,14 @@ func (s *AuthClientUser) AwaitReceipt(ctx context.Context, txHash *gethcommon.Ha
return receipt, err
}

func (s *AuthClientUser) Address() gethcommon.Address {
return s.accountAddress
}

func (s *AuthClientUser) SignTransaction(tx types.TxData) (*types.Transaction, error) {
return s.SignTransactionForChainID(tx, s.chainID)
}

func (s *AuthClientUser) SignTransactionForChainID(tx types.TxData, chainID *big.Int) (*types.Transaction, error) {
return types.SignNewTx(s.privateKey, types.NewLondonSigner(chainID), tx)
}

func (s *AuthClientUser) GetNonce() uint64 {
return s.nonce
}

func (s *AuthClientUser) PrivateKey() *ecdsa.PrivateKey {
return s.privateKey
}

//
// These methods allow the user to comply with the wallet.Wallet interface
//

func (s *AuthClientUser) ChainID() *big.Int {
return s.chainID
}

func (s *AuthClientUser) SetNonce(_ uint64) {
panic("AuthClientUser is designed to manage its own nonce - this method exists to support legacy interface methods")
}

func (s *AuthClientUser) GetNonceAndIncrement() uint64 {
panic("AuthClientUser is designed to manage its own nonce - this method exists to support legacy interface methods")
}

// EnsureClientSetup creates an authenticated RPC client (with a viewing key generated, signed and registered) when first called
// Also fetches current nonce value.
func (s *AuthClientUser) EnsureClientSetup(ctx context.Context) error {
if s.client != nil {
// client already setup
return nil
}
authClient, err := obsclient.DialWithAuth(s.rpcEndpoint, s, s.logger)
authClient, err := obsclient.DialWithAuth(s.rpcEndpoint, s.wal, s.logger)
if err != nil {
return err
}
Expand All @@ -169,7 +108,7 @@ func (s *AuthClientUser) EnsureClientSetup(ctx context.Context) error {
if err != nil {
return fmt.Errorf("unable to fetch client nonce - %w", err)
}
s.nonce = nonce
s.wal.SetNonce(nonce)

return nil
}
Expand All @@ -187,12 +126,6 @@ func (s *AuthClientUser) Init(ctx context.Context) (*AuthClientUser, error) {
return s, s.EnsureClientSetup(ctx)
}

// UserWalletOptions can be passed into the constructor to override default values
// e.g. NewUserWallet(pk, rpcAddr, logger, WithChainId(123))
// NewUserWallet(pk, rpcAddr, logger, WithChainId(123), WithRPCTimeout(20*time.Second)), )

func WithChainID(chainID *big.Int) Option {
return func(wallet *AuthClientUser) {
wallet.chainID = chainID
}
func (s *AuthClientUser) Wallet() wallet.Wallet {
return s.wal
}
76 changes: 16 additions & 60 deletions integration/networktest/userwallet/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package userwallet

import (
"context"
"crypto/ecdsa"
"errors"
"fmt"
"math/big"
Expand All @@ -11,20 +10,16 @@ import (
"github.com/ethereum/go-ethereum"
gethcommon "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethclient"
gethlog "github.com/ethereum/go-ethereum/log"
"github.com/ten-protocol/go-ten/go/common/retry"
"github.com/ten-protocol/go-ten/go/rpc"
"github.com/ten-protocol/go-ten/integration"
"github.com/ten-protocol/go-ten/go/wallet"
"github.com/ten-protocol/go-ten/tools/walletextension/lib"
)

type GatewayUser struct {
privateKey *ecdsa.PrivateKey
publicKey *ecdsa.PublicKey
accountAddress gethcommon.Address
chainID *big.Int
wal wallet.Wallet

gwLib *lib.TGLib // TenGateway utility
client *ethclient.Client
Expand All @@ -35,20 +30,14 @@ type GatewayUser struct {
logger gethlog.Logger
}

func NewGatewayUser(pk *ecdsa.PrivateKey, gatewayURL string, logger gethlog.Logger) (*GatewayUser, error) {
publicKeyECDSA, ok := pk.Public().(*ecdsa.PublicKey)
if !ok {
// this shouldn't happen
logger.Crit("error casting public key to ECDSA")
}

func NewGatewayUser(wal wallet.Wallet, gatewayURL string, logger gethlog.Logger) (*GatewayUser, error) {
gwLib := lib.NewTenGatewayLibrary(gatewayURL, "") // not providing wsURL for now, add if we need it

err := gwLib.Join()
if err != nil {
return nil, fmt.Errorf("failed to join TenGateway: %w", err)
}
err = gwLib.RegisterAccount(pk, crypto.PubkeyToAddress(*publicKeyECDSA))
err = gwLib.RegisterAccount(wal.PrivateKey(), wal.Address())
if err != nil {
return nil, fmt.Errorf("failed to register account with TenGateway: %w", err)
}
Expand All @@ -58,19 +47,14 @@ func NewGatewayUser(pk *ecdsa.PrivateKey, gatewayURL string, logger gethlog.Logg
return nil, fmt.Errorf("failed to dial TenGateway HTTP: %w", err)
}

fmt.Printf("Registered acc with TenGateway: %s (%s)\n", crypto.PubkeyToAddress(*publicKeyECDSA).Hex(), gwLib.HTTP())
fmt.Printf("Registered acc with TenGateway: %s (%s)\n", wal.Address(), gwLib.HTTP())

wal := &GatewayUser{
privateKey: pk,
publicKey: publicKeyECDSA,
accountAddress: crypto.PubkeyToAddress(*publicKeyECDSA),
chainID: big.NewInt(integration.TenChainID),
gwLib: gwLib,
client: client,
logger: logger,
}

return wal, nil
return &GatewayUser{
wal: wal,
gwLib: gwLib,
client: client,
logger: logger,
}, nil
}

func (g *GatewayUser) SendFunds(ctx context.Context, addr gethcommon.Address, value *big.Int) (*gethcommon.Hash, error) {
Expand All @@ -85,14 +69,14 @@ func (g *GatewayUser) SendFunds(ctx context.Context, addr gethcommon.Address, va
}
txData.GasPrice = gasPrice
gasLimit, err := g.client.EstimateGas(ctx, ethereum.CallMsg{
From: g.accountAddress,
From: g.wal.Address(),
To: &addr,
})
if err != nil {
return nil, fmt.Errorf("unable to estimate gas - %w", err)
}
txData.Gas = gasLimit
signedTx, err := g.SignTransaction(txData)
signedTx, err := g.wal.SignTransaction(txData)
if err != nil {
return nil, fmt.Errorf("unable to sign transaction - %w", err)
}
Expand Down Expand Up @@ -123,37 +107,9 @@ func (g *GatewayUser) AwaitReceipt(ctx context.Context, txHash *gethcommon.Hash)
}

func (g *GatewayUser) NativeBalance(ctx context.Context) (*big.Int, error) {
return g.client.BalanceAt(ctx, g.accountAddress, nil)
}

func (g *GatewayUser) Address() gethcommon.Address {
return g.accountAddress
}

func (g *GatewayUser) SignTransaction(tx types.TxData) (*types.Transaction, error) {
return g.SignTransactionForChainID(tx, g.chainID)
}

func (g *GatewayUser) SignTransactionForChainID(tx types.TxData, chainID *big.Int) (*types.Transaction, error) {
return types.SignNewTx(g.privateKey, types.NewLondonSigner(chainID), tx)
}

func (g *GatewayUser) SetNonce(_ uint64) {
panic("GatewayUser is designed to manage its own nonce - this method exists to support legacy interface methods")
}

func (g *GatewayUser) GetNonceAndIncrement() uint64 {
panic("GatewayUser is designed to manage its own nonce - this method exists to support legacy interface methods")
}

func (g *GatewayUser) GetNonce() uint64 {
return g.nonce
}

func (g *GatewayUser) ChainID() *big.Int {
return g.chainID
return g.client.BalanceAt(ctx, g.wal.Address(), nil)
}

func (g *GatewayUser) PrivateKey() *ecdsa.PrivateKey {
return g.privateKey
func (g *GatewayUser) Wallet() wallet.Wallet {
return g.wal
}
3 changes: 1 addition & 2 deletions integration/networktest/userwallet/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ import (
//
// This abstraction allows us to use the same tests for both types of users
type User interface {
Wallet() wallet.Wallet
SendFunds(ctx context.Context, addr gethcommon.Address, value *big.Int) (*gethcommon.Hash, error)
AwaitReceipt(ctx context.Context, txHash *gethcommon.Hash) (*types.Receipt, error)
NativeBalance(ctx context.Context) (*big.Int, error)
SignTransaction(tx types.TxData) (*types.Transaction, error)
wallet.Wallet
}
2 changes: 1 addition & 1 deletion integration/simulation/devnetwork/dev_network.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func (s *InMemDevNetwork) startNodes() {
}
}(v)
}
s.faucet = userwallet.NewUserWallet(s.networkWallets.L2FaucetWallet.PrivateKey(), s.SequencerRPCAddress(), s.logger)
s.faucet = userwallet.NewUserWallet(s.networkWallets.L2FaucetWallet, s.SequencerRPCAddress(), s.logger)
}

func (s *InMemDevNetwork) startTenGateway() {
Expand Down

0 comments on commit 72e9e2f

Please sign in to comment.