Skip to content

Commit

Permalink
gRCP improvments
Browse files Browse the repository at this point in the history
  • Loading branch information
Trojan295 committed Sep 27, 2024
1 parent 84a1cde commit 6d9a4c3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 45 deletions.
61 changes: 30 additions & 31 deletions internal/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (c *Client) Run(ctx context.Context) error {
continue
}

err = c.run(authCtx, stream, closeStream)
err = c.run(stream, closeStream)
if err != nil {
c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err)
t.Reset(time.Duration(c.keepAlive.Load()))
Expand Down Expand Up @@ -133,26 +133,21 @@ func (c *Client) sendInitialRequest(stream cloudproxyv1alpha.CloudProxyAPI_Strea
return nil
}

func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, closeStream func()) error {
func (c *Client) run(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, closeStream func()) error {
defer closeStream()

err := c.sendInitialRequest(stream)
if err != nil {
return fmt.Errorf("c.Connect: %w", err)
}

keepAliveCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(keepAliveCh)
go c.sendKeepAlive(stream, keepAliveCh)

messageRespCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(messageRespCh)
sendCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest, 10)
go c.sendKeepAlive(stream, sendCh)
defer close(sendCh)

go func() {
for {
select {
case <-ctx.Done():
return
case <-stream.Context().Done():
return
default:
Expand All @@ -162,7 +157,6 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}

c.log.Debugf("Polling stream for messages")

in, err := stream.Recv()
if err != nil {
c.log.Errorf("stream.Recv: got error: %v", err)
Expand All @@ -172,24 +166,21 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}

c.log.Debugf("Handling message from castai")
go c.handleMessage(in, messageRespCh)
go c.handleMessage(stream.Context(), in, sendCh)
}
}()

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-stream.Context().Done():
return fmt.Errorf("stream closed %w", stream.Context().Err())
case req := <-keepAliveCh:
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send keep alive")
}
case req := <-messageRespCh:

case req := <-sendCh:
c.log.Debugf("Sending message to stream")
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send message response")
return fmt.Errorf("failed to send gRPC message: %w", err)
}

case <-time.After(time.Duration(c.keepAlive.Load())):
if !c.isAlive() {
if err := c.lastSeenError.Load(); err != nil {
Expand All @@ -201,16 +192,17 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}
}

func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
func (c *Client) handleMessage(ctx context.Context, in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
if in == nil {
c.log.Error("nil message")
return
}

c.lastSeen.Store(time.Now().UnixNano())
c.processConfigurationRequest(in)

// skip processing http request if keep alive message.
if in.GetMessageId() == KeepAliveMessageID {
c.lastSeen.Store(time.Now().UnixNano())
c.log.Debugf("Received keep-alive message from castai for %s", in.GetClientMetadata().GetClusterId())
return
}
Expand All @@ -222,7 +214,11 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, r
} else {
c.log.Debugf("Proxied request msg_id=%v, sending response to castai", in.GetMessageId())
}
respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{

select {
case <-ctx.Done():
return
case respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{
Request: &cloudproxyv1alpha.StreamCloudProxyRequest_Response{
Response: &cloudproxyv1alpha.ClusterResponse{
ClientMetadata: &cloudproxyv1alpha.ClientMetadata{
Expand All @@ -233,17 +229,21 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, r
HttpResponse: resp,
},
},
}:
return
}
}

func (c *Client) processConfigurationRequest(in *cloudproxyv1alpha.StreamCloudProxyResponse) {
if in.ConfigurationRequest != nil {
if in.ConfigurationRequest.GetKeepAlive() != 0 {
c.keepAlive.Store(in.ConfigurationRequest.GetKeepAlive())
}
if in.ConfigurationRequest.GetKeepAliveTimeout() != 0 {
c.keepAliveTimeout.Store(in.ConfigurationRequest.GetKeepAliveTimeout())
}
if in.ConfigurationRequest == nil {
return
}

if in.ConfigurationRequest.GetKeepAlive() != 0 {
c.keepAlive.Store(in.ConfigurationRequest.GetKeepAlive())
}
if in.ConfigurationRequest.GetKeepAliveTimeout() != 0 {
c.keepAliveTimeout.Store(in.ConfigurationRequest.GetKeepAliveTimeout())
}
c.log.Debugf("Updated keep-alive configuration to %v and keep-alive timeout to %v", c.keepAlive.Load(), c.keepAliveTimeout.Load())
}
Expand Down Expand Up @@ -274,7 +274,6 @@ func (c *Client) processHTTPRequest(req *cloudproxyv1alpha.HTTPRequest) *cloudpr

func (c *Client) isAlive() bool {
lastSeen := c.lastSeen.Load()

return time.Now().UnixNano()-lastSeen <= c.keepAliveTimeout.Load()
}

Expand Down
20 changes: 6 additions & 14 deletions internal/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func TestClient_handleMessage(t *testing.T) {
<-msgStream
}()

c.handleMessage(tt.args.in, msgStream)
c.handleMessage(context.Background(), tt.args.in, msgStream)
require.Equal(t, tt.wantLastSeenUpdated, c.lastSeen.Load() > 0, "lastSeen: %v", c.lastSeen.Load())
require.Equal(t, tt.wantKeepAlive, c.keepAlive.Load(), "keepAlive: %v", c.keepAlive.Load())
require.Equal(t, tt.wantKeepAliveTimeout, c.keepAliveTimeout.Load(), "keepAliveTimeout: %v", c.keepAliveTimeout.Load())
Expand Down Expand Up @@ -438,9 +438,6 @@ func TestClient_run(t *testing.T) {
{
name: "send initial error",
args: args{
ctx: func() context.Context {
return context.Background()
},
tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) {
m.EXPECT().Send(gomock.Any()).Return(fmt.Errorf("test error"))
},
Expand All @@ -450,14 +447,12 @@ func TestClient_run(t *testing.T) {
{
name: "context done",
args: args{
ctx: func() context.Context {
tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return ctx
},
tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) {
m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times.
m.EXPECT().Context().Return(context.Background()).AnyTimes() // expected 0 or 1 times.

m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times.
m.EXPECT().Context().Return(ctx).AnyTimes() // expected 0 or 1 times.
},
},
wantLastSeenUpdated: true,
Expand All @@ -466,9 +461,6 @@ func TestClient_run(t *testing.T) {
{
name: "stream not alive",
args: args{
ctx: func() context.Context {
return context.Background()
},
tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) {
m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times.
m.EXPECT().Context().Return(context.Background()).AnyTimes() // expected 0 or 1 times.
Expand All @@ -491,7 +483,7 @@ func TestClient_run(t *testing.T) {
if tt.args.tuneMockStream != nil {
tt.args.tuneMockStream(stream)
}
if err := c.run(tt.args.ctx(), stream, func() {}); (err != nil) != tt.wantErr {
if err := c.run(stream, func() {}); (err != nil) != tt.wantErr {
t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr)
}
require.Equal(t, tt.wantLastSeenUpdated, c.lastSeen.Load() > 0, "lastSeen: %v", c.lastSeen.Load())
Expand Down

0 comments on commit 6d9a4c3

Please sign in to comment.