From 810f9814aaa60d8f58fded5b2d11f442b31e8652 Mon Sep 17 00:00:00 2001 From: Patric Vormstein Date: Wed, 6 Nov 2024 13:22:58 +0100 Subject: [PATCH] add stream analytics to ee (TT-13233) (#6671) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### **User description**
TT-13233
Summary Streams in Pump - Enable Pump to purge Stream API records
Type Story Story
Status In Dev
Points N/A
Labels -
--- This PR adds analytics to stream APIs for EE. ## Types of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to change) - [ ] Refactoring or add test (improvements in base code or adds test coverage to functionality) ___ ### **PR Type** Enhancement, Tests ___ ### **Description** - Introduced stream analytics interfaces and no-op implementations to support analytics recording. - Enhanced `Manager` and `Middleware` to integrate with `StreamAnalyticsFactory`. - Implemented `DefaultStreamAnalyticsFactory` and `StreamAnalyticsResponseWriter` for detailed analytics. - Added comprehensive tests for new analytics components and functionalities. ___ ### **Changes walkthrough** 📝
Relevant files
Enhancement
7 files
analytics.go
Add interfaces and no-op implementations for stream analytics

ee/middleware/streams/analytics.go
  • Introduced StreamAnalyticsFactory and StreamAnalyticsRecorder
    interfaces.
  • Added NoopStreamAnalyticsFactory and NoopStreamAnalyticsRecorder
    implementations.
  • Defined error ErrResponseWriterNotHijackable.
  • +42/-0   
    manager.go
    Integrate analytics factory into stream manager                   

    ee/middleware/streams/manager.go
  • Added analyticsFactory to Manager struct.
  • Implemented SetAnalyticsFactory method.
  • Updated stream creation logic to use HandleFuncAdapter.
  • +21/-13 
    middleware.go
    Enhance middleware with analytics factory support               

    ee/middleware/streams/middleware.go
  • Added analyticsFactory to Middleware struct.
  • Updated NewMiddleware to accept analyticsFactory.
  • Implemented SetAnalyticsFactory and GetStreamManager methods.
  • +35/-12 
    util.go
    Refactor and enhance handle function adapter                         

    ee/middleware/streams/util.go
  • Refactored handleFuncAdapter to HandleFuncAdapter.
  • Integrated analytics recording in HandleFunc.
  • +20/-17 
    analytics_streams.go
    Implement default stream analytics factory and response writer

    gateway/analytics_streams.go
  • Added DefaultStreamAnalyticsFactory and related recorder
    implementations.
  • Implemented StreamAnalyticsResponseWriter.
  • Provided utility functions for analytics recording.
  • +230/-0 
    handler_success.go
    Refactor record detail logic for streams                                 

    gateway/handler_success.go
  • Added recordDetailUnsafe function.
  • Refactored recordDetail to use recordDetailUnsafe.
  • +6/-1     
    mw_streaming_ee.go
    Integrate analytics factory into streaming middleware       

    gateway/mw_streaming_ee.go
  • Integrated StreamAnalyticsFactory into streaming middleware
    initialization.
  • +3/-1     
    Tests
    2 files
    analytics_streams_test.go
    Add tests for stream analytics components                               

    gateway/analytics_streams_test.go
  • Added tests for websocket upgrade detection.
  • Tested DefaultStreamAnalyticsFactory and
    StreamAnalyticsResponseWriter.
  • Verified analytics recording functionality.
  • +316/-0 
    mw_streaming_test.go
    Update streaming middleware tests for analytics integration

    gateway/mw_streaming_test.go
  • Updated streaming middleware test setup to include analytics factory.
  • +1/-1     
    ___ > 💡 **PR-Agent usage**: Comment `/help "your question"` on any pull request to receive relevant information --- ee/middleware/streams/analytics.go | 41 ++++ ee/middleware/streams/manager.go | 34 +-- ee/middleware/streams/middleware.go | 47 +++-- ee/middleware/streams/util.go | 37 ++-- gateway/analytics_streams.go | 193 +++++++++++++++++ gateway/analytics_streams_test.go | 317 ++++++++++++++++++++++++++++ gateway/handler_success.go | 11 +- gateway/mw_streaming_ee.go | 4 +- gateway/mw_streaming_test.go | 2 +- 9 files changed, 639 insertions(+), 47 deletions(-) create mode 100644 ee/middleware/streams/analytics.go create mode 100644 gateway/analytics_streams.go create mode 100644 gateway/analytics_streams_test.go diff --git a/ee/middleware/streams/analytics.go b/ee/middleware/streams/analytics.go new file mode 100644 index 00000000000..3e6e77c54a0 --- /dev/null +++ b/ee/middleware/streams/analytics.go @@ -0,0 +1,41 @@ +package streams + +import ( + "errors" + "net/http" + + "github.com/TykTechnologies/tyk-pump/analytics" +) + +var ( + ErrResponseWriterNotHijackable = errors.New("ResponseWriter is not hijackable") +) + +type StreamAnalyticsFactory interface { + CreateRecorder(r *http.Request) StreamAnalyticsRecorder + CreateResponseWriter(w http.ResponseWriter, r *http.Request, streamID string, recorder StreamAnalyticsRecorder) http.ResponseWriter +} + +type NoopStreamAnalyticsFactory struct{} + +func (n *NoopStreamAnalyticsFactory) CreateRecorder(r *http.Request) StreamAnalyticsRecorder { + return &NoopStreamAnalyticsRecorder{} +} + +func (n *NoopStreamAnalyticsFactory) CreateResponseWriter(w http.ResponseWriter, r *http.Request, streamID string, recorder StreamAnalyticsRecorder) http.ResponseWriter { + return w +} + +type StreamAnalyticsRecorder interface { + PrepareRecord(r *http.Request) + RecordHit(statusCode int, latency analytics.Latency) error +} + +type NoopStreamAnalyticsRecorder struct{} + +func (n *NoopStreamAnalyticsRecorder) PrepareRecord(r *http.Request) { +} + +func (n *NoopStreamAnalyticsRecorder) RecordHit(statusCode int, latency analytics.Latency) error { + return nil +} diff --git a/ee/middleware/streams/manager.go b/ee/middleware/streams/manager.go index 5412174007e..0fcc9bf4ef7 100644 --- a/ee/middleware/streams/manager.go +++ b/ee/middleware/streams/manager.go @@ -12,13 +12,14 @@ import ( // Manager is responsible for creating a single stream. type Manager struct { - streams sync.Map - routeLock sync.Mutex - muxer *mux.Router - mw *Middleware - dryRun bool - listenPaths []string - activityCounter atomic.Int32 // Counts active subscriptions, requests. + streams sync.Map + routeLock sync.Mutex + muxer *mux.Router + mw *Middleware + dryRun bool + listenPaths []string + activityCounter atomic.Int32 // Counts active subscriptions, requests. + analyticsFactory StreamAnalyticsFactory } func (sm *Manager) initStreams(r *http.Request, config *StreamsConfig) { @@ -96,13 +97,13 @@ func (sm *Manager) createStream(streamID string, config map[string]interface{}) } stream := NewStream(sm.mw.allowedUnsafe) - err := stream.Start(config, &handleFuncAdapter{ - mw: sm.mw, - streamID: streamFullID, - muxer: sm.muxer, - sm: sm, + err := stream.Start(config, &HandleFuncAdapter{ + StreamMiddleware: sm.mw, + StreamID: streamFullID, + Muxer: sm.muxer, + StreamManager: sm, // child logger is necessary to prevent race condition - logger: sm.mw.Logger().WithField("stream", streamFullID), + Logger: sm.mw.Logger().WithField("stream", streamFullID), }) if err != nil { sm.mw.Logger().Errorf("Failed to start stream %s: %v", streamFullID, err) @@ -123,3 +124,10 @@ func (sm *Manager) hasPath(path string) bool { } return false } + +func (sm *Manager) SetAnalyticsFactory(factory StreamAnalyticsFactory) { + if factory == nil { + factory = &NoopStreamAnalyticsFactory{} + } + sm.analyticsFactory = factory +} diff --git a/ee/middleware/streams/middleware.go b/ee/middleware/streams/middleware.go index 5f96a8a0e3f..a6f20ab8ad1 100644 --- a/ee/middleware/streams/middleware.go +++ b/ee/middleware/streams/middleware.go @@ -29,21 +29,23 @@ type Middleware struct { createStreamManagerLock sync.Mutex StreamManagerCache sync.Map // Map of payload hash to Manager - ctx context.Context - cancel context.CancelFunc - allowedUnsafe []string - defaultManager *Manager + ctx context.Context + cancel context.CancelFunc + allowedUnsafe []string + defaultManager *Manager + analyticsFactory StreamAnalyticsFactory } // Middleware implements model.Middleware. var _ model.Middleware = &Middleware{} // NewMiddleware returns a new instance of Middleware. -func NewMiddleware(gw Gateway, mw BaseMiddleware, spec *APISpec) *Middleware { +func NewMiddleware(gw Gateway, mw BaseMiddleware, spec *APISpec, analyticsFactory StreamAnalyticsFactory) *Middleware { return &Middleware{ - base: mw, - Gw: gw, - Spec: spec, + base: mw, + Gw: gw, + Spec: spec, + analyticsFactory: analyticsFactory, } } @@ -89,6 +91,13 @@ func (s *Middleware) Init() { s.Logger().Debug("Initializing default stream manager") s.defaultManager = s.CreateStreamManager(nil) + s.Logger().Debug("Initializing stream analytics factory") + if s.analyticsFactory == nil { + s.SetAnalyticsFactory(&NoopStreamAnalyticsFactory{}) + } else { + s.SetAnalyticsFactory(s.analyticsFactory) + } + // Start garbage collection routine go func() { ticker := time.NewTicker(StreamGCInterval) @@ -122,10 +131,11 @@ func (s *Middleware) CreateStreamManager(r *http.Request) *Manager { } newManager := &Manager{ - muxer: mux.NewRouter(), - mw: s, - dryRun: r == nil, - activityCounter: atomic.Int32{}, + muxer: mux.NewRouter(), + mw: s, + dryRun: r == nil, + activityCounter: atomic.Int32{}, + analyticsFactory: &NoopStreamAnalyticsFactory{}, } newManager.initStreams(r, streamsConfig) @@ -233,6 +243,7 @@ func (s *Middleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ in var match mux.RouteMatch streamManager := s.CreateStreamManager(r) + streamManager.SetAnalyticsFactory(s.analyticsFactory) streamManager.routeLock.Lock() streamManager.muxer.Match(newRequest, &match) streamManager.routeLock.Unlock() @@ -279,3 +290,15 @@ func (s *Middleware) Unload() { s.StreamManagerCache = sync.Map{} s.Logger().Info("All streams successfully removed") } + +func (s *Middleware) SetAnalyticsFactory(factory StreamAnalyticsFactory) { + if factory == nil { + factory = &NoopStreamAnalyticsFactory{} + } + s.analyticsFactory = factory + s.defaultManager.SetAnalyticsFactory(factory) +} + +func (s *Middleware) GetStreamManager() *Manager { + return s.defaultManager +} diff --git a/ee/middleware/streams/util.go b/ee/middleware/streams/util.go index a3218a12b1f..e4400f2ab63 100644 --- a/ee/middleware/streams/util.go +++ b/ee/middleware/streams/util.go @@ -7,30 +7,33 @@ import ( "github.com/sirupsen/logrus" ) -type handleFuncAdapter struct { - streamID string - sm *Manager - mw *Middleware - muxer *mux.Router - logger *logrus.Entry +type HandleFuncAdapter struct { + StreamID string + StreamManager *Manager + StreamMiddleware *Middleware + Muxer *mux.Router + Logger *logrus.Entry } -func (h *handleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter, *http.Request)) { - h.logger.Debugf("Registering streaming handleFunc for path: %s", path) +func (h *HandleFuncAdapter) HandleFunc(path string, f func(http.ResponseWriter, *http.Request)) { + h.Logger.Debugf("Registering streaming handleFunc for path: %s", path) - if h.mw == nil || h.muxer == nil { - h.logger.Error("Middleware or muxer is nil") + if h.StreamMiddleware == nil || h.Muxer == nil { + h.Logger.Error("Middleware or muxer is nil") return } - h.sm.routeLock.Lock() - h.muxer.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { - h.sm.activityCounter.Add(1) - defer h.sm.activityCounter.Add(-1) - f(w, r) + h.StreamManager.routeLock.Lock() + h.Muxer.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { + recorder := h.StreamManager.analyticsFactory.CreateRecorder(r) + analyticsResponseWriter := h.StreamManager.analyticsFactory.CreateResponseWriter(w, r, h.StreamID, recorder) + + h.StreamManager.activityCounter.Add(1) + defer h.StreamManager.activityCounter.Add(-1) + f(analyticsResponseWriter, r) }) - h.sm.routeLock.Unlock() - h.logger.Debugf("Registered handler for path: %s", path) + h.StreamManager.routeLock.Unlock() + h.Logger.Debugf("Registered handler for path: %s", path) } // Helper function to extract paths from an http_server configuration diff --git a/gateway/analytics_streams.go b/gateway/analytics_streams.go new file mode 100644 index 00000000000..766a65c8198 --- /dev/null +++ b/gateway/analytics_streams.go @@ -0,0 +1,193 @@ +//go:build ee || dev + +package gateway + +import ( + "bufio" + "io" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/sirupsen/logrus" + + "github.com/TykTechnologies/tyk-pump/analytics" + + "github.com/TykTechnologies/tyk/ee/middleware/streams" +) + +type DefaultStreamAnalyticsFactory struct { + Logger *logrus.Entry + Gw *Gateway + Spec *APISpec +} + +func NewStreamAnalyticsFactory(logger *logrus.Entry, gw *Gateway, spec *APISpec) streams.StreamAnalyticsFactory { + return &DefaultStreamAnalyticsFactory{ + Logger: logger, + Gw: gw, + Spec: spec, + } +} + +func (d *DefaultStreamAnalyticsFactory) CreateRecorder(r *http.Request) streams.StreamAnalyticsRecorder { + detailed := false + if recordDetailUnsafe(r, d.Spec) { + detailed = true + } + + if isWebsocketUpgrade(r) { + return NewWebSocketStreamAnalyticsRecorder(d.Gw, d.Spec, detailed) + } + + return NewDefaultStreamAnalyticsRecorder(d.Gw, d.Spec) +} + +func (d *DefaultStreamAnalyticsFactory) CreateResponseWriter(w http.ResponseWriter, r *http.Request, streamID string, recorder streams.StreamAnalyticsRecorder) http.ResponseWriter { + return NewStreamAnalyticsResponseWriter(d.Logger, w, r, streamID, recorder) +} + +type DefaultStreamAnalyticsRecorder struct { + Gw *Gateway + Spec *APISpec + reqCopy *http.Request + respCopy *http.Response +} + +func NewDefaultStreamAnalyticsRecorder(gw *Gateway, spec *APISpec) *DefaultStreamAnalyticsRecorder { + return &DefaultStreamAnalyticsRecorder{ + Gw: gw, + Spec: spec, + } +} + +func (s *DefaultStreamAnalyticsRecorder) PrepareRecord(r *http.Request) { + s.reqCopy = r.Clone(r.Context()) + s.respCopy = &http.Response{ + StatusCode: 200, + Header: make(http.Header), + } + + s.respCopy.Header.Set("Content-Length", strconv.FormatInt(0, 10)) + s.respCopy.Body = io.NopCloser(strings.NewReader("")) + s.respCopy.ContentLength = 0 +} + +func (s *DefaultStreamAnalyticsRecorder) RecordHit(statusCode int, latency analytics.Latency) error { + s.respCopy.StatusCode = statusCode + + handler := SuccessHandler{ + &BaseMiddleware{ + Spec: s.Spec, + Gw: s.Gw, + }, + } + + handler.RecordHit(s.reqCopy, latency, statusCode, s.respCopy, false) + return nil +} + +type WebSocketStreamAnalyticsRecorder struct { + Gw *Gateway + Spec *APISpec + Detailed bool + simpleStreamAnalyticsRecorder *DefaultStreamAnalyticsRecorder +} + +func NewWebSocketStreamAnalyticsRecorder(gw *Gateway, spec *APISpec, detailed bool) *WebSocketStreamAnalyticsRecorder { + return &WebSocketStreamAnalyticsRecorder{ + Gw: gw, + Spec: spec, + Detailed: detailed, + simpleStreamAnalyticsRecorder: NewDefaultStreamAnalyticsRecorder(gw, spec), + } +} + +func (d *WebSocketStreamAnalyticsRecorder) PrepareRecord(r *http.Request) { + d.simpleStreamAnalyticsRecorder.PrepareRecord(r) +} + +func (d *WebSocketStreamAnalyticsRecorder) RecordHit(statusCode int, latency analytics.Latency) error { + return d.simpleStreamAnalyticsRecorder.RecordHit(statusCode, latency) +} + +type StreamAnalyticsResponseWriter struct { + logger *logrus.Entry + w http.ResponseWriter + r *http.Request + streamID string + recorder streams.StreamAnalyticsRecorder + writtenStatusCode int +} + +func NewStreamAnalyticsResponseWriter(logger *logrus.Entry, w http.ResponseWriter, r *http.Request, streamID string, recorder streams.StreamAnalyticsRecorder) *StreamAnalyticsResponseWriter { + return &StreamAnalyticsResponseWriter{ + logger: logger, + w: w, + r: r, + streamID: streamID, + recorder: recorder, + writtenStatusCode: http.StatusOK, // implicit status code from ResponseWriter.Write + } +} + +func (s *StreamAnalyticsResponseWriter) SetStreamID(streamID string) { + s.streamID = streamID +} + +func (s *StreamAnalyticsResponseWriter) Header() http.Header { + return s.w.Header() +} + +func (s *StreamAnalyticsResponseWriter) Write(bytes []byte) (int, error) { + now := time.Now() + n, err := s.w.Write(bytes) + if err != nil { + return n, err + } + + totalMillisecond := DurationToMillisecond(time.Since(now)) + latency := analytics.Latency{ + Total: int64(totalMillisecond), + Upstream: int64(totalMillisecond), + } + + s.recorder.PrepareRecord(s.r) + recorderErr := s.recorder.RecordHit(s.writtenStatusCode, latency) + if recorderErr != nil { + s.logger.Errorf("Failed to record analytics for stream on path '%s %s', %v", s.r.Method, s.r.URL.Path, recorderErr) + } + return n, nil +} + +func (s *StreamAnalyticsResponseWriter) WriteHeader(statusCode int) { + s.writtenStatusCode = statusCode + s.w.WriteHeader(statusCode) +} + +func (s *StreamAnalyticsResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijackableWriter, ok := s.w.(http.Hijacker) + if !ok { + return nil, nil, streams.ErrResponseWriterNotHijackable + } + + s.recorder.PrepareRecord(s.r) + recorderErr := s.recorder.RecordHit(http.StatusSwitchingProtocols, analytics.Latency{}) + if recorderErr != nil { + s.logger.Errorf("Failed to record analytics for connection upgrade on path 'UPGRADE %s', %v", s.r.URL.Path, recorderErr) + } + + return hijackableWriter.Hijack() +} + +func (s *StreamAnalyticsResponseWriter) Flush() { + if flusher, ok := s.w.(http.Flusher); ok { + flusher.Flush() + } +} + +func isWebsocketUpgrade(r *http.Request) bool { + return strings.ToLower(r.Header.Get("Connection")) == "upgrade" && strings.ToLower(r.Header.Get("Upgrade")) == "websocket" +} diff --git a/gateway/analytics_streams_test.go b/gateway/analytics_streams_test.go new file mode 100644 index 00000000000..33cfa0a2fde --- /dev/null +++ b/gateway/analytics_streams_test.go @@ -0,0 +1,317 @@ +//go:build ee || dev + +package gateway + +import ( + "bufio" + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + logrus "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/TykTechnologies/tyk-pump/analytics" + + "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/ee/middleware/streams" +) + +func TestIsWebsocketUpgrade(t *testing.T) { + type testCase struct { + name string + connectionHeader string + upgradeHeader string + expectedResult bool + } + + for _, tc := range []testCase{ + { + name: "should be true for capitalized headers", + connectionHeader: "Upgrade", + upgradeHeader: "Websocket", + expectedResult: true, + }, + { + name: "should be true for lower-case headers", + connectionHeader: "upgrade", + upgradeHeader: "websocket", + expectedResult: true, + }, + { + name: "should be false for wrong connection header", + connectionHeader: "No-Upgrade", + upgradeHeader: "Websocket", + expectedResult: false, + }, + { + name: "should be false for wrong upgrade header", + connectionHeader: "No-Upgrade", + upgradeHeader: "Websocket", + expectedResult: false, + }, + { + name: "should be false for empty headers", + connectionHeader: "", + upgradeHeader: "", + expectedResult: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) + require.NoError(t, err) + + req.Header.Set("Connection", tc.connectionHeader) + req.Header.Set("Upgrade", tc.upgradeHeader) + assert.Equal(t, tc.expectedResult, isWebsocketUpgrade(req)) + }) + } +} + +func TestDefaultStreamAnalyticsFactory_CreateRecorder(t *testing.T) { + type testCase struct { + name string + enableDetailedRecording bool + expectedDetailedRecording bool + } + + t.Run("default recorder", func(t *testing.T) { + for _, tc := range []testCase{ + { + name: "should create a non-detailed default recorder", + enableDetailedRecording: false, + expectedDetailedRecording: false, + }, + { + name: "should create a detailed default recorder", + enableDetailedRecording: true, + expectedDetailedRecording: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + spec := &APISpec{ + APIDefinition: &apidef.APIDefinition{ + EnableDetailedRecording: tc.enableDetailedRecording, + }, + } + + req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) + require.NoError(t, err) + + factory := NewStreamAnalyticsFactory(nil, nil, spec) + recorder := factory.CreateRecorder(req) + + _, ok := recorder.(*DefaultStreamAnalyticsRecorder) + assert.True(t, ok) + }) + } + }) + + t.Run("websocket recorder", func(t *testing.T) { + for _, tc := range []testCase{ + { + name: "should create a non-detailed websocket recorder", + enableDetailedRecording: false, + expectedDetailedRecording: false, + }, + { + name: "should create a detailed websocket recorder", + enableDetailedRecording: true, + expectedDetailedRecording: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + spec := &APISpec{ + APIDefinition: &apidef.APIDefinition{ + EnableDetailedRecording: tc.enableDetailedRecording, + }, + } + + req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil) + require.NoError(t, err) + + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + + factory := NewStreamAnalyticsFactory(nil, nil, spec) + recorder := factory.CreateRecorder(req) + + websocketRecorder, ok := recorder.(*WebSocketStreamAnalyticsRecorder) + assert.True(t, ok) + assert.Equal(t, tc.expectedDetailedRecording, websocketRecorder.Detailed) + }) + } + }) +} + +func TestDefaultStreamAnalyticsRecorder_PrepareRecord(t *testing.T) { + t.Run("should prepare non-detailed record", func(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "http://localhost:8080/path", nil) + require.NoError(t, err) + + recorder := NewDefaultStreamAnalyticsRecorder(nil, &APISpec{APIDefinition: &apidef.APIDefinition{}}) + recorder.PrepareRecord(req) + + assert.NotNil(t, recorder.respCopy) + assert.NotNil(t, recorder.reqCopy) + assert.Equal(t, "/path", recorder.reqCopy.URL.Path) + assert.Equal(t, http.MethodPost, recorder.reqCopy.Method) + }) +} + +func TestHandleFuncAdapter_HandleFunc(t *testing.T) { + logger, _ := logrus.NewNullLogger() + baseMid := &BaseMiddleware{ + logger: logger.WithContext(context.Background()), + } + spec := &APISpec{ + APIDefinition: &apidef.APIDefinition{ + APIID: "test", + Name: "test-api", + IsOAS: true, + }, + } + streamSpec := streams.NewAPISpec(spec.APIID, spec.Name, spec.IsOAS, spec.OAS, spec.StripListenPath) + streamMiddleware := streams.NewMiddleware(baseMid.Gw, baseMid, streamSpec, nil) + + factory := &testStreamAnalyticsFactory{} + streamMiddleware.Init() + streamMiddleware.SetAnalyticsFactory(factory) + + router := mux.NewRouter() + testHandleFuncAdapter := streams.HandleFuncAdapter{ + StreamID: "test", + StreamManager: streamMiddleware.GetStreamManager(), + StreamMiddleware: streamMiddleware, + Muxer: router, + Logger: logger.WithContext(context.Background()), + } + + testHandleFuncAdapter.HandleFunc("/path", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusSwitchingProtocols) + w.Write(nil) + }) + + testServer := httptest.NewServer(router) + t.Cleanup(testServer.Close) + + targetURL := fmt.Sprintf("%s/%s", testServer.URL, "path") + req, err := http.NewRequest(http.MethodPost, targetURL, nil) + + client := http.Client{} + _, err = client.Do(req) + require.NoError(t, err) + + assert.Equal(t, http.StatusSwitchingProtocols, factory.responseWriter.responseRecorder.Code) +} + +func TestStreamAnalyticsResponseWriter_Write(t *testing.T) { + logger, _ := logrus.NewNullLogger() + responseRecorder := httptest.NewRecorder() + r := httptest.NewRequest("GET", "http://localhost/path", nil) + analyticsRecorder := &testStreamAnalyticsRecorder{} + + w := NewStreamAnalyticsResponseWriter(logger.WithContext(context.Background()), responseRecorder, r, "test", analyticsRecorder) + _, err := w.Write(nil) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, responseRecorder.Code) + assert.Equal(t, http.StatusOK, analyticsRecorder.actualRecord.ResponseCode) + assert.Equal(t, "GET", analyticsRecorder.actualRecord.Method) + assert.Equal(t, "localhost", analyticsRecorder.actualRecord.Host) + assert.Equal(t, "/path", analyticsRecorder.actualRecord.Path) +} + +func TestStreamAnalyticsResponseWriter_WriteHeader(t *testing.T) { + logger, _ := logrus.NewNullLogger() + responseRecorder := httptest.NewRecorder() + r := httptest.NewRequest("GET", "http://localhost/path", nil) + analyticsRecorder := &testStreamAnalyticsRecorder{} + + w := NewStreamAnalyticsResponseWriter(logger.WithContext(context.Background()), responseRecorder, r, "test", analyticsRecorder) + w.WriteHeader(http.StatusSwitchingProtocols) + assert.Equal(t, http.StatusSwitchingProtocols, responseRecorder.Code) +} + +func TestStreamAnalyticsResponseWriter_Hijack(t *testing.T) { + logger, _ := logrus.NewNullLogger() + responseRecorder := &testStreamHijackableResponseRecorder{ + responseRecorder: httptest.NewRecorder(), + } + r := httptest.NewRequest("GET", "http://localhost/path", nil) + analyticsRecorder := &testStreamAnalyticsRecorder{} + + w := NewStreamAnalyticsResponseWriter(logger.WithContext(context.Background()), responseRecorder, r, "test", analyticsRecorder) + _, _, err := w.Hijack() + require.NoError(t, err) + + assert.Equal(t, http.StatusSwitchingProtocols, analyticsRecorder.actualRecord.ResponseCode) + assert.Equal(t, "GET", analyticsRecorder.actualRecord.Method) + assert.Equal(t, "localhost", analyticsRecorder.actualRecord.Host) + assert.Equal(t, "/path", analyticsRecorder.actualRecord.Path) +} + +type testStreamAnalyticsFactory struct { + recorder *testStreamAnalyticsRecorder + responseWriter *testStreamHijackableResponseRecorder +} + +func (t *testStreamAnalyticsFactory) CreateRecorder(r *http.Request) streams.StreamAnalyticsRecorder { + t.recorder = &testStreamAnalyticsRecorder{} + return t.recorder +} + +func (t *testStreamAnalyticsFactory) CreateResponseWriter(w http.ResponseWriter, r *http.Request, streamID string, recorder streams.StreamAnalyticsRecorder) http.ResponseWriter { + httpRecorder := httptest.NewRecorder() + t.responseWriter = &testStreamHijackableResponseRecorder{ + responseRecorder: httpRecorder, + } + return t.responseWriter +} + +type testStreamAnalyticsRecorder struct { + actualRecord *analytics.AnalyticsRecord +} + +func (t *testStreamAnalyticsRecorder) PrepareRecord(r *http.Request) { + t.actualRecord = &analytics.AnalyticsRecord{ + Method: r.Method, + Host: r.Host, + Path: r.URL.Path, + } + return +} + +func (t *testStreamAnalyticsRecorder) RecordHit(statusCode int, latency analytics.Latency) error { + t.actualRecord.ResponseCode = statusCode + return nil +} + +type testStreamHijackableResponseRecorder struct { + responseRecorder *httptest.ResponseRecorder +} + +func (t *testStreamHijackableResponseRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, nil +} + +func (t *testStreamHijackableResponseRecorder) Flush() { + t.responseRecorder.Flush() +} + +func (t *testStreamHijackableResponseRecorder) Header() http.Header { + return t.responseRecorder.Header() +} + +func (t *testStreamHijackableResponseRecorder) Write(i []byte) (int, error) { + return t.responseRecorder.Write(i) +} + +func (t *testStreamHijackableResponseRecorder) WriteHeader(statusCode int) { + t.responseRecorder.WriteHeader(statusCode) +} diff --git a/gateway/handler_success.go b/gateway/handler_success.go index c149dc229ca..53dd816d95b 100644 --- a/gateway/handler_success.go +++ b/gateway/handler_success.go @@ -10,14 +10,15 @@ import ( "strings" "time" - "github.com/TykTechnologies/tyk/ctx" - "github.com/TykTechnologies/tyk/internal/httputil" - graphqlinternal "github.com/TykTechnologies/tyk/internal/graphql" "github.com/TykTechnologies/tyk-pump/analytics" + "github.com/TykTechnologies/tyk/apidef" + "github.com/TykTechnologies/tyk/internal/httputil" + "github.com/TykTechnologies/tyk/config" + "github.com/TykTechnologies/tyk/ctx" "github.com/TykTechnologies/tyk/header" "github.com/TykTechnologies/tyk/request" "github.com/TykTechnologies/tyk/user" @@ -333,6 +334,10 @@ func recordDetail(r *http.Request, spec *APISpec) bool { return false } + return recordDetailUnsafe(r, spec) +} + +func recordDetailUnsafe(r *http.Request, spec *APISpec) bool { if spec.EnableDetailedRecording { return true } diff --git a/gateway/mw_streaming_ee.go b/gateway/mw_streaming_ee.go index 64ee4ffedb7..81072acf020 100644 --- a/gateway/mw_streaming_ee.go +++ b/gateway/mw_streaming_ee.go @@ -10,6 +10,8 @@ import ( func getStreamingMiddleware(baseMid *BaseMiddleware) TykMiddleware { spec := baseMid.Spec streamSpec := streams.NewAPISpec(spec.APIID, spec.Name, spec.IsOAS, spec.OAS, spec.StripListenPath) - streamMw := streams.NewMiddleware(baseMid.Gw, baseMid, streamSpec) + + streamAnalyticsFactory := NewStreamAnalyticsFactory(baseMid.logger.Dup(), baseMid.Gw, spec) + streamMw := streams.NewMiddleware(baseMid.Gw, baseMid, streamSpec, streamAnalyticsFactory) return WrapMiddleware(baseMid, streamMw) } diff --git a/gateway/mw_streaming_test.go b/gateway/mw_streaming_test.go index 02377eaf9a2..f7cfaa8ab6d 100644 --- a/gateway/mw_streaming_test.go +++ b/gateway/mw_streaming_test.go @@ -965,7 +965,7 @@ func TestStreamingAPIGarbageCollection(t *testing.T) { apiSpec := streams.NewAPISpec(specs[0].APIID, specs[0].Name, specs[0].IsOAS, specs[0].OAS, specs[0].StripListenPath) - s := streams.NewMiddleware(ts.Gw, &DummyBase{}, apiSpec) + s := streams.NewMiddleware(ts.Gw, &DummyBase{}, apiSpec, nil) if err := setUpStreamAPI(ts, apiName, bentoHTTPServerTemplate); err != nil { t.Fatal(err)