From b46d150dac28ace7b1a4f8203e2127bda317ac5d Mon Sep 17 00:00:00 2001 From: Damian Czaja Date: Mon, 23 Sep 2024 17:26:34 +0200 Subject: [PATCH] ensure Send to stream is done only from 1 goroutine --- cmd/proxy/main.go | 8 ++----- internal/proxy/client.go | 52 +++++++++++++++++++++++++--------------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index f6166ec..cad59cc 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -79,15 +79,11 @@ func main() { }(conn) client := proxy.New(conn, gcp.New(tokenSource), logger, - cfg.GetPodName(), cfg.ClusterID, GetVersion(), cfg.KeepAlive, cfg.KeepAliveTimeout) + cfg.GetPodName(), cfg.ClusterID, GetVersion(), cfg.CastAI.APIKey, cfg.KeepAlive, cfg.KeepAliveTimeout) go startHealthServer(logger, cfg.HealthAddress) - proxyCtx := metadata.NewOutgoingContext(ctx, metadata.Pairs( - "authorization", fmt.Sprintf("Token %s", cfg.CastAI.APIKey), - )) - - err = client.Run(proxyCtx) + err = client.Run(ctx) if err != nil { logger.Panicf("Failed to run client: %v", err) panic(err) diff --git a/internal/proxy/client.go b/internal/proxy/client.go index a3956fc..41f6669 100644 --- a/internal/proxy/client.go +++ b/internal/proxy/client.go @@ -15,6 +15,7 @@ import ( "github.com/samber/lo" "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/metadata" cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha" ) @@ -29,6 +30,7 @@ type CloudClient interface { type Client struct { grpcConn *grpc.ClientConn + apiKey string cloudClient CloudClient log *logrus.Logger podName string @@ -44,9 +46,10 @@ type Client struct { version string } -func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, podName, clusterID, version string, keepalive, keepaliveTimeout time.Duration) *Client { +func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, podName, clusterID, version, apiKey string, keepalive, keepaliveTimeout time.Duration) *Client { c := &Client{ grpcConn: grpcConn, + apiKey: apiKey, cloudClient: cloudClient, log: logger, podName: podName, @@ -60,22 +63,26 @@ func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logg } func (c *Client) Run(ctx context.Context) error { + authCtx := metadata.NewOutgoingContext(ctx, metadata.Pairs( + "authorization", fmt.Sprintf("Token %s", c.apiKey), + )) + t := time.NewTimer(time.Millisecond) for { select { - case <-ctx.Done(): - return ctx.Err() + case <-authCtx.Done(): + return authCtx.Err() case <-t.C: c.log.Info("Starting proxy client") - stream, closeStream, err := c.getStream(ctx) + stream, closeStream, err := c.getStream(authCtx) if err != nil { c.log.Errorf("Could not get stream, restarting proxy client in %vs: %v", time.Duration(c.keepAlive.Load()).Seconds(), err) t.Reset(time.Duration(c.keepAlive.Load())) continue } - err = c.run(ctx, stream, closeStream) + err = c.run(authCtx, stream, closeStream) if err != nil { c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err) t.Reset(time.Duration(c.keepAlive.Load())) @@ -134,7 +141,12 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI return fmt.Errorf("c.Connect: %w", err) } - go c.sendKeepAlive(stream) + keepAliveCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest) + defer close(keepAliveCh) + go c.sendKeepAlive(stream, keepAliveCh) + + messageRespCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest) + defer close(messageRespCh) go func() { for { @@ -160,7 +172,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI } c.log.Debugf("Handling message from castai") - go c.handleMessage(in, stream) + go c.handleMessage(in, messageRespCh) } }() @@ -170,6 +182,14 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI return ctx.Err() case <-stream.Context().Done(): return fmt.Errorf("stream closed %w", stream.Context().Err()) + case req := <-keepAliveCh: + if err := stream.Send(req); err != nil { + c.log.WithError(err).Warn("failed to send keep alive") + } + case req := <-messageRespCh: + if err := stream.Send(req); err != nil { + c.log.WithError(err).Warn("failed to send message response") + } case <-time.After(time.Duration(c.keepAlive.Load())): if !c.isAlive() { if err := c.lastSeenError.Load(); err != nil { @@ -181,7 +201,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI } } -func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) { +func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) { if in == nil { c.log.Error("nil message") return @@ -202,7 +222,7 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, s } else { c.log.Debugf("Proxied request msg_id=%v, sending response to castai", in.GetMessageId()) } - err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{ + respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{ Request: &cloudproxyv1alpha.StreamCloudProxyRequest_Response{ Response: &cloudproxyv1alpha.ClusterResponse{ ClientMetadata: &cloudproxyv1alpha.ClientMetadata{ @@ -213,9 +233,6 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, s HttpResponse: resp, }, }, - }) - if err != nil { - c.log.Errorf("error sending response for msg_id=%v %v", in.GetMessageId(), err) } } @@ -261,7 +278,7 @@ func (c *Client) isAlive() bool { return time.Now().UnixNano()-lastSeen <= c.keepAliveTimeout.Load() } -func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) { +func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, sendCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) { ticker := time.NewTimer(time.Duration(c.keepAlive.Load())) defer ticker.Stop() @@ -277,7 +294,8 @@ func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamClou return } c.log.Debug("Sending keep-alive to castai") - err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{ + + sendCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{ Request: &cloudproxyv1alpha.StreamCloudProxyRequest_ClientStats{ ClientStats: &cloudproxyv1alpha.ClientStats{ ClientMetadata: &cloudproxyv1alpha.ClientMetadata{ @@ -290,12 +308,8 @@ func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamClou }, }, }, - }) - if err != nil { - c.lastSeen.Store(0) - c.log.Errorf("error sending keep alive message: %v", err) - return } + ticker.Reset(time.Duration(c.keepAlive.Load())) } }