diff --git a/.github/workflows/manual-deploy-dexynth-gateway.yml b/.github/workflows/manual-deploy-dexynth-gateway.yml index 5271686154..cd86515f4a 100644 --- a/.github/workflows/manual-deploy-dexynth-gateway.yml +++ b/.github/workflows/manual-deploy-dexynth-gateway.yml @@ -138,4 +138,5 @@ jobs: --log-opt max-file=3 --log-opt max-size=10m \ ${{ vars.DOCKER_BUILD_TAG_GATEWAY_DEXYNTH }} \ -host=0.0.0.0 -port=8080 -portWS=81 -nodeHost=${{ vars.L2_RPC_URL_VALIDATOR_DEXYNTH }} -verbose=true \ - -logPath=sys_out -dbType=mariaDB -dbConnectionURL="obscurouser:${{ secrets.OBSCURO_GATEWAY_MARIADB_USER_PWD }}@tcp(obscurogateway-mariadb-${{ github.event.inputs.testnet_type }}.uksouth.cloudapp.azure.com:3306)/ogdb"' + -logPath=sys_out -dbType=mariaDB -dbConnectionURL="obscurouser:${{ secrets.OBSCURO_GATEWAY_MARIADB_USER_PWD }}@tcp(obscurogateway-mariadb-${{ github.event.inputs.testnet_type }}.uksouth.cloudapp.azure.com:3306)/ogdb" \ + -rateLimitUserComputeTime=${{ vars.GATEWAY_RATE_LIMIT_USER_COMPUTE_TIME }} -rateLimitWindow=${{ vars.GATEWAY_RATE_LIMIT_WINDOW }} -maxConcurrentRequestsPerUser=${{ vars.GATEWAY_MAX_CONCURRENT_REQUESTS_PER_USER }} ' diff --git a/.github/workflows/manual-deploy-obscuro-gateway.yml b/.github/workflows/manual-deploy-obscuro-gateway.yml index 4d063d70a7..ea85edf663 100644 --- a/.github/workflows/manual-deploy-obscuro-gateway.yml +++ b/.github/workflows/manual-deploy-obscuro-gateway.yml @@ -137,7 +137,8 @@ jobs: && cd /home/obscuro/go-obscuro/ \ && docker run -d -p 80:80 -p 81:81 --name ${{ github.event.inputs.testnet_type }}-OG-${{ GITHUB.RUN_NUMBER }} \ -e OBSCURO_GATEWAY_VERSION="${{ GITHUB.RUN_NUMBER }}-${{ GITHUB.SHA }}" \ - --log-opt max-file=3 --log-opt max-size=10m \ - ${{ vars.DOCKER_BUILD_TAG_GATEWAY }} \ - -host=0.0.0.0 -port=8080 -portWS=81 -nodeHost=${{ vars.L2_RPC_URL_VALIDATOR }} -verbose=true \ - -logPath=sys_out -dbType=mariaDB -dbConnectionURL="obscurouser:${{ secrets.OBSCURO_GATEWAY_MARIADB_USER_PWD }}@tcp(obscurogateway-mariadb-${{ github.event.inputs.testnet_type }}.uksouth.cloudapp.azure.com:3306)/ogdb"' + --log-opt max-file=3 --log-opt max-size=10m \ + ${{ vars.DOCKER_BUILD_TAG_GATEWAY }} \ + -host=0.0.0.0 -port=8080 -portWS=81 -nodeHost=${{ vars.L2_RPC_URL_VALIDATOR }} -verbose=true \ + -logPath=sys_out -dbType=mariaDB -dbConnectionURL="obscurouser:${{ secrets.OBSCURO_GATEWAY_MARIADB_USER_PWD }}@tcp(obscurogateway-mariadb-${{ github.event.inputs.testnet_type }}.uksouth.cloudapp.azure.com:3306)/ogdb" \ + -rateLimitUserComputeTime=${{ vars.GATEWAY_RATE_LIMIT_USER_COMPUTE_TIME }} -rateLimitWindow=${{ vars.GATEWAY_RATE_LIMIT_WINDOW }} -maxConcurrentRequestsPerUser=${{ vars.GATEWAY_MAX_CONCURRENT_REQUESTS_PER_USER }} ' diff --git a/contracts/package.json b/contracts/package.json index 9c8c4ea73b..5ba922d48b 100644 --- a/contracts/package.json +++ b/contracts/package.json @@ -26,5 +26,8 @@ "ethers": "^6.6.0", "hardhat-ignore-warnings": "^0.2.6", "ten-hardhat-plugin": "^0.0.9" + }, + "peerDependencies": { + "@nomicfoundation/hardhat-verify" : "2.0.8" } } diff --git a/go/common/subscription/new_heads_manager.go b/go/common/subscription/new_heads_manager.go index b39ebc21ef..4ebbb8ce02 100644 --- a/go/common/subscription/new_heads_manager.go +++ b/go/common/subscription/new_heads_manager.go @@ -88,7 +88,7 @@ func (nhs *NewHeadsService) onNewBatch(head *common.BatchHeader) error { var msg any = head if nhs.convertToEthHeader { - msg = convertBatchHeader(head) + msg = ConvertBatchHeader(head) } nhs.notifiersMutex.Lock() @@ -130,10 +130,11 @@ func (nhs *NewHeadsService) HealthStatus(context.Context) host.HealthStatus { return &host.BasicErrHealthStatus{} } -func convertBatchHeader(head *common.BatchHeader) *types.Header { +func ConvertBatchHeader(head *common.BatchHeader) *types.Header { return &types.Header{ ParentHash: head.ParentHash, UncleHash: gethcommon.Hash{}, + Coinbase: head.Coinbase, Root: head.Root, TxHash: head.TxHash, ReceiptHash: head.ReceiptHash, diff --git a/go/host/rpc/clientapi/client_api_eth.go b/go/host/rpc/clientapi/client_api_eth.go index ded05e75e7..0435302ea5 100644 --- a/go/host/rpc/clientapi/client_api_eth.go +++ b/go/host/rpc/clientapi/client_api_eth.go @@ -187,6 +187,18 @@ func (api *EthereumAPI) GetStorageAt(ctx context.Context, encryptedParams common return *enclaveResponse, nil } +func (api *EthereumAPI) MaxPriorityFeePerGas(_ context.Context) (*hexutil.Big, error) { + // todo - implement with the gas mechanics + header, err := api.host.Storage().FetchHeadBatchHeader() + if err != nil { + api.logger.Error("Unable to retrieve header for fee history.", log.ErrKey, err) + return nil, fmt.Errorf("unable to retrieve MaxPriorityFeePerGas") + } + + // just return the base fee? + return (*hexutil.Big)(header.BaseFee), err +} + // FeeHistory is a placeholder for an RPC method required by MetaMask/Remix. // rpc.DecimalOrHex -> []byte func (api *EthereumAPI) FeeHistory(context.Context, string, rpc.BlockNumber, []float64) (*FeeHistoryResult, error) { diff --git a/integration/obscurogateway/tengateway_test.go b/integration/obscurogateway/tengateway_test.go index f411d35a4b..43778e5d35 100644 --- a/integration/obscurogateway/tengateway_test.go +++ b/integration/obscurogateway/tengateway_test.go @@ -64,16 +64,19 @@ func TestTenGateway(t *testing.T) { createTenNetwork(t, startPort) tenGatewayConf := wecommon.Config{ - WalletExtensionHost: "127.0.0.1", - WalletExtensionPortHTTP: startPort + integration.DefaultTenGatewayHTTPPortOffset, - WalletExtensionPortWS: startPort + integration.DefaultTenGatewayWSPortOffset, - NodeRPCHTTPAddress: fmt.Sprintf("127.0.0.1:%d", startPort+integration.DefaultHostRPCHTTPOffset), - NodeRPCWebsocketAddress: fmt.Sprintf("127.0.0.1:%d", startPort+integration.DefaultHostRPCWSOffset), - LogPath: "sys_out", - VerboseFlag: false, - DBType: "sqlite", - TenChainID: 443, - StoreIncomingTxs: true, + WalletExtensionHost: "127.0.0.1", + WalletExtensionPortHTTP: startPort + integration.DefaultTenGatewayHTTPPortOffset, + WalletExtensionPortWS: startPort + integration.DefaultTenGatewayWSPortOffset, + NodeRPCHTTPAddress: fmt.Sprintf("127.0.0.1:%d", startPort+integration.DefaultHostRPCHTTPOffset), + NodeRPCWebsocketAddress: fmt.Sprintf("127.0.0.1:%d", startPort+integration.DefaultHostRPCWSOffset), + LogPath: "sys_out", + VerboseFlag: false, + DBType: "sqlite", + TenChainID: 443, + StoreIncomingTxs: true, + RateLimitUserComputeTime: 200 * time.Millisecond, + RateLimitWindow: 1 * time.Second, + RateLimitMaxConcurrentRequests: 3, } tenGwContainer := walletextension.NewContainerFromConfig(tenGatewayConf, testlog.Logger()) @@ -111,6 +114,7 @@ func TestTenGateway(t *testing.T) { "testDifferentMessagesOnRegister": testDifferentMessagesOnRegister, "testInvokeNonSensitiveMethod": testInvokeNonSensitiveMethod, "testGetStorageAtForReturningUserID": testGetStorageAtForReturningUserID, + "testRateLimiter": testRateLimiter, } { t.Run(name, func(t *testing.T) { test(t, httpURL, wsURL, w) @@ -124,6 +128,45 @@ func TestTenGateway(t *testing.T) { assert.NoError(t, err) } +func testRateLimiter(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { + user0, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) + require.NoError(t, err) + testlog.Logger().Info("Created user with encryption token", "t", user0.tgClient.UserID()) + // register the user so we can call the endpoints that require authentication + err = user0.RegisterAccounts() + require.NoError(t, err) + + // call BalanceAt - fist call should be successful + _, err = user0.HTTPClient.BalanceAt(context.Background(), user0.Wallets[0].Address(), nil) + require.NoError(t, err) + + // sleep for a period of time to allow the rate limiter to reset + time.Sleep(1 * time.Second) + + // first call after the rate limiter reset should be successful + _, err = user0.HTTPClient.BalanceAt(context.Background(), user0.Wallets[0].Address(), nil) + require.NoError(t, err) + + address := user0.Wallets[0].Address() + + // make 1000 requests with the same user to "spam" the gateway + for i := 0; i < 1000; i++ { + msg := ethereum.CallMsg{ + From: address, + To: &address, // Example: self-call to the user's address + Gas: uint64(i), + Data: nil, + } + + user0.HTTPClient.EstimateGas(context.Background(), msg) + } + + // after 1000 requests, the rate limiter should block the user + _, err = user0.HTTPClient.BalanceAt(context.Background(), user0.Wallets[0].Address(), nil) + require.Error(t, err) + require.Equal(t, "rate limit exceeded", err.Error()) +} + func testNewHeadsSubscription(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { user0, err := NewGatewayUser([]wallet.Wallet{w, datagenerator.RandomWallet(integration.TenChainID)}, httpURL, wsURL) require.NoError(t, err) @@ -466,6 +509,7 @@ func testErrorHandling(t *testing.T, httpURL, wsURL string, w wallet.Wallet) { `{"jsonrpc":"2.0","method":"eth_getBalance","params":["0xA58C60cc047592DE97BF1E8d2f225Fc5D959De77", "latest"],"id":1,"extra":"extra_field"}`, `{"jsonrpc":"2.0","method":"eth_sendTransaction","params":[["0xA58C60cc047592DE97BF1E8d2f225Fc5D959De77", "0x1234"]],"id":1}`, `{"jsonrpc":"2.0","method":"eth_getTransactionByHash","params":["0x0000000000000000000000000000000000000000000000000000000000000000"],"id":1}`, + `{"jsonrpc":"2.0","method":"eth_maxPriorityFeePerGas","params":[],"id":1}`, } { // ensure the geth request is issued correctly (should return 200 ok with jsonRPCError) _, response, err := httputil.PostDataJSON(ogClient.HTTP(), []byte(req)) diff --git a/tools/walletextension/common/config.go b/tools/walletextension/common/config.go index 98092f8c7b..26f43cf459 100644 --- a/tools/walletextension/common/config.go +++ b/tools/walletextension/common/config.go @@ -1,17 +1,22 @@ package common +import "time" + // Config contains the configuration required by the WalletExtension. type Config struct { - WalletExtensionHost string - WalletExtensionPortHTTP int - WalletExtensionPortWS int - NodeRPCHTTPAddress string - NodeRPCWebsocketAddress string - LogPath string - DBPathOverride string // Overrides the database file location. Used in tests. - VerboseFlag bool - DBType string - DBConnectionURL string - TenChainID int - StoreIncomingTxs bool + WalletExtensionHost string + WalletExtensionPortHTTP int + WalletExtensionPortWS int + NodeRPCHTTPAddress string + NodeRPCWebsocketAddress string + LogPath string + DBPathOverride string // Overrides the database file location. Used in tests. + VerboseFlag bool + DBType string + DBConnectionURL string + TenChainID int + StoreIncomingTxs bool + RateLimitUserComputeTime time.Duration + RateLimitWindow time.Duration + RateLimitMaxConcurrentRequests int } diff --git a/tools/walletextension/main/cli.go b/tools/walletextension/main/cli.go index 0e7f564c4c..16f1a00709 100644 --- a/tools/walletextension/main/cli.go +++ b/tools/walletextension/main/cli.go @@ -3,6 +3,7 @@ package main import ( "flag" "fmt" + "time" wecommon "github.com/ten-protocol/go-ten/tools/walletextension/common" ) @@ -59,6 +60,18 @@ const ( storeIncomingTxs = "storeIncomingTxs" storeIncomingTxsDefault = true storeIncomingTxsUsage = "Flag to enable storing incoming transactions in the database for debugging purposes. Default: true" + + rateLimitUserComputeTimeName = "rateLimitUserComputeTime" + rateLimitUserComputeTimeDefault = 10 * time.Second + rateLimitUserComputeTimeUsage = "rateLimitUserComputeTime represents how much compute time is user allowed to used in rateLimitWindow time. If rateLimitUserComputeTime is set to 0, rate limiting is turned off. Default: 10s." + + rateLimitWindowName = "rateLimitWindow" + rateLimitWindowDefault = 1 * time.Minute + rateLimitWindowUsage = "rateLimitWindow represents time window in which we allow one user to use compute time defined with rateLimitUserComputeTimeMs Default: 1m" + + rateLimitMaxConcurrentRequestsName = "maxConcurrentRequestsPerUser" + rateLimitMaxConcurrentRequestsDefault = 3 + rateLimitMaxConcurrentRequestsUsage = "Number of concurrent requests allowed per user. Default: 3" ) func parseCLIArgs() wecommon.Config { @@ -75,20 +88,26 @@ func parseCLIArgs() wecommon.Config { dbConnectionURL := flag.String(dbConnectionURLFlagName, dbConnectionURLFlagDefault, dbConnectionURLFlagUsage) tenChainID := flag.Int(tenChainIDName, tenChainIDDefault, tenChainIDFlagUsage) storeIncomingTransactions := flag.Bool(storeIncomingTxs, storeIncomingTxsDefault, storeIncomingTxsUsage) + rateLimitUserComputeTime := flag.Duration(rateLimitUserComputeTimeName, rateLimitUserComputeTimeDefault, rateLimitUserComputeTimeUsage) + rateLimitWindow := flag.Duration(rateLimitWindowName, rateLimitWindowDefault, rateLimitWindowUsage) + rateLimitMaxConcurrentRequests := flag.Int(rateLimitMaxConcurrentRequestsName, rateLimitMaxConcurrentRequestsDefault, rateLimitMaxConcurrentRequestsUsage) flag.Parse() return wecommon.Config{ - WalletExtensionHost: *walletExtensionHost, - WalletExtensionPortHTTP: *walletExtensionPort, - WalletExtensionPortWS: *walletExtensionPortWS, - NodeRPCHTTPAddress: fmt.Sprintf("%s:%d", *nodeHost, *nodeHTTPPort), - NodeRPCWebsocketAddress: fmt.Sprintf("%s:%d", *nodeHost, *nodeWebsocketPort), - LogPath: *logPath, - DBPathOverride: *databasePath, - VerboseFlag: *verboseFlag, - DBType: *dbType, - DBConnectionURL: *dbConnectionURL, - TenChainID: *tenChainID, - StoreIncomingTxs: *storeIncomingTransactions, + WalletExtensionHost: *walletExtensionHost, + WalletExtensionPortHTTP: *walletExtensionPort, + WalletExtensionPortWS: *walletExtensionPortWS, + NodeRPCHTTPAddress: fmt.Sprintf("%s:%d", *nodeHost, *nodeHTTPPort), + NodeRPCWebsocketAddress: fmt.Sprintf("%s:%d", *nodeHost, *nodeWebsocketPort), + LogPath: *logPath, + DBPathOverride: *databasePath, + VerboseFlag: *verboseFlag, + DBType: *dbType, + DBConnectionURL: *dbConnectionURL, + TenChainID: *tenChainID, + StoreIncomingTxs: *storeIncomingTransactions, + RateLimitUserComputeTime: *rateLimitUserComputeTime, + RateLimitWindow: *rateLimitWindow, + RateLimitMaxConcurrentRequests: *rateLimitMaxConcurrentRequests, } } diff --git a/tools/walletextension/ratelimiter/rate_limiter.go b/tools/walletextension/ratelimiter/rate_limiter.go new file mode 100644 index 0000000000..f2be19ac41 --- /dev/null +++ b/tools/walletextension/ratelimiter/rate_limiter.go @@ -0,0 +1,230 @@ +package ratelimiter + +import ( + "math" + "sync" + "time" + + gethlog "github.com/ethereum/go-ethereum/log" + + "github.com/google/uuid" + + "github.com/ethereum/go-ethereum/common" +) + +// RequestInterval represents an interval for a request with a start and optional end timestamp. +type RequestInterval struct { + Start time.Time + End *time.Time // can be nil if the request is not over yet +} + +// RateLimitUser represents a user with a map of current requests. +type RateLimitUser struct { + CurrentRequests map[uuid.UUID]RequestInterval +} + +// zeroUUID is a zero UUID returned when no new request is added. +var zeroUUID uuid.UUID + +// AddRequest adds a new request interval to a user's current requests and returns the UUID. +func (rl *RateLimiter) AddRequest(userID common.Address, interval RequestInterval) uuid.UUID { + rl.mu.Lock() + defer rl.mu.Unlock() + + user, exists := rl.users[userID] + if !exists { + user = &RateLimitUser{ + CurrentRequests: make(map[uuid.UUID]RequestInterval), + } + rl.users[userID] = user + } + id := uuid.New() + user.CurrentRequests[id] = interval + return id +} + +// SetRequestEnd updates the end time of a request interval given its UUID. +func (rl *RateLimiter) SetRequestEnd(userID common.Address, id uuid.UUID) { + if user, userExists := rl.users[userID]; userExists { + if request, requestExists := user.CurrentRequests[id]; requestExists { + rl.mu.Lock() + defer rl.mu.Unlock() + now := time.Now() + request.End = &now + user.CurrentRequests[id] = request + } else { + rl.logger.Info("Request with ID %s not found for user %s.", id, userID.Hex()) + } + } else { + rl.logger.Info("User %s not found while trying to update the request.", userID.Hex()) + } +} + +// CountOpenRequests counts the number of requests without an End time set. +func (rl *RateLimiter) CountOpenRequests(userID common.Address) int { + rl.mu.Lock() + defer rl.mu.Unlock() + + var count int + if user, exists := rl.users[userID]; exists { + for _, interval := range user.CurrentRequests { + if interval.End == nil { + count++ + } + } + } + return count +} + +// SumComputeTime sums the compute time for requests within the rate limiter's window +// and returns it as uint32 milliseconds. +func (rl *RateLimiter) SumComputeTime(userID common.Address) time.Duration { + rl.mu.Lock() + defer rl.mu.Unlock() + + var totalComputeTime time.Duration + if user, exists := rl.users[userID]; exists { + cutoff := time.Now().Add(-rl.window) + for _, interval := range user.CurrentRequests { + // if the request has ended and it's within the window, add the compute time + if interval.End != nil && interval.End.After(cutoff) { + totalComputeTime += interval.End.Sub(interval.Start) + } + // if the request hasn't ended yet, add the compute time until now + if interval.End == nil { + totalComputeTime += time.Since(interval.Start) + } + } + } + return totalComputeTime +} + +type RateLimiter struct { + mu sync.Mutex + users map[common.Address]*RateLimitUser + userComputeTime time.Duration + window time.Duration + maxConcurrentRequests uint32 + totalRequests uint64 + rateLimitedRequests uint64 + logger gethlog.Logger +} + +// IncrementTotalRequests increments the total requests counter by 1 with thread safety. +func (rl *RateLimiter) IncrementTotalRequests() { + rl.mu.Lock() + defer rl.mu.Unlock() + rl.totalRequests++ +} + +// IncrementRateLimitedRequests increments the total requests counter by 1 with thread safety. +func (rl *RateLimiter) IncrementRateLimitedRequests() { + rl.mu.Lock() + defer rl.mu.Unlock() + rl.rateLimitedRequests++ +} + +// GetMaxConcurrentRequest returns the maximum number of concurrent requests allowed. +func (rl *RateLimiter) GetMaxConcurrentRequest() uint32 { + rl.mu.Lock() + defer rl.mu.Unlock() + return rl.maxConcurrentRequests +} + +// GetUserComputeTime returns the user compute time +func (rl *RateLimiter) GetUserComputeTime() time.Duration { + rl.mu.Lock() + defer rl.mu.Unlock() + return rl.userComputeTime +} + +func NewRateLimiter(rateLimitUserComputeTime time.Duration, rateLimitWindow time.Duration, concurrentRequestsLimit uint32, logger gethlog.Logger) *RateLimiter { + rl := &RateLimiter{ + users: make(map[common.Address]*RateLimitUser), + userComputeTime: rateLimitUserComputeTime, + window: rateLimitWindow, + maxConcurrentRequests: concurrentRequestsLimit, + logger: logger, + } + go rl.logRateLimitedStats() + go rl.periodicPrune() + return rl +} + +// Allow checks if the user is allowed to make a request based on the rate limit threshold +// before comparing to the threshold also decays the score of the user based on the decay rate +func (rl *RateLimiter) Allow(userID common.Address) (bool, uuid.UUID) { + // If the userComputeTime is 0, allow all requests (rate limiting is disabled) + if rl.GetUserComputeTime() == 0 { + return true, zeroUUID + } + // Increment the total requests counter for statistics + rl.IncrementTotalRequests() + + // Check if the user has reached the maximum number of concurrent requests + if uint32(rl.CountOpenRequests(userID)) >= rl.GetMaxConcurrentRequest() { + rl.IncrementRateLimitedRequests() + rl.logger.Info("User %s has reached the maximum number of concurrent requests.", userID.Hex()) + return false, zeroUUID + } + + // Check if user is in limits of rate limiting + userComputeTimeForUser := rl.SumComputeTime(userID) + if userComputeTimeForUser > rl.userComputeTime { + rl.IncrementRateLimitedRequests() + rl.logger.Info("User %s has reached the rate limit threshold.", userID.Hex()) + return false, zeroUUID + } + + requestUUID := rl.AddRequest(userID, RequestInterval{Start: time.Now()}) + return true, requestUUID +} + +// PruneRequests deletes all requests that have ended before the rate limiter's window. +func (rl *RateLimiter) PruneRequests() { + rl.mu.Lock() + defer rl.mu.Unlock() + startTime := time.Now() + // delete all the requests that have + cutoff := time.Now().Add(-rl.window) + for userID, user := range rl.users { + for id, interval := range user.CurrentRequests { + if interval.End != nil && interval.End.Before(cutoff) { + delete(user.CurrentRequests, id) + } + } + if len(user.CurrentRequests) == 0 { + delete(rl.users, userID) + } + } + timeTaken := time.Since(startTime) + if timeTaken > 1*time.Second { + rl.logger.Warn("PruneRequests completed in %s", timeTaken) + } +} + +// periodically prunes the requests that have ended before the rate limiter's window every 10 * window milliseconds +func (rl *RateLimiter) periodicPrune() { + for { + time.Sleep(rl.window * 10) + rl.PruneRequests() + } +} + +func (rl *RateLimiter) logRateLimitedStats() { + for { + time.Sleep(30 * time.Minute) + rl.mu.Lock() + totalRequests := rl.totalRequests + rateLimitedRequests := rl.rateLimitedRequests + rl.totalRequests = 0 + rl.rateLimitedRequests = 0 + rl.mu.Unlock() + + rateLimitedPercentage := float64(rateLimitedRequests) / float64(totalRequests) * 100 + if math.IsNaN(rateLimitedPercentage) { + rateLimitedPercentage = 0 + } + rl.logger.Info("Total requests: %d, Rate-limited requests: %d (%.4f%%)", totalRequests, rateLimitedRequests, rateLimitedPercentage) + } +} diff --git a/tools/walletextension/rpcapi/blockchain_api.go b/tools/walletextension/rpcapi/blockchain_api.go index aeb6a01ed9..57d8ccfbf0 100644 --- a/tools/walletextension/rpcapi/blockchain_api.go +++ b/tools/walletextension/rpcapi/blockchain_api.go @@ -12,6 +12,7 @@ import ( "github.com/ten-protocol/go-ten/go/common" "github.com/ten-protocol/go-ten/go/common/gethapi" "github.com/ten-protocol/go-ten/go/common/privacy" + "github.com/ten-protocol/go-ten/go/common/subscription" "github.com/ten-protocol/go-ten/lib/gethfork/rpc" ) @@ -100,7 +101,7 @@ func (api *BlockChainAPI) GetHeaderByHash(ctx context.Context, hash gethcommon.H } func (api *BlockChainAPI) GetBlockByNumber(ctx context.Context, number rpc.BlockNumber, fullTx bool) (map[string]interface{}, error) { - resp, err := UnauthenticatedTenRPCCall[map[string]interface{}]( + resp, err := UnauthenticatedTenRPCCall[common.BatchHeader]( ctx, api.we, &CacheCfg{ @@ -111,15 +112,25 @@ func (api *BlockChainAPI) GetBlockByNumber(ctx context.Context, number rpc.Block if resp == nil { return nil, err } - return *resp, err + + // convert to geth header and marshall + header := subscription.ConvertBatchHeader(resp) + fields := RPCMarshalHeader(header) + addExtraTenFields(fields, resp) + return fields, err } func (api *BlockChainAPI) GetBlockByHash(ctx context.Context, hash gethcommon.Hash, fullTx bool) (map[string]interface{}, error) { - resp, err := UnauthenticatedTenRPCCall[map[string]interface{}](ctx, api.we, &CacheCfg{CacheType: LongLiving}, "eth_getBlockByHash", hash, fullTx) + resp, err := UnauthenticatedTenRPCCall[common.BatchHeader](ctx, api.we, &CacheCfg{CacheType: LongLiving}, "eth_getBlockByHash", hash, fullTx) if resp == nil { return nil, err } - return *resp, err + + // convert to geth header and marshall + header := subscription.ConvertBatchHeader(resp) + fields := RPCMarshalHeader(header) + addExtraTenFields(fields, resp) + return fields, err } func (api *BlockChainAPI) GetCode(ctx context.Context, address gethcommon.Address, blockNrOrHash rpc.BlockNumberOrHash) (hexutil.Bytes, error) { @@ -338,3 +349,52 @@ func extractCustomQueryAddress(params any) (*gethcommon.Address, error) { address := gethcommon.HexToAddress(addressStr) return &address, nil } + +// RPCMarshalHeader converts the given header to the RPC output . +// duplicated from go-ethereum +func RPCMarshalHeader(head *types.Header) map[string]interface{} { + result := map[string]interface{}{ + "number": (*hexutil.Big)(head.Number), + "hash": head.Hash(), + "parentHash": head.ParentHash, + "nonce": head.Nonce, + "mixHash": head.MixDigest, + "sha3Uncles": head.UncleHash, + "logsBloom": head.Bloom, + "stateRoot": head.Root, + "miner": head.Coinbase, + "difficulty": (*hexutil.Big)(head.Difficulty), + "extraData": hexutil.Bytes(head.Extra), + "gasLimit": hexutil.Uint64(head.GasLimit), + "gasUsed": hexutil.Uint64(head.GasUsed), + "timestamp": hexutil.Uint64(head.Time), + "transactionsRoot": head.TxHash, + "receiptsRoot": head.ReceiptHash, + } + if head.BaseFee != nil { + result["baseFeePerGas"] = (*hexutil.Big)(head.BaseFee) + } + if head.WithdrawalsHash != nil { + result["withdrawalsRoot"] = head.WithdrawalsHash + } + if head.BlobGasUsed != nil { + result["blobGasUsed"] = hexutil.Uint64(*head.BlobGasUsed) + } + if head.ExcessBlobGas != nil { + result["excessBlobGas"] = hexutil.Uint64(*head.ExcessBlobGas) + } + if head.ParentBeaconRoot != nil { + result["parentBeaconBlockRoot"] = head.ParentBeaconRoot + } + return result +} + +func addExtraTenFields(fields map[string]interface{}, header *common.BatchHeader) { + fields["l1Proof"] = header.L1Proof + fields["signature"] = header.Signature + fields["crossChainMessages"] = header.CrossChainMessages + fields["inboundCrossChainHash"] = header.LatestInboundCrossChainHash + fields["inboundCrossChainHeight"] = header.LatestInboundCrossChainHeight + fields["crossChainTreeHash"] = header.CrossChainRoot + fields["crossChainTree"] = header.CrossChainTree +} diff --git a/tools/walletextension/rpcapi/filter_api.go b/tools/walletextension/rpcapi/filter_api.go index b144160ae8..af154b6793 100644 --- a/tools/walletextension/rpcapi/filter_api.go +++ b/tools/walletextension/rpcapi/filter_api.go @@ -197,6 +197,12 @@ func (api *FilterAPI) GetLogs(ctx context.Context, crit common.FilterCriteria) ( return nil, err } + rateLimitAllowed, requestUUID := api.we.RateLimiter.Allow(gethcommon.Address(userID)) + defer api.we.RateLimiter.SetRequestEnd(gethcommon.Address(userID), requestUUID) + if !rateLimitAllowed { + return nil, fmt.Errorf("rate limit exceeded") + } + res, err := withCache( api.we.Cache, &CacheCfg{ diff --git a/tools/walletextension/rpcapi/utils.go b/tools/walletextension/rpcapi/utils.go index 2bab927659..a90eacb810 100644 --- a/tools/walletextension/rpcapi/utils.go +++ b/tools/walletextension/rpcapi/utils.go @@ -101,6 +101,12 @@ func ExecAuthRPC[R any](ctx context.Context, w *Services, cfg *ExecCfg, method s return nil, err } + rateLimitAllowed, requestUUID := w.RateLimiter.Allow(gethcommon.Address(userID)) + defer w.RateLimiter.SetRequestEnd(gethcommon.Address(userID), requestUUID) + if !rateLimitAllowed { + return nil, fmt.Errorf("rate limit exceeded") + } + cacheArgs := []any{userID, method} cacheArgs = append(cacheArgs, args...) diff --git a/tools/walletextension/rpcapi/wallet_extension.go b/tools/walletextension/rpcapi/wallet_extension.go index 0d2b742791..fae47871f7 100644 --- a/tools/walletextension/rpcapi/wallet_extension.go +++ b/tools/walletextension/rpcapi/wallet_extension.go @@ -31,6 +31,7 @@ import ( "github.com/ten-protocol/go-ten/go/common/stopcontrol" "github.com/ten-protocol/go-ten/go/common/viewingkey" "github.com/ten-protocol/go-ten/tools/walletextension/common" + "github.com/ten-protocol/go-ten/tools/walletextension/ratelimiter" "github.com/ten-protocol/go-ten/tools/walletextension/storage" ) @@ -44,6 +45,7 @@ type Services struct { stopControl *stopcontrol.StopControl version string Cache cache.Cache + RateLimiter *ratelimiter.RateLimiter // the OG maintains a connection pool of rpc connections to underlying nodes rpcHTTPConnPool *pool.ObjectPool rpcWSConnPool *pool.ObjectPool @@ -92,6 +94,8 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage cfg := pool.NewDefaultPoolConfig() cfg.MaxTotal = 200 // todo - what is the right number + rateLimiter := ratelimiter.NewRateLimiter(config.RateLimitUserComputeTime, config.RateLimitWindow, uint32(config.RateLimitMaxConcurrentRequests), logger) + services := Services{ HostAddrHTTP: hostAddrHTTP, HostAddrWS: hostAddrWS, @@ -101,6 +105,7 @@ func NewServices(hostAddrHTTP string, hostAddrWS string, storage storage.Storage stopControl: stopControl, version: version, Cache: newGatewayCache, + RateLimiter: rateLimiter, rpcHTTPConnPool: pool.NewObjectPool(context.Background(), factoryHTTP, cfg), rpcWSConnPool: pool.NewObjectPool(context.Background(), factoryWS, cfg), Config: config,