Skip to content

Commit

Permalink
new version of rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
zkokelj committed Jun 3, 2024
1 parent 75c1f19 commit a9e1af2
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 87 deletions.
2 changes: 2 additions & 0 deletions integration/obscurogateway/tengateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ func TestTenGateway(t *testing.T) {
DBType: "sqlite",
TenChainID: 443,
StoreIncomingTxs: true,
RateLimitThreshold: 100,
RateLimitDecay: 100,
}

tenGwContainer := walletextension.NewContainerFromConfig(tenGatewayConf, testlog.Logger())
Expand Down
3 changes: 2 additions & 1 deletion tools/walletextension/common/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ type Config struct {
DBConnectionURL string
TenChainID int
StoreIncomingTxs bool
RateLimit int
RateLimitThreshold int
RateLimitDecay int
}
16 changes: 11 additions & 5 deletions tools/walletextension/main/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ const (
storeIncomingTxsDefault = true
storeIncomingTxsUsage = "Flag to enable storing incoming transactions in the database for debugging purposes. Default: true"

rateLimiter = "rateLimiter"
rateLimiterDefault = 100
rateLimiterUsage = "The rate limit for the gateway in time needed between calls to computationally expensive endpoints per user with unique encryptionToken. If set to 0 rate limiting is turned off. Default: 100."
rateLimitThresholdName = "rateLimitThreshold"
rateLimitThresholdDefault = 1000000
rateLimitThresholdUsage = "Rate limit threshold per user. Default: 1000000."

rateLimitDecayName = "rateLimitDecay"
rateLimitDecayDefault = 100
rateLimitDecayUsage = "Rate limit decay per user. Default: 100."
)

