diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index a997b58..ae2de10 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -53,8 +53,9 @@ func main() { MaxDelay: 5 * time.Second, Multiplier: 1.2, }, + MinConnectTimeout: 10 * time.Second, } - dialOpts = append(dialOpts, grpc.WithConnectParams(connectParams)) + dialOpts = append(dialOpts, grpc.WithConnectParams(connectParams), grpc.WithIdleTimeout(cfg.KeepAliveTimeout)) logger.Infof( "Creating grpc channel against (%s) with connection config (%+v) and TLS enabled=%v", diff --git a/internal/proxy/client.go b/internal/proxy/client.go index 41f6669..988a60f 100644 --- a/internal/proxy/client.go +++ b/internal/proxy/client.go @@ -75,14 +75,7 @@ func (c *Client) Run(ctx context.Context) error { return authCtx.Err() case <-t.C: c.log.Info("Starting proxy client") - stream, closeStream, err := c.getStream(authCtx) - if err != nil { - c.log.Errorf("Could not get stream, restarting proxy client in %vs: %v", time.Duration(c.keepAlive.Load()).Seconds(), err) - t.Reset(time.Duration(c.keepAlive.Load())) - continue - } - - err = c.run(authCtx, stream, closeStream) + err := c.run(authCtx) if err != nil { c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err) t.Reset(time.Duration(c.keepAlive.Load())) @@ -133,10 +126,17 @@ func (c *Client) sendInitialRequest(stream cloudproxyv1alpha.CloudProxyAPI_Strea return nil } -func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient, closeStream func()) error { +func (c *Client) run(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + stream, closeStream, err := c.getStream(ctx) + if err != nil { + return fmt.Errorf("c.getStream: %w", err) + } defer closeStream() - err := c.sendInitialRequest(stream) + err = c.sendInitialRequest(stream) if err != nil { return fmt.Errorf("c.Connect: %w", err) } @@ -189,6 +189,7 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI case req := <-messageRespCh: if err := stream.Send(req); err != nil { c.log.WithError(err).Warn("failed to send message response") + return fmt.Errorf("stream.Send: %w", err) } case <-time.After(time.Duration(c.keepAlive.Load())): if !c.isAlive() {