From 6d9a4c3fbc173f6ba5e5c20566e991d77f8562ce Mon Sep 17 00:00:00 2001 From: Damian Czaja Date: Fri, 27 Sep 2024 16:27:08 +0200 Subject: [PATCH] gRCP improvments --- internal/proxy/client.go | 61 +++++++++++++++++------------------ internal/proxy/client_test.go | 20 ++++-------- 2 files changed, 36 insertions(+), 45 deletions(-) diff --git a/internal/proxy/client.go b/internal/proxy/client.go index 41f6669..c761edc 100644 --- a/internal/proxy/client.go +++ b/internal/proxy/client.go @@ -82,7 +82,7 @@ func (c *Client) Run(ctx context.Context) error { continue } - err = c.run(authCtx, stream, closeStream) + err = c.run(stream, closeStream) if err != nil { c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err) t.Reset(time.Duration(c.keepAlive.Load())) @@ -133,7 +133,7 @@ func (c *Client) sendInitialRequest(stream cloudproxyv1alpha.CloudProxyAPI_Strea return nil } -func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, closeStream func()) error { +func (c *Client) run(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, closeStream func()) error { defer closeStream() err := c.sendInitialRequest(stream) @@ -141,18 +141,13 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI return fmt.Errorf("c.Connect: %w", err) } - keepAliveCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest) - defer close(keepAliveCh) - go c.sendKeepAlive(stream, keepAliveCh) - - messageRespCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest) - defer close(messageRespCh) + sendCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest, 10) + go c.sendKeepAlive(stream, sendCh) + defer close(sendCh) go func() { for { select { - case <-ctx.Done(): - return case <-stream.Context().Done(): return default: @@ -162,7 +157,6 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI } c.log.Debugf("Polling stream for messages") - in, err := stream.Recv() if err != nil { c.log.Errorf("stream.Recv: got error: %v", err) @@ -172,24 +166,21 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI } c.log.Debugf("Handling message from castai") - go c.handleMessage(in, messageRespCh) + go c.handleMessage(stream.Context(), in, sendCh) } }() for { select { - case <-ctx.Done(): - return ctx.Err() case <-stream.Context().Done(): return fmt.Errorf("stream closed %w", stream.Context().Err()) - case req := <-keepAliveCh: - if err := stream.Send(req); err != nil { - c.log.WithError(err).Warn("failed to send keep alive") - } - case req := <-messageRespCh: + + case req := <-sendCh: + c.log.Debugf("Sending message to stream") if err := stream.Send(req); err != nil { - c.log.WithError(err).Warn("failed to send message response") + return fmt.Errorf("failed to send gRPC message: %w", err) } + case <-time.After(time.Duration(c.keepAlive.Load())): if !c.isAlive() { if err := c.lastSeenError.Load(); err != nil { @@ -201,16 +192,17 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI } } -func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) { +func (c *Client) handleMessage(ctx context.Context, in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) { if in == nil { c.log.Error("nil message") return } + + c.lastSeen.Store(time.Now().UnixNano()) c.processConfigurationRequest(in) // skip processing http request if keep alive message. if in.GetMessageId() == KeepAliveMessageID { - c.lastSeen.Store(time.Now().UnixNano()) c.log.Debugf("Received keep-alive message from castai for %s", in.GetClientMetadata().GetClusterId()) return } @@ -222,7 +214,11 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, r } else { c.log.Debugf("Proxied request msg_id=%v, sending response to castai", in.GetMessageId()) } - respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{ + + select { + case <-ctx.Done(): + return + case respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{ Request: &cloudproxyv1alpha.StreamCloudProxyRequest_Response{ Response: &cloudproxyv1alpha.ClusterResponse{ ClientMetadata: &cloudproxyv1alpha.ClientMetadata{ @@ -233,17 +229,21 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, r HttpResponse: resp, }, }, + }: + return } } func (c *Client) processConfigurationRequest(in *cloudproxyv1alpha.StreamCloudProxyResponse) { - if in.ConfigurationRequest != nil { - if in.ConfigurationRequest.GetKeepAlive() != 0 { - c.keepAlive.Store(in.ConfigurationRequest.GetKeepAlive()) - } - if in.ConfigurationRequest.GetKeepAliveTimeout() != 0 { - c.keepAliveTimeout.Store(in.ConfigurationRequest.GetKeepAliveTimeout()) - } + if in.ConfigurationRequest == nil { + return + } + + if in.ConfigurationRequest.GetKeepAlive() != 0 { + c.keepAlive.Store(in.ConfigurationRequest.GetKeepAlive()) + } + if in.ConfigurationRequest.GetKeepAliveTimeout() != 0 { + c.keepAliveTimeout.Store(in.ConfigurationRequest.GetKeepAliveTimeout()) } c.log.Debugf("Updated keep-alive configuration to %v and keep-alive timeout to %v", c.keepAlive.Load(), c.keepAliveTimeout.Load()) } @@ -274,7 +274,6 @@ func (c *Client) processHTTPRequest(req *cloudproxyv1alpha.HTTPRequest) *cloudpr func (c *Client) isAlive() bool { lastSeen := c.lastSeen.Load() - return time.Now().UnixNano()-lastSeen <= c.keepAliveTimeout.Load() } diff --git a/internal/proxy/client_test.go b/internal/proxy/client_test.go index 5635c6d..7a5d94e 100644 --- a/internal/proxy/client_test.go +++ b/internal/proxy/client_test.go @@ -266,7 +266,7 @@ func TestClient_handleMessage(t *testing.T) { <-msgStream }() - c.handleMessage(tt.args.in, msgStream) + c.handleMessage(context.Background(), tt.args.in, msgStream) require.Equal(t, tt.wantLastSeenUpdated, c.lastSeen.Load() > 0, "lastSeen: %v", c.lastSeen.Load()) require.Equal(t, tt.wantKeepAlive, c.keepAlive.Load(), "keepAlive: %v", c.keepAlive.Load()) require.Equal(t, tt.wantKeepAliveTimeout, c.keepAliveTimeout.Load(), "keepAliveTimeout: %v", c.keepAliveTimeout.Load()) @@ -438,9 +438,6 @@ func TestClient_run(t *testing.T) { { name: "send initial error", args: args{ - ctx: func() context.Context { - return context.Background() - }, tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) { m.EXPECT().Send(gomock.Any()).Return(fmt.Errorf("test error")) }, @@ -450,14 +447,12 @@ func TestClient_run(t *testing.T) { { name: "context done", args: args{ - ctx: func() context.Context { + tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) { ctx, cancel := context.WithCancel(context.Background()) cancel() - return ctx - }, - tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) { - m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times. - m.EXPECT().Context().Return(context.Background()).AnyTimes() // expected 0 or 1 times. + + m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times. + m.EXPECT().Context().Return(ctx).AnyTimes() // expected 0 or 1 times. }, }, wantLastSeenUpdated: true, @@ -466,9 +461,6 @@ func TestClient_run(t *testing.T) { { name: "stream not alive", args: args{ - ctx: func() context.Context { - return context.Background() - }, tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) { m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times. m.EXPECT().Context().Return(context.Background()).AnyTimes() // expected 0 or 1 times. @@ -491,7 +483,7 @@ func TestClient_run(t *testing.T) { if tt.args.tuneMockStream != nil { tt.args.tuneMockStream(stream) } - if err := c.run(tt.args.ctx(), stream, func() {}); (err != nil) != tt.wantErr { + if err := c.run(stream, func() {}); (err != nil) != tt.wantErr { t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr) } require.Equal(t, tt.wantLastSeenUpdated, c.lastSeen.Load() > 0, "lastSeen: %v", c.lastSeen.Load())