From 509c508f84fc00b7a09911628fc795e3db94c2ec Mon Sep 17 00:00:00 2001
From: Edgar Gomes <talktoedgar@gmail.com>
Date: Wed, 20 Dec 2023 13:41:24 -0300
Subject: [PATCH] fix: websocket keep-alive (#1993)

## Description:
fix websocket keep alive (ping/pong) process and detect exit by
connection closed

## Is this change user facing?
NO
---
 .../engine/server/websocket_api_handler.go    | 28 ++----
 .../server/engine/streaming/websocket_pump.go | 91 ++++++++++++++++---
 2 files changed, 84 insertions(+), 35 deletions(-)

diff --git a/engine/server/engine/server/websocket_api_handler.go b/engine/server/engine/server/websocket_api_handler.go
index 1656f78c81..fee9cd7d19 100644
--- a/engine/server/engine/server/websocket_api_handler.go
+++ b/engine/server/engine/server/websocket_api_handler.go
@@ -180,32 +180,25 @@ func streamStarlarkLogsWithWebsocket[T any](ctx echo.Context, cors cors.Cors, st
 		return
 	}
 	defer wsPump.Close()
-	go wsPump.StartPumping()
 
 	found, err := streamerPool.Consume(streaming.StreamerUUID(streamerUUID), func(logline *rpc_api.StarlarkRunResponseLine) error {
 		response, err := to_http.ToHttpStarlarkRunResponseLine(logline)
 		if err != nil {
 			return stacktrace.Propagate(err, "Failed to convert value of type `%T` to http", logline)
 		}
-		wsPump.PumpMessage(response)
-		return nil
+		return wsPump.PumpMessage(response)
 	})
 
 	if !found {
-		wsPump.PumpResponseInfo(&notFoundErr)
+		if err := wsPump.PumpResponseInfo(&notFoundErr); err != nil {
+			logrus.WithError(err).Warn("Failed to send response.")
+		}
 	}
 
 	if err != nil {
 		logrus.WithError(err).WithFields(logrus.Fields{
 			"streamerUUID": streamerUUID,
-			"stacktrace":   fmt.Sprintf("%+v", err),
 		}).Error("Failed to stream all data")
-		streamingErr := api_type.ResponseInfo{
-			Type:    api_type.ERROR,
-			Message: fmt.Sprintf("Log streaming '%s' failed while sending the data", streamerUUID),
-			Code:    http.StatusInternalServerError,
-		}
-		wsPump.PumpResponseInfo(&streamingErr)
 	}
 }
 
@@ -270,24 +263,15 @@ func streamServiceLogsWithWebsocket(ctx echo.Context, cors cors.Cors, streamer s
 		return
 	}
 	defer wsPump.Close()
-	go wsPump.StartPumping()
 
 	err = streamer.Consume(func(logline *api_type.ServiceLogs) error {
-		wsPump.PumpMessage(logline)
-		return nil
+		return wsPump.PumpMessage(logline)
 	})
 
 	if err != nil {
 		logrus.WithError(err).WithFields(logrus.Fields{
-			"stacktrace": fmt.Sprintf("%+v", err),
-			"services":   streamer.GetRequestedServiceUuids(),
+			"services": streamer.GetRequestedServiceUuids(),
 		}).Error("Failed to stream all data")
-		streamingErr := api_type.ResponseInfo{
-			Type:    api_type.ERROR,
-			Message: "Log streaming failed while sending the data",
-			Code:    http.StatusInternalServerError,
-		}
-		wsPump.PumpResponseInfo(&streamingErr)
 	}
 }
 
