Skip to content

Commit

Permalink
ensure Send to stream is done only from 1 goroutine
Browse files Browse the repository at this point in the history
  • Loading branch information
Trojan295 committed Sep 23, 2024
1 parent 0ea1e47 commit b46d150
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
8 changes: 2 additions & 6 deletions cmd/proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,11 @@ func main() {
}(conn)

client := proxy.New(conn, gcp.New(tokenSource), logger,
cfg.GetPodName(), cfg.ClusterID, GetVersion(), cfg.KeepAlive, cfg.KeepAliveTimeout)
cfg.GetPodName(), cfg.ClusterID, GetVersion(), cfg.CastAI.APIKey, cfg.KeepAlive, cfg.KeepAliveTimeout)

go startHealthServer(logger, cfg.HealthAddress)

proxyCtx := metadata.NewOutgoingContext(ctx, metadata.Pairs(
"authorization", fmt.Sprintf("Token %s", cfg.CastAI.APIKey),
))

err = client.Run(proxyCtx)
err = client.Run(ctx)
if err != nil {
logger.Panicf("Failed to run client: %v", err)
panic(err)
Expand Down
52 changes: 33 additions & 19 deletions internal/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/samber/lo"
"github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/metadata"

cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha"
)
Expand All @@ -29,6 +30,7 @@ type CloudClient interface {

type Client struct {
grpcConn *grpc.ClientConn
apiKey string
cloudClient CloudClient
log *logrus.Logger
podName string
Expand All @@ -44,9 +46,10 @@ type Client struct {
version string
}

func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, podName, clusterID, version string, keepalive, keepaliveTimeout time.Duration) *Client {
func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, podName, clusterID, version, apiKey string, keepalive, keepaliveTimeout time.Duration) *Client {
c := &Client{
grpcConn: grpcConn,
apiKey: apiKey,
cloudClient: cloudClient,
log: logger,
podName: podName,
Expand All @@ -60,22 +63,26 @@ func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logg
}

func (c *Client) Run(ctx context.Context) error {
authCtx := metadata.NewOutgoingContext(ctx, metadata.Pairs(
"authorization", fmt.Sprintf("Token %s", c.apiKey),
))

t := time.NewTimer(time.Millisecond)

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-authCtx.Done():
return authCtx.Err()
case <-t.C:
c.log.Info("Starting proxy client")
stream, closeStream, err := c.getStream(ctx)
stream, closeStream, err := c.getStream(authCtx)
if err != nil {
c.log.Errorf("Could not get stream, restarting proxy client in %vs: %v", time.Duration(c.keepAlive.Load()).Seconds(), err)
t.Reset(time.Duration(c.keepAlive.Load()))
continue
}

err = c.run(ctx, stream, closeStream)
err = c.run(authCtx, stream, closeStream)
if err != nil {
c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err)
t.Reset(time.Duration(c.keepAlive.Load()))
Expand Down Expand Up @@ -134,7 +141,12 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
return fmt.Errorf("c.Connect: %w", err)
}

go c.sendKeepAlive(stream)
keepAliveCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(keepAliveCh)
go c.sendKeepAlive(stream, keepAliveCh)

messageRespCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(messageRespCh)

go func() {
for {
Expand All @@ -160,7 +172,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}

c.log.Debugf("Handling message from castai")
go c.handleMessage(in, stream)
go c.handleMessage(in, messageRespCh)
}
}()

Expand All @@ -170,6 +182,14 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
return ctx.Err()
case <-stream.Context().Done():
return fmt.Errorf("stream closed %w", stream.Context().Err())
case req := <-keepAliveCh:
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send keep alive")
}
case req := <-messageRespCh:
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send message response")
}
case <-time.After(time.Duration(c.keepAlive.Load())):
if !c.isAlive() {
if err := c.lastSeenError.Load(); err != nil {
Expand All @@ -181,7 +201,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}
}

func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) {
func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
if in == nil {
c.log.Error("nil message")
return
Expand All @@ -202,7 +222,7 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, s
} else {
c.log.Debugf("Proxied request msg_id=%v, sending response to castai", in.GetMessageId())
}
err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{
respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{
Request: &cloudproxyv1alpha.StreamCloudProxyRequest_Response{
Response: &cloudproxyv1alpha.ClusterResponse{
ClientMetadata: &cloudproxyv1alpha.ClientMetadata{
Expand All @@ -213,9 +233,6 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, s
HttpResponse: resp,
},
},
})
if err != nil {
c.log.Errorf("error sending response for msg_id=%v %v", in.GetMessageId(), err)
}
}

Expand Down Expand Up @@ -261,7 +278,7 @@ func (c *Client) isAlive() bool {
return time.Now().UnixNano()-lastSeen <= c.keepAliveTimeout.Load()
}

func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) {
func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, sendCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
ticker := time.NewTimer(time.Duration(c.keepAlive.Load()))
defer ticker.Stop()

Expand All @@ -277,7 +294,8 @@ func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamClou
return
}
c.log.Debug("Sending keep-alive to castai")
err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{

sendCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{
Request: &cloudproxyv1alpha.StreamCloudProxyRequest_ClientStats{
ClientStats: &cloudproxyv1alpha.ClientStats{
ClientMetadata: &cloudproxyv1alpha.ClientMetadata{
Expand All @@ -290,12 +308,8 @@ func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamClou
},
},
},
})
if err != nil {
c.lastSeen.Store(0)
c.log.Errorf("error sending keep alive message: %v", err)
return
}

ticker.Reset(time.Duration(c.keepAlive.Load()))
}
}
Expand Down

0 comments on commit b46d150

Please sign in to comment.