func parseCLIArgs() wecommon.Config {
Expand All @@ -79,7 +83,8 @@ func parseCLIArgs() wecommon.Config {
dbConnectionURL := flag.String(dbConnectionURLFlagName, dbConnectionURLFlagDefault, dbConnectionURLFlagUsage)
tenChainID := flag.Int(tenChainIDName, tenChainIDDefault, tenChainIDFlagUsage)
storeIncomingTransactions := flag.Bool(storeIncomingTxs, storeIncomingTxsDefault, storeIncomingTxsUsage)
rateLimit := flag.Int(rateLimiter, rateLimiterDefault, rateLimiterUsage)
rateLimitThreshold := flag.Int(rateLimitThresholdName, rateLimitThresholdDefault, rateLimitThresholdUsage)
rateLimitDecay := flag.Int(rateLimitDecayName, rateLimitDecayDefault, rateLimitDecayUsage)
flag.Parse()

return wecommon.Config{
Expand All @@ -95,6 +100,7 @@ func parseCLIArgs() wecommon.Config {
DBConnectionURL: *dbConnectionURL,
TenChainID: *tenChainID,
StoreIncomingTxs: *storeIncomingTransactions,
RateLimit: *rateLimit,
RateLimitThreshold: *rateLimitThreshold,
RateLimitDecay: *rateLimitDecay,
}
}
41 changes: 31 additions & 10 deletions tools/walletextension/ratelimiter/rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,54 @@ import (
"time"
)

type Score struct {
lastRequest time.Time
score uint32
}

type RateLimiter struct {
mu sync.Mutex
users map[string]time.Time
threshold time.Duration
users map[string]Score
threshold uint32
decay uint32
}

func NewRateLimiter(threshold time.Duration) *RateLimiter {
func NewRateLimiter(threshold uint32, decay uint32) *RateLimiter {
return &RateLimiter{
users: make(map[string]time.Time),
users: make(map[string]Score),
threshold: threshold,
decay: decay,
}
}

func (rl *RateLimiter) Allow(userID string) bool {
func (rl *RateLimiter) Allow(userID string, weightOfTheCall uint32) bool {
rl.mu.Lock()
defer rl.mu.Unlock()

// Allow all requests if the threshold is 0
if rl.threshold == 0 {
return true
}

now := time.Now()
if lastRequest, exists := rl.users[userID]; exists {
if now.Sub(lastRequest) < rl.threshold {
return false
userScore, exists := rl.users[userID]
if !exists {
// Create a new entry for the user if not exists
rl.users[userID] = Score{lastRequest: now, score: weightOfTheCall}
} else {
// Calculate the decay based on the time passed
decayTime := uint32(now.Sub(userScore.lastRequest).Seconds())
decayedScore := int64(userScore.score) - int64(decayTime)*int64(rl.decay)
newScore := decayedScore + int64(weightOfTheCall)

// Ensure score does not become negative
if newScore < 0 {
newScore = 0
}

// Update user's score and last request time
rl.users[userID] = Score{lastRequest: now, score: uint32(newScore)}
}
rl.users[userID] = now
return true

return rl.users[userID].score <= rl.threshold
}
101 changes: 37 additions & 64 deletions tools/walletextension/rpcapi/blockchain_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package rpcapi
import (
"context"
"encoding/json"

Check failure on line 5 in tools/walletextension/rpcapi/blockchain_api.go

View workflow job for this annotation

GitHub Actions / lint

File is not `goimports`-ed (goimports)
"fmt"

"github.com/ethereum/go-ethereum/core/types"

"github.com/ethereum/go-ethereum/common"
Expand All @@ -18,18 +16,6 @@ type BlockChainAPI struct {
we *Services
}

func withRateLimit(ctx context.Context, we *Services, fn func() (any, error)) (any, error) {
userIDHex, err := extractUserIDHex(ctx, we)
if err != nil {
return nil, err
}

if !we.RateLimiter.Allow(userIDHex) {
return nil, fmt.Errorf("rate limit exceeded")
}
return fn()
}

func NewBlockChainAPI(we *Services) *BlockChainAPI {
return &BlockChainAPI{we}
}
Expand All @@ -48,28 +34,25 @@ func (api *BlockChainAPI) BlockNumber() hexutil.Uint64 {
}

func (api *BlockChainAPI) GetBalance(ctx context.Context, address common.Address, blockNrOrHash rpc.BlockNumberOrHash) (*hexutil.Big, error) {
result, err := withRateLimit(ctx, api.we, func() (any, error) {
return ExecAuthRPC[hexutil.Big](
ctx,
api.we,
&ExecCfg{
cacheCfg: &CacheCfg{
CacheTypeDynamic: func() CacheStrategy {
return cacheBlockNumberOrHash(blockNrOrHash)
},
return ExecAuthRPC[hexutil.Big](
ctx,
api.we,
&ExecCfg{
cacheCfg: &CacheCfg{
CacheTypeDynamic: func() CacheStrategy {
return cacheBlockNumberOrHash(blockNrOrHash)
},
account: &address,
tryUntilAuthorised: true, // the user can request the balance of a contract account
},
"eth_getBalance",
address,
blockNrOrHash,
)
})
if err != nil {
return nil, err
}
return result.(*hexutil.Big), nil
account: &address,
tryUntilAuthorised: true, // the user can request the balance of a contract account
calculateRateLimitScore: func() uint32 {
return 100
},
},
"eth_getBalance",
address,
blockNrOrHash,
)
}

// Result structs for GetProof
Expand Down Expand Up @@ -223,39 +206,29 @@ func (api *BlockChainAPI) Call(ctx context.Context, args gethapi.TransactionArgs
}

func (api *BlockChainAPI) EstimateGas(ctx context.Context, args gethapi.TransactionArgs, blockNrOrHash *rpc.BlockNumberOrHash, overrides *StateOverride) (hexutil.Uint64, error) {
result, err := withRateLimit(ctx, api.we, func() (any, error) {
return ExecAuthRPC[hexutil.Uint64](
ctx,
api.we,
&ExecCfg{
cacheCfg: &CacheCfg{
CacheTypeDynamic: func() CacheStrategy {
if blockNrOrHash != nil {
return cacheBlockNumberOrHash(*blockNrOrHash)
}
return LatestBatch
},
},
computeFromCallback: func(user *GWUser) *common.Address {
return searchFromAndData(user.GetAllAddresses(), args)
},
adjustArgs: func(acct *GWAccount) []any {
argsClone := populateFrom(acct, args)
return []any{argsClone, blockNrOrHash, overrides}
},
tryAll: true,
resp, err := ExecAuthRPC[hexutil.Uint64](ctx, api.we, &ExecCfg{
cacheCfg: &CacheCfg{
CacheTypeDynamic: func() CacheStrategy {
if blockNrOrHash != nil {
return cacheBlockNumberOrHash(*blockNrOrHash)
}
return LatestBatch
},
"eth_estimateGas",
args,
blockNrOrHash,
overrides,
)
})

if err != nil {
},
computeFromCallback: func(user *GWUser) *common.Address {
return searchFromAndData(user.GetAllAddresses(), args)
},
adjustArgs: func(acct *GWAccount) []any {
argsClone := populateFrom(acct, args)
return []any{argsClone, blockNrOrHash, overrides}
},
// is this a security risk?
tryAll: true,
}, "eth_estimateGas", args, blockNrOrHash, overrides)
if resp == nil {
return 0, err
}
return result.(hexutil.Uint64), nil
return *resp, err
}

func populateFrom(acct *GWAccount, args gethapi.TransactionArgs) gethapi.TransactionArgs {
Expand Down
19 changes: 13 additions & 6 deletions tools/walletextension/rpcapi/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ const (
var rpcNotImplemented = fmt.Errorf("rpc endpoint not implemented")

type ExecCfg struct {
account *gethcommon.Address
computeFromCallback func(user *GWUser) *gethcommon.Address
tryAll bool
tryUntilAuthorised bool
adjustArgs func(acct *GWAccount) []any
cacheCfg *CacheCfg
account *gethcommon.Address
computeFromCallback func(user *GWUser) *gethcommon.Address
tryAll bool
tryUntilAuthorised bool
adjustArgs func(acct *GWAccount) []any
cacheCfg *CacheCfg
calculateRateLimitScore func() uint32
}

type CacheStrategy uint8
Expand Down Expand Up @@ -101,6 +102,12 @@ func ExecAuthRPC[R any](ctx context.Context, w *Services, cfg *ExecCfg, method s
return nil, err
}

if cfg.calculateRateLimitScore != nil {
if !w.RateLimiter.Allow(hexutils.BytesToHex(userID), cfg.calculateRateLimitScore()) {
return nil, fmt.Errorf("rate limit exceeded")
}
}

user, err := getUser(userID, w)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion tools/walletextension/rpcapi/wallet_extension.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage
cfg := pool.NewDefaultPoolConfig()
cfg.MaxTotal = 200 // todo - what is the right number

rateLimiter := ratelimiter.NewRateLimiter(time.Duration(config.RateLimit) * time.Millisecond)
rateLimiter := ratelimiter.NewRateLimiter(uint32(config.RateLimitThreshold), uint32(config.RateLimitDecay))

services := Services{
HostAddrHTTP: hostAddrHTTP,
Expand Down

0 comments on commit a9e1af2

Please sign in to comment.