diff --git a/engine/server/engine/streaming/websocket_pump.go b/engine/server/engine/streaming/websocket_pump.go
index ad5fa703f5..103ed4f4d0 100644
--- a/engine/server/engine/streaming/websocket_pump.go
+++ b/engine/server/engine/streaming/websocket_pump.go
@@ -27,11 +27,13 @@ const (
 )
 
 type WebsocketPump[T interface{}] struct {
-	websocket  *websocket.Conn
-	inputChan  chan *T
-	infoChan   chan *api_type.ResponseInfo
-	ctx        context.Context
-	cancelFunc context.CancelFunc
+	websocket       *websocket.Conn
+	inputChan       chan *T
+	infoChan        chan *api_type.ResponseInfo
+	ctx             context.Context
+	cancelFunc      context.CancelFunc
+	closed          bool
+	connectionError *error
 }
 
 func NewWebsocketPump[T interface{}](ctx echo.Context, cors cors.Cors) (*WebsocketPump[T], error) {
@@ -49,22 +51,37 @@ func NewWebsocketPump[T interface{}](ctx echo.Context, cors cors.Cors) (*Websock
 
 	ctxWithCancel, cancelFunc := context.WithCancel(context.Background())
 
-	return &WebsocketPump[T]{
+	pump := &WebsocketPump[T]{
 		websocket:  conn,
 		inputChan:  make(chan *T),
 		infoChan:   make(chan *api_type.ResponseInfo),
 		ctx:        ctxWithCancel,
 		cancelFunc: cancelFunc,
-	}, nil
+		closed:     false,
+	}
+
+	go pump.startPumping()
+
+	return pump, nil
 }
 
-func (pump WebsocketPump[T]) StartPumping() {
+func (pump *WebsocketPump[T]) readLoop() {
+	for {
+		_, _, err := pump.websocket.ReadMessage()
+		if err != nil {
+			break
+		}
+	}
+}
+
+func (pump *WebsocketPump[T]) startPumping() {
 	ticker := time.NewTicker(pingPeriod)
 	defer func() {
 		ticker.Stop()
 		pump.websocket.Close()
 		close(pump.inputChan)
 		close(pump.infoChan)
+		pump.closed = true
 	}()
 
 	logrus.WithFields(logrus.Fields{
@@ -76,47 +93,95 @@ func (pump WebsocketPump[T]) StartPumping() {
 	pump.websocket.SetReadLimit(maxMessageSize)
 	if err := pump.websocket.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
 		logrus.WithError(err).Error("Failed to set Pong wait time")
+		pump.connectionError = &err
 		return
 	}
 	// nolint:errcheck
-	pump.websocket.SetPongHandler(func(string) error { return pump.websocket.SetReadDeadline(time.Now().Add(pongWait)) })
+	pump.websocket.SetPongHandler(func(string) error {
+		logrus.Debug("Client is connected, got pong")
+		return pump.websocket.SetReadDeadline(time.Now().Add(pongWait))
+	})
+
+	pump.websocket.SetCloseHandler(func(code int, text string) error {
+		logrus.Infof("Websocket connection closed by the client - code: %d, msg: %s", code, text)
+		pump.cancelFunc()
+		return nil
+	})
+
+	// The read callbacks (handlers) are triggered from the ReadMessage calls, so
+	// we also need a dummy reader loop.
+	go pump.readLoop()
 
 	for {
 		select {
 		case <-ticker.C:
 			if err := pump.websocket.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
-				logrus.Debug("Websocket connection is likely closed, exiting keep alive process")
+				logrus.Debug("Websocket connection did not meet the write deadline")
+				pump.connectionError = &err
 				return
 			}
 			if err := pump.websocket.WriteMessage(websocket.PingMessage, nil); err != nil {
 				logrus.Debug("Websocket connection is likely closed, exiting keep alive process")
+				pump.connectionError = &err
 				return
 			}
 		case msg := <-pump.inputChan:
+			if err := pump.websocket.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
+				logrus.Debug("Websocket connection did not meet the write deadline")
+				pump.connectionError = &err
+				return
+			}
 			if err := pump.websocket.WriteJSON(msg); err != nil {
 				logrus.WithError(stacktrace.Propagate(err, "Failed to send value of type `%T` via websocket", msg)).Errorf("Failed to write message to websocket, closing it.")
+				pump.connectionError = &err
 				return
 			}
 		case msg := <-pump.infoChan:
+			if err := pump.websocket.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
+				logrus.Debug("Websocket connection did not meet the write deadline")
+				pump.connectionError = &err
+				return
+			}
 			if err := pump.websocket.WriteJSON(msg); err != nil {
 				logrus.WithError(stacktrace.Propagate(err, "Failed to send value of type `%T` via websocket", msg)).Errorf("Failed to write message to websocket, closing it.")
+				pump.connectionError = &err
 				return
 			}
 		case <-pump.ctx.Done():
-			logrus.Debug("Websocket pumper has been asked to close, closing it.")
+			logrus.Debug("Websocket pump has been asked to close, closing it.")
 			return
 		}
 	}
 }
 
-func (pump *WebsocketPump[T]) PumpResponseInfo(msg *api_type.ResponseInfo) {
+func (pump *WebsocketPump[T]) PumpResponseInfo(msg *api_type.ResponseInfo) error {
+	if pump.closed {
+		if pump.connectionError != nil {
+			return stacktrace.Propagate(*pump.connectionError, "Websocket has been closed due connection error")
+
+		}
+		return nil
+	}
 	pump.infoChan <- msg
+	return nil
 }
 
-func (pump *WebsocketPump[T]) PumpMessage(msg *T) {
+func (pump *WebsocketPump[T]) PumpMessage(msg *T) error {
+	if pump.closed {
+		if pump.connectionError != nil {
+			return stacktrace.Propagate(*pump.connectionError, "Websocket has been closed due connection error")
+
+		}
+		return nil
+	}
 	pump.inputChan <- msg
+	return nil
 }
 
 func (pump *WebsocketPump[T]) Close() {
 	pump.cancelFunc()
 }
+
+func (pump *WebsocketPump[T]) IsClosed() (bool, *error) {
+	return pump.closed, pump.connectionError
+}