From 0291751ff2e55c0f5aa11f4eded52a5dd566c077 Mon Sep 17 00:00:00 2001 From: millken Date: Mon, 15 Apr 2024 14:27:59 +0800 Subject: [PATCH] [api] Add ratelimit for websocket API (#4031) --- api/config.go | 19 ++++++++------- api/serverV2.go | 4 +++- api/serverV2_test.go | 56 ++++++++++++++++++++++++++++++++++++++++++-- api/websocket.go | 14 ++++++++++- go.mod | 2 +- go.sum | 2 -- 6 files changed, 82 insertions(+), 15 deletions(-) diff --git a/api/config.go b/api/config.go index 5ac7ba3b2d..f937750c7c 100644 --- a/api/config.go +++ b/api/config.go @@ -23,16 +23,19 @@ type Config struct { Tracer tracer.Config `yaml:"tracer"` // BatchRequestLimit is the maximum number of requests in a batch. BatchRequestLimit int `yaml:"batchRequestLimit"` + // WebsocketRateLimit is the maximum number of messages per second per client. + WebsocketRateLimit int `yaml:"websocketRateLimit"` } // DefaultConfig is the default config var DefaultConfig = Config{ - UseRDS: false, - GRPCPort: 14014, - HTTPPort: 15014, - WebSocketPort: 16014, - TpsWindow: 10, - GasStation: gasstation.DefaultConfig, - RangeQueryLimit: 1000, - BatchRequestLimit: _defaultBatchRequestLimit, + UseRDS: false, + GRPCPort: 14014, + HTTPPort: 15014, + WebSocketPort: 16014, + TpsWindow: 10, + GasStation: gasstation.DefaultConfig, + RangeQueryLimit: 1000, + BatchRequestLimit: _defaultBatchRequestLimit, + WebsocketRateLimit: 5, } diff --git a/api/serverV2.go b/api/serverV2.go index c68a952739..f40e65586f 100644 --- a/api/serverV2.go +++ b/api/serverV2.go @@ -11,6 +11,7 @@ import ( "github.com/pkg/errors" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" tracesdk "go.opentelemetry.io/otel/sdk/trace" + "golang.org/x/time/rate" "github.com/iotexproject/iotex-core/action/protocol" "github.com/iotexproject/iotex-core/action/protocol/execution/evm" @@ -65,7 +66,8 @@ func NewServerV2( wrappedWeb3Handler := otelhttp.NewHandler(newHTTPHandler(web3Handler), "web3.jsonrpc") - wrappedWebsocketHandler := otelhttp.NewHandler(NewWebsocketHandler(web3Handler), "web3.websocket") + limiter := rate.NewLimiter(rate.Limit(cfg.WebsocketRateLimit), 1) + wrappedWebsocketHandler := otelhttp.NewHandler(NewWebsocketHandler(web3Handler, limiter), "web3.websocket") return &ServerV2{ core: coreAPI, diff --git a/api/serverV2_test.go b/api/serverV2_test.go index acb80c3817..00e4f338cd 100644 --- a/api/serverV2_test.go +++ b/api/serverV2_test.go @@ -2,12 +2,17 @@ package api import ( "context" + "net/http" + "net/http/httptest" + "strings" "testing" "time" "github.com/golang/mock/gomock" + "github.com/gorilla/websocket" "github.com/pkg/errors" "github.com/stretchr/testify/require" + "golang.org/x/time/rate" "github.com/iotexproject/iotex-core/test/mock/mock_apicoreservice" "github.com/iotexproject/iotex-core/testutil" @@ -18,13 +23,12 @@ func TestServerV2(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() core := mock_apicoreservice.NewMockCoreService(ctrl) - // TODO: mock web3handler web3Handler := NewWeb3Handler(core, "", _defaultBatchRequestLimit) svr := &ServerV2{ core: core, grpcServer: NewGRPCServer(core, testutil.RandomPort()), httpSvr: NewHTTPServer("", testutil.RandomPort(), newHTTPHandler(web3Handler)), - websocketSvr: NewHTTPServer("", testutil.RandomPort(), NewWebsocketHandler(web3Handler)), + websocketSvr: NewHTTPServer("", testutil.RandomPort(), NewWebsocketHandler(web3Handler, nil)), } ctx := context.Background() @@ -54,4 +58,52 @@ func TestServerV2(t *testing.T) { err := svr.Stop(ctx) require.Contains(err.Error(), expectErr.Error()) }) + + t.Run("websocket rate limit", func(t *testing.T) { + // set the limiter to 1 request per second + limiter := rate.NewLimiter(1, 1) + echo := func(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + for { + if err := limiter.Wait(ctx); err != nil { + return + } + mt, message, err := c.ReadMessage() + if err != nil { + break + } + err = c.WriteMessage(mt, message) + if err != nil { + break + } + } + } + s := httptest.NewServer(http.HandlerFunc(echo)) + defer s.Close() + + u := "ws" + strings.TrimPrefix(s.URL, "http") + c, _, err := websocket.DefaultDialer.Dial(u, nil) + require.NoError(err) + defer c.Close() + i := 0 + timeout := time.After(3 * time.Second) + LOOP: + for { + select { + case <-timeout: + break LOOP + default: + err := c.WriteMessage(websocket.TextMessage, []byte{0}) + require.NoError(err) + _, _, err = c.ReadMessage() + require.NoError(err) + i++ + } + } + require.Equal(4, i) + }) } diff --git a/api/websocket.go b/api/websocket.go index dec7c0278f..25170bb3da 100644 --- a/api/websocket.go +++ b/api/websocket.go @@ -2,12 +2,14 @@ package api import ( "context" + "math" "net/http" "sync" "time" "github.com/gorilla/websocket" "go.uber.org/zap" + "golang.org/x/time/rate" apitypes "github.com/iotexproject/iotex-core/api/types" "github.com/iotexproject/iotex-core/pkg/log" @@ -30,6 +32,7 @@ const ( // WebsocketHandler handles requests from websocket protocol type WebsocketHandler struct { msgHandler Web3Handler + limiter *rate.Limiter } var upgrader = websocket.Upgrader{ @@ -72,9 +75,14 @@ func (c *safeWebsocketConn) SetWriteDeadline(t time.Time) error { } // NewWebsocketHandler creates a new websocket handler -func NewWebsocketHandler(web3Handler Web3Handler) *WebsocketHandler { +func NewWebsocketHandler(web3Handler Web3Handler, limiter *rate.Limiter) *WebsocketHandler { + if limiter == nil { + // set the limiter to the maximum possible rate + limiter = rate.NewLimiter(rate.Limit(math.MaxFloat64), 1) + } return &WebsocketHandler{ msgHandler: web3Handler, + limiter: limiter, } } @@ -113,6 +121,10 @@ func (wsSvr *WebsocketHandler) handleConnection(ctx context.Context, ws *websock case <-ctx.Done(): return default: + if err := wsSvr.limiter.Wait(ctx); err != nil { + cancel() + return + } _, reader, err := ws.NextReader() if err != nil { log.Logger("api").Debug("Client Disconnected", zap.Error(err)) diff --git a/go.mod b/go.mod index 348df63d3b..6fc3d1fab4 100644 --- a/go.mod +++ b/go.mod @@ -242,7 +242,7 @@ require ( golang.org/x/oauth2 v0.8.0 // indirect golang.org/x/sys v0.16.0 // indirect golang.org/x/term v0.15.0 // indirect - golang.org/x/time v0.3.0 // indirect + golang.org/x/time v0.3.0 gopkg.in/square/go-jose.v2 v2.5.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 1acfc12791..1b5a839cbb 100644 --- a/go.sum +++ b/go.sum @@ -664,8 +664,6 @@ github.com/iotexproject/iotex-antenna-go/v2 v2.5.1/go.mod h1:8pDZcM45M0gY6jm3PoM github.com/iotexproject/iotex-election v0.3.5-0.20210611041425-20ddf674363d h1:/j1xCAC9YiG/8UKqYvycS/v3ddVsb1G7AMyLXOjeYI0= github.com/iotexproject/iotex-election v0.3.5-0.20210611041425-20ddf674363d/go.mod h1:GRWevxtqQ4gPMrd7Qxhr29/7aTgvjiTp+rFI9KMMZEo= github.com/iotexproject/iotex-proto v0.5.0/go.mod h1:Xg6REkv+nTZN+OC22xXIQuqKdTWWHwOAJEXCoMpDwtI= -github.com/iotexproject/iotex-proto v0.5.15 h1:9+6szZDQ1HhSFKyB2kVlVPXdCFAHHw72VVGcYXQ7P/w= -github.com/iotexproject/iotex-proto v0.5.15/go.mod h1:wQpCk3Df0fPID+K8ohiICGj+cWRmcQ3wanT+aSrnIPo= github.com/iotexproject/iotex-proto v0.6.0 h1:UIwPq5QuuPwR7G4OZzmyBsbvEJ+YH6oHyzRjxGk9Fow= github.com/iotexproject/iotex-proto v0.6.0/go.mod h1:wQpCk3Df0fPID+K8ohiICGj+cWRmcQ3wanT+aSrnIPo= github.com/ipfs/go-cid v0.0.1/go.mod h1:GHWU/WuQdMPmIosc4Yn1bcCT7dSeX4lBafM7iqUPQvM=