From a9e1af279f0ddfdd14c000aa866c43d264c00fc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C5=BDiga=20Kokelj?= Date: Mon, 3 Jun 2024 17:19:04 +0200 Subject: [PATCH] new version of rate limiter --- integration/obscurogateway/tengateway_test.go | 2 + tools/walletextension/common/config.go | 3 +- tools/walletextension/main/cli.go | 16 ++- .../ratelimiter/rate_limiter.go | 41 +++++-- .../walletextension/rpcapi/blockchain_api.go | 101 +++++++----------- tools/walletextension/rpcapi/utils.go | 19 ++-- .../rpcapi/wallet_extension.go | 2 +- 7 files changed, 97 insertions(+), 87 deletions(-) diff --git a/integration/obscurogateway/tengateway_test.go b/integration/obscurogateway/tengateway_test.go index e06acf6e3b..127bcdbda6 100644 --- a/integration/obscurogateway/tengateway_test.go +++ b/integration/obscurogateway/tengateway_test.go @@ -74,6 +74,8 @@ func TestTenGateway(t *testing.T) { DBType: "sqlite", TenChainID: 443, StoreIncomingTxs: true, + RateLimitThreshold: 100, + RateLimitDecay: 100, } tenGwContainer := walletextension.NewContainerFromConfig(tenGatewayConf, testlog.Logger()) diff --git a/tools/walletextension/common/config.go b/tools/walletextension/common/config.go index dc50007cfe..e0ba3dc2f4 100644 --- a/tools/walletextension/common/config.go +++ b/tools/walletextension/common/config.go @@ -14,5 +14,6 @@ type Config struct { DBConnectionURL string TenChainID int StoreIncomingTxs bool - RateLimit int + RateLimitThreshold int + RateLimitDecay int } diff --git a/tools/walletextension/main/cli.go b/tools/walletextension/main/cli.go index 5e9901e7fe..6544235f43 100644 --- a/tools/walletextension/main/cli.go +++ b/tools/walletextension/main/cli.go @@ -60,9 +60,13 @@ const ( storeIncomingTxsDefault = true storeIncomingTxsUsage = "Flag to enable storing incoming transactions in the database for debugging purposes. Default: true" - rateLimiter = "rateLimiter" - rateLimiterDefault = 100 - rateLimiterUsage = "The rate limit for the gateway in time needed between calls to computationally expensive endpoints per user with unique encryptionToken. If set to 0 rate limiting is turned off. Default: 100." + rateLimitThresholdName = "rateLimitThreshold" + rateLimitThresholdDefault = 1000000 + rateLimitThresholdUsage = "Rate limit threshold per user. Default: 1000000." + + rateLimitDecayName = "rateLimitDecay" + rateLimitDecayDefault = 100 + rateLimitDecayUsage = "Rate limit decay per user. Default: 100." ) func parseCLIArgs() wecommon.Config { @@ -79,7 +83,8 @@ func parseCLIArgs() wecommon.Config { dbConnectionURL := flag.String(dbConnectionURLFlagName, dbConnectionURLFlagDefault, dbConnectionURLFlagUsage) tenChainID := flag.Int(tenChainIDName, tenChainIDDefault, tenChainIDFlagUsage) storeIncomingTransactions := flag.Bool(storeIncomingTxs, storeIncomingTxsDefault, storeIncomingTxsUsage) - rateLimit := flag.Int(rateLimiter, rateLimiterDefault, rateLimiterUsage) + rateLimitThreshold := flag.Int(rateLimitThresholdName, rateLimitThresholdDefault, rateLimitThresholdUsage) + rateLimitDecay := flag.Int(rateLimitDecayName, rateLimitDecayDefault, rateLimitDecayUsage) flag.Parse() return wecommon.Config{ @@ -95,6 +100,7 @@ func parseCLIArgs() wecommon.Config { DBConnectionURL: *dbConnectionURL, TenChainID: *tenChainID, StoreIncomingTxs: *storeIncomingTransactions, - RateLimit: *rateLimit, + RateLimitThreshold: *rateLimitThreshold, + RateLimitDecay: *rateLimitDecay, } } diff --git a/tools/walletextension/ratelimiter/rate_limiter.go b/tools/walletextension/ratelimiter/rate_limiter.go index ed6d850e6b..11d5292dc8 100644 --- a/tools/walletextension/ratelimiter/rate_limiter.go +++ b/tools/walletextension/ratelimiter/rate_limiter.go @@ -5,33 +5,54 @@ import ( "time" ) +type Score struct { + lastRequest time.Time + score uint32 +} + type RateLimiter struct { mu sync.Mutex - users map[string]time.Time - threshold time.Duration + users map[string]Score + threshold uint32 + decay uint32 } -func NewRateLimiter(threshold time.Duration) *RateLimiter { +func NewRateLimiter(threshold uint32, decay uint32) *RateLimiter { return &RateLimiter{ - users: make(map[string]time.Time), + users: make(map[string]Score), threshold: threshold, + decay: decay, } } -func (rl *RateLimiter) Allow(userID string) bool { +func (rl *RateLimiter) Allow(userID string, weightOfTheCall uint32) bool { rl.mu.Lock() defer rl.mu.Unlock() + // Allow all requests if the threshold is 0 if rl.threshold == 0 { return true } now := time.Now() - if lastRequest, exists := rl.users[userID]; exists { - if now.Sub(lastRequest) < rl.threshold { - return false + userScore, exists := rl.users[userID] + if !exists { + // Create a new entry for the user if not exists + rl.users[userID] = Score{lastRequest: now, score: weightOfTheCall} + } else { + // Calculate the decay based on the time passed + decayTime := uint32(now.Sub(userScore.lastRequest).Seconds()) + decayedScore := int64(userScore.score) - int64(decayTime)*int64(rl.decay) + newScore := decayedScore + int64(weightOfTheCall) + + // Ensure score does not become negative + if newScore < 0 { + newScore = 0 } + + // Update user's score and last request time + rl.users[userID] = Score{lastRequest: now, score: uint32(newScore)} } - rl.users[userID] = now - return true + + return rl.users[userID].score <= rl.threshold } diff --git a/tools/walletextension/rpcapi/blockchain_api.go b/tools/walletextension/rpcapi/blockchain_api.go index c537e421f7..6e787824c0 100644 --- a/tools/walletextension/rpcapi/blockchain_api.go +++ b/tools/walletextension/rpcapi/blockchain_api.go @@ -3,8 +3,6 @@ package rpcapi import ( "context" "encoding/json" - "fmt" - "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/common" @@ -18,18 +16,6 @@ type BlockChainAPI struct { we *Services } -func withRateLimit(ctx context.Context, we *Services, fn func() (any, error)) (any, error) { - userIDHex, err := extractUserIDHex(ctx, we) - if err != nil { - return nil, err - } - - if !we.RateLimiter.Allow(userIDHex) { - return nil, fmt.Errorf("rate limit exceeded") - } - return fn() -} - func NewBlockChainAPI(we *Services) *BlockChainAPI { return &BlockChainAPI{we} } @@ -48,28 +34,25 @@ func (api *BlockChainAPI) BlockNumber() hexutil.Uint64 { } func (api *BlockChainAPI) GetBalance(ctx context.Context, address common.Address, blockNrOrHash rpc.BlockNumberOrHash) (*hexutil.Big, error) { - result, err := withRateLimit(ctx, api.we, func() (any, error) { - return ExecAuthRPC[hexutil.Big]( - ctx, - api.we, - &ExecCfg{ - cacheCfg: &CacheCfg{ - CacheTypeDynamic: func() CacheStrategy { - return cacheBlockNumberOrHash(blockNrOrHash) - }, + return ExecAuthRPC[hexutil.Big]( + ctx, + api.we, + &ExecCfg{ + cacheCfg: &CacheCfg{ + CacheTypeDynamic: func() CacheStrategy { + return cacheBlockNumberOrHash(blockNrOrHash) }, - account: &address, - tryUntilAuthorised: true, // the user can request the balance of a contract account }, - "eth_getBalance", - address, - blockNrOrHash, - ) - }) - if err != nil { - return nil, err - } - return result.(*hexutil.Big), nil + account: &address, + tryUntilAuthorised: true, // the user can request the balance of a contract account + calculateRateLimitScore: func() uint32 { + return 100 + }, + }, + "eth_getBalance", + address, + blockNrOrHash, + ) } // Result structs for GetProof @@ -223,39 +206,29 @@ func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs } func (api *BlockChainAPI) EstimateGas(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Uint64, error) { - result, err := withRateLimit(ctx, api.we, func() (any, error) { - return ExecAuthRPC[hexutil.Uint64]( - ctx, - api.we, - &ExecCfg{ - cacheCfg: &CacheCfg{ - CacheTypeDynamic: func() CacheStrategy { - if blockNrOrHash != nil { - return cacheBlockNumberOrHash(*blockNrOrHash) - } - return LatestBatch - }, - }, - computeFromCallback: func(user *GWUser) *common.Address { - return searchFromAndData(user.GetAllAddresses(), args) - }, - adjustArgs: func(acct *GWAccount) []any { - argsClone := populateFrom(acct, args) - return []any{argsClone, blockNrOrHash, overrides} - }, - tryAll: true, + resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &ExecCfg{ + cacheCfg: &CacheCfg{ + CacheTypeDynamic: func() CacheStrategy { + if blockNrOrHash != nil { + return cacheBlockNumberOrHash(*blockNrOrHash) + } + return LatestBatch }, - "eth_estimateGas", - args, - blockNrOrHash, - overrides, - ) - }) - - if err != nil { + }, + computeFromCallback: func(user *GWUser) *common.Address { + return searchFromAndData(user.GetAllAddresses(), args) + }, + adjustArgs: func(acct *GWAccount) []any { + argsClone := populateFrom(acct, args) + return []any{argsClone, blockNrOrHash, overrides} + }, + // is this a security risk? + tryAll: true, + }, "eth_estimateGas", args, blockNrOrHash, overrides) + if resp == nil { return 0, err } - return result.(hexutil.Uint64), nil + return *resp, err } func populateFrom(acct *GWAccount, args gethapi.TransactionArgs) gethapi.TransactionArgs { diff --git a/tools/walletextension/rpcapi/utils.go b/tools/walletextension/rpcapi/utils.go index 5ee2885100..43757b186c 100644 --- a/tools/walletextension/rpcapi/utils.go +++ b/tools/walletextension/rpcapi/utils.go @@ -46,12 +46,13 @@ const ( var rpcNotImplemented = fmt.Errorf("rpc endpoint not implemented") type ExecCfg struct { - account *gethcommon.Address - computeFromCallback func(user *GWUser) *gethcommon.Address - tryAll bool - tryUntilAuthorised bool - adjustArgs func(acct *GWAccount) []any - cacheCfg *CacheCfg + account *gethcommon.Address + computeFromCallback func(user *GWUser) *gethcommon.Address + tryAll bool + tryUntilAuthorised bool + adjustArgs func(acct *GWAccount) []any + cacheCfg *CacheCfg + calculateRateLimitScore func() uint32 } type CacheStrategy uint8 @@ -101,6 +102,12 @@ func ExecAuthRPC[R any](ctx context.Context, w *Services, cfg *ExecCfg, method s return nil, err } + if cfg.calculateRateLimitScore != nil { + if !w.RateLimiter.Allow(hexutils.BytesToHex(userID), cfg.calculateRateLimitScore()) { + return nil, fmt.Errorf("rate limit exceeded") + } + } + user, err := getUser(userID, w) if err != nil { return nil, err diff --git a/tools/walletextension/rpcapi/wallet_extension.go b/tools/walletextension/rpcapi/wallet_extension.go index 9ed08b8501..bba05047ac 100644 --- a/tools/walletextension/rpcapi/wallet_extension.go +++ b/tools/walletextension/rpcapi/wallet_extension.go @@ -94,7 +94,7 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage cfg := pool.NewDefaultPoolConfig() cfg.MaxTotal = 200 // todo - what is the right number - rateLimiter := ratelimiter.NewRateLimiter(time.Duration(config.RateLimit) * time.Millisecond) + rateLimiter := ratelimiter.NewRateLimiter(uint32(config.RateLimitThreshold), uint32(config.RateLimitDecay)) services := Services{ HostAddrHTTP: hostAddrHTTP,