Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gRCP improvments #18

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 30 additions & 31 deletions internal/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func (c *Client) Run(ctx context.Context) error {
continue
}

err = c.run(authCtx, stream, closeStream)
err = c.run(stream, closeStream)
if err != nil {
c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err)
t.Reset(time.Duration(c.keepAlive.Load()))
Expand Down Expand Up @@ -133,26 +133,21 @@ func (c *Client) sendInitialRequest(stream cloudproxyv1alpha.CloudProxyAPI_Strea
return nil
}

func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, closeStream func()) error {
func (c *Client) run(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, closeStream func()) error {
defer closeStream()

err := c.sendInitialRequest(stream)
if err != nil {
return fmt.Errorf("c.Connect: %w", err)
}

keepAliveCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(keepAliveCh)
go c.sendKeepAlive(stream, keepAliveCh)

messageRespCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest)
defer close(messageRespCh)
sendCh := make(chan *cloudproxyv1alpha.StreamCloudProxyRequest, 10)
go c.sendKeepAlive(stream, sendCh)
defer close(sendCh)

go func() {
for {
select {
case <-ctx.Done():
return
case <-stream.Context().Done():
return
default:
Expand All @@ -162,7 +157,6 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}

c.log.Debugf("Polling stream for messages")

in, err := stream.Recv()
if err != nil {
c.log.Errorf("stream.Recv: got error: %v", err)
Expand All @@ -172,24 +166,21 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}

c.log.Debugf("Handling message from castai")
go c.handleMessage(in, messageRespCh)
go c.handleMessage(stream.Context(), in, sendCh)
}
}()

for {
select {
case <-ctx.Done():
return ctx.Err()
case <-stream.Context().Done():
return fmt.Errorf("stream closed %w", stream.Context().Err())
case req := <-keepAliveCh:
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send keep alive")
}
case req := <-messageRespCh:

case req := <-sendCh:
c.log.Debugf("Sending message to stream")
if err := stream.Send(req); err != nil {
c.log.WithError(err).Warn("failed to send message response")
return fmt.Errorf("failed to send gRPC message: %w", err)
}

case <-time.After(time.Duration(c.keepAlive.Load())):
if !c.isAlive() {
if err := c.lastSeenError.Load(); err != nil {
Expand All @@ -201,16 +192,17 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI
}
}

func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
func (c *Client) handleMessage(ctx context.Context, in *cloudproxyv1alpha.StreamCloudProxyResponse, respCh chan<- *cloudproxyv1alpha.StreamCloudProxyRequest) {
if in == nil {
c.log.Error("nil message")
return
}

c.lastSeen.Store(time.Now().UnixNano())
c.processConfigurationRequest(in)

// skip processing http request if keep alive message.
if in.GetMessageId() == KeepAliveMessageID {
c.lastSeen.Store(time.Now().UnixNano())
c.log.Debugf("Received keep-alive message from castai for %s", in.GetClientMetadata().GetClusterId())
return
}
Expand All @@ -222,7 +214,11 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, r
} else {
c.log.Debugf("Proxied request msg_id=%v, sending response to castai", in.GetMessageId())
}
respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{

select {
case <-ctx.Done():
return
case respCh <- &cloudproxyv1alpha.StreamCloudProxyRequest{
Request: &cloudproxyv1alpha.StreamCloudProxyRequest_Response{
Response: &cloudproxyv1alpha.ClusterResponse{
ClientMetadata: &cloudproxyv1alpha.ClientMetadata{
Expand All @@ -233,17 +229,21 @@ func (c *Client) handleMessage(in *cloudproxyv1alpha.StreamCloudProxyResponse, r
HttpResponse: resp,
},
},
}:
return
}
}

func (c *Client) processConfigurationRequest(in *cloudproxyv1alpha.StreamCloudProxyResponse) {
if in.ConfigurationRequest != nil {
if in.ConfigurationRequest.GetKeepAlive() != 0 {
c.keepAlive.Store(in.ConfigurationRequest.GetKeepAlive())
}
if in.ConfigurationRequest.GetKeepAliveTimeout() != 0 {
c.keepAliveTimeout.Store(in.ConfigurationRequest.GetKeepAliveTimeout())
}
if in.ConfigurationRequest == nil {
return
}

if in.ConfigurationRequest.GetKeepAlive() != 0 {
c.keepAlive.Store(in.ConfigurationRequest.GetKeepAlive())
}
if in.ConfigurationRequest.GetKeepAliveTimeout() != 0 {
c.keepAliveTimeout.Store(in.ConfigurationRequest.GetKeepAliveTimeout())
}
c.log.Debugf("Updated keep-alive configuration to %v and keep-alive timeout to %v", c.keepAlive.Load(), c.keepAliveTimeout.Load())
}
Expand Down Expand Up @@ -274,7 +274,6 @@ func (c *Client) processHTTPRequest(req *cloudproxyv1alpha.HTTPRequest) *cloudpr

func (c *Client) isAlive() bool {
lastSeen := c.lastSeen.Load()

return time.Now().UnixNano()-lastSeen <= c.keepAliveTimeout.Load()
}

Expand Down
20 changes: 6 additions & 14 deletions internal/proxy/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@
<-msgStream
}()

c.handleMessage(tt.args.in, msgStream)
c.handleMessage(context.Background(), tt.args.in, msgStream)
require.Equal(t, tt.wantLastSeenUpdated, c.lastSeen.Load() > 0, "lastSeen: %v", c.lastSeen.Load())
require.Equal(t, tt.wantKeepAlive, c.keepAlive.Load(), "keepAlive: %v", c.keepAlive.Load())
require.Equal(t, tt.wantKeepAliveTimeout, c.keepAliveTimeout.Load(), "keepAliveTimeout: %v", c.keepAliveTimeout.Load())
Expand Down Expand Up @@ -426,7 +426,7 @@
t.Parallel()

type args struct {
ctx func() context.Context

Check failure on line 429 in internal/proxy/client_test.go

View workflow job for this annotation

GitHub Actions / lint

field `ctx` is unused (unused)
tuneMockStream func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient)
}
tests := []struct {
Expand All @@ -438,9 +438,6 @@
{
name: "send initial error",
args: args{
ctx: func() context.Context {
return context.Background()
},
tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) {
m.EXPECT().Send(gomock.Any()).Return(fmt.Errorf("test error"))
},
Expand All @@ -450,14 +447,12 @@
{
name: "context done",
args: args{
ctx: func() context.Context {
tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return ctx
},
tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) {
m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times.
m.EXPECT().Context().Return(context.Background()).AnyTimes() // expected 0 or 1 times.

m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times.
m.EXPECT().Context().Return(ctx).AnyTimes() // expected 0 or 1 times.
},
},
wantLastSeenUpdated: true,
Expand All @@ -466,9 +461,6 @@
{
name: "stream not alive",
args: args{
ctx: func() context.Context {
return context.Background()
},
tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) {
m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times.
m.EXPECT().Context().Return(context.Background()).AnyTimes() // expected 0 or 1 times.
Expand All @@ -491,7 +483,7 @@
if tt.args.tuneMockStream != nil {
tt.args.tuneMockStream(stream)
}
if err := c.run(tt.args.ctx(), stream, func() {}); (err != nil) != tt.wantErr {
if err := c.run(stream, func() {}); (err != nil) != tt.wantErr {
t.Errorf("run() error = %v, wantErr %v", err, tt.wantErr)
}
require.Equal(t, tt.wantLastSeenUpdated, c.lastSeen.Load() > 0, "lastSeen: %v", c.lastSeen.Load())
Expand Down
Loading