diff --git a/.gitignore b/.gitignore index 8694b17..b02a102 100644 --- a/.gitignore +++ b/.gitignore @@ -24,7 +24,7 @@ go.work go.work.sum # env file -.env +*.env # goland .idea diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 8a63ee6..d4dab82 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -1,8 +1,6 @@ package main import ( - "cloud-proxy/internal/cloud/gcp" - "cloud-proxy/internal/cloud/gcp/gcpauth" "context" "fmt" "net/http" @@ -10,15 +8,17 @@ import ( "runtime" "time" + "cloud-proxy/internal/cloud/gcp" + "cloud-proxy/internal/cloud/gcp/gcpauth" + "cloud-proxy/internal/config" + "cloud-proxy/internal/proxy" + + "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/backoff" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" - - "cloud-proxy/internal/config" - "cloud-proxy/internal/proxy" - "github.com/sirupsen/logrus" ) var ( @@ -30,22 +30,7 @@ var ( func main() { logrus.Info("Starting proxy") cfg := config.Get() - - logger := logrus.New() - logger.SetLevel(logrus.Level(cfg.Log.Level)) - logger.SetReportCaller(true) - logger.Formatter = &logrus.TextFormatter{ - CallerPrettyfier: func(f *runtime.Frame) (function string, file string) { - filename := path.Base(f.File) - return fmt.Sprintf("%s()", f.Function), fmt.Sprintf("%s:%d", filename, f.Line) - }, - } - - logger.WithFields(logrus.Fields{ - "GitCommit": GitCommit, - "GitRef": GitRef, - "Version": Version, - }).Info("Starting cloud-proxy") + logger := setupLogger(cfg) dialOpts := make([]grpc.DialOption, 0) if cfg.CastAI.DisableGRPCTLS { @@ -66,7 +51,7 @@ func main() { dialOpts = append(dialOpts, grpc.WithConnectParams(connectParams)) logger.Infof( - "Creating grpc channel against (%s) with connection config (%v) and TLS enabled=%v", + "Creating grpc channel against (%s) with connection config (%+v) and TLS enabled=%v", cfg.CastAI.GrpcURL, connectParams, !cfg.CastAI.DisableGRPCTLS, @@ -90,7 +75,8 @@ func main() { "authorization", fmt.Sprintf("Token %s", cfg.CastAI.ApiKey), )) - client := proxy.New(conn, gcp.New(gcpauth.NewCredentialsSource(), http.DefaultClient), logger, cfg.ClusterID, GetVersion()) + client := proxy.New(conn, gcp.New(gcpauth.NewCredentialsSource(), http.DefaultClient), logger, + cfg.ClusterID, GetVersion(), cfg.KeepAlive, cfg.KeepAliveTimeout) err = client.Run(ctx) if err != nil { logger.Panicf("Failed to run client: %v", err) @@ -101,3 +87,25 @@ func main() { func GetVersion() string { return fmt.Sprintf("GitCommit=%q GitRef=%q Version=%q", GitCommit, GitRef, Version) } + +func setupLogger(cfg config.Config) *logrus.Logger { + logger := logrus.New() + logger.SetLevel(logrus.Level(cfg.Log.Level)) + logger.SetReportCaller(true) + logger.Formatter = &logrus.TextFormatter{ + CallerPrettyfier: func(f *runtime.Frame) (function string, file string) { + filename := path.Base(f.File) + return fmt.Sprintf("%s()", f.Function), fmt.Sprintf("%s:%d", filename, f.Line) + }, + TimestampFormat: time.RFC3339, + FullTimestamp: true, + } + + logger.WithFields(logrus.Fields{ + "GitCommit": GitCommit, + "GitRef": GitRef, + "Version": Version, + }).Infof("Starting cloud-proxy: %+v", cfg) + + return logger +} diff --git a/go.mod b/go.mod index ffe942a..640be46 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( google.golang.org/api v0.197.0 google.golang.org/grpc v1.66.2 google.golang.org/protobuf v1.34.2 + sigs.k8s.io/controller-runtime v0.19.0 ) require ( @@ -49,8 +50,7 @@ require ( go.opentelemetry.io/otel v1.29.0 // indirect go.opentelemetry.io/otel/metric v1.29.0 // indirect go.opentelemetry.io/otel/trace v1.29.0 // indirect - go.uber.org/atomic v1.9.0 // indirect - go.uber.org/multierr v1.9.0 // indirect + go.uber.org/multierr v1.11.0 // indirect golang.org/x/crypto v0.27.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/net v0.29.0 // indirect diff --git a/go.sum b/go.sum index 5211c9f..f621643 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -60,6 +62,8 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af h1:kmjWCqn2qkEml422C2Rrd27c3VGxi6a/6HNq8QmHRKM= +github.com/google/pprof v0.0.0-20240525223248-4bfdf5a9a2af/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo= github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM= github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -79,6 +83,10 @@ github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0V github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/onsi/ginkgo/v2 v2.19.0 h1:9Cnnf7UHo57Hy3k6/m5k3dRfGTMXGvxhHFvkDTCTpvA= +github.com/onsi/ginkgo/v2 v2.19.0/go.mod h1:rlwLi9PilAFJ8jCg9UE1QP6VBpd6/xj3SRC0d6TU0To= +github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= +github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -109,7 +117,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -134,10 +141,8 @@ go.opentelemetry.io/otel/sdk v1.28.0 h1:b9d7hIry8yZsgtbmM0DKyPWMMUMlK9NEKuIG4aBq go.opentelemetry.io/otel/sdk v1.28.0/go.mod h1:oYj7ClPUA7Iw3m+r7GeEjz0qckQRJK2B8zjcZEfu7Pg= go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= -go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= -go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= +go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= @@ -193,6 +198,8 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -238,3 +245,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +sigs.k8s.io/controller-runtime v0.19.0 h1:nWVM7aq+Il2ABxwiCizrVDSlmDcshi9llbaFbC0ji/Q= +sigs.k8s.io/controller-runtime v0.19.0/go.mod h1:iRmWllt8IlaLjvTTDLhRBXIEtkCK6hwVBJJsYS9Ajf4= diff --git a/internal/config/config.go b/internal/config/config.go index 97257d3..6a5af2d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,14 +2,23 @@ package config import ( "fmt" + "time" "github.com/sirupsen/logrus" "github.com/spf13/viper" ) +const ( + KeepAliveDefault = 10 * time.Second + KeepAliveTimeoutDefault = time.Minute +) + type Config struct { - CastAI CastAPI `mapstructure:"cast"` - ClusterID string `mapstructure:"clusterid"` + CastAI CastAPI `mapstructure:"cast"` + ClusterID string `mapstructure:"clusterid"` + KeepAlive time.Duration `mapstructure:"keepalive"` + KeepAliveTimeout time.Duration `mapstructure:"keepalivetimeout"` + PodMetadata PodMetadata `mapstructure:"podmetadata"` //MetricsAddress string `mapstructure:"metricsaddress"` @@ -71,6 +80,9 @@ func Get() Config { v.MustBindEnv("podmetadata.nodename", "NODE_NAME") v.MustBindEnv("podmetadata.podname", "POD_NAME") + _ = v.BindEnv("keepalive", "KEEP_ALIVE") + _ = v.BindEnv("keepalivetimeout", "KEEP_ALIVE_TIMEOUT") + _ = v.BindEnv("log.level", "LOG_LEVEL") cfg = &Config{} @@ -95,6 +107,17 @@ func Get() Config { cfg.Log.Level = int(logrus.InfoLevel) } + if cfg.KeepAlive == 0 { + cfg.KeepAlive = KeepAliveDefault + } + if cfg.KeepAliveTimeout == 0 { + if cfg.KeepAlive < KeepAliveTimeoutDefault { + cfg.KeepAliveTimeout = KeepAliveTimeoutDefault + } else { + cfg.KeepAliveTimeout = cfg.KeepAlive * 4 + } + } + return *cfg } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index ab76c60..69df649 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -45,6 +45,8 @@ func TestConfig(t *testing.T) { Log: Log{ Level: 3, }, + KeepAlive: KeepAliveDefault, + KeepAliveTimeout: KeepAliveTimeoutDefault, } got := Get() diff --git a/internal/e2etest/roundtripper.go b/internal/e2etest/roundtripper.go index 7914d77..80c041b 100644 --- a/internal/e2etest/roundtripper.go +++ b/internal/e2etest/roundtripper.go @@ -2,12 +2,13 @@ package e2etest import ( "bytes" - cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha" "fmt" "io" "log" "net/http" + cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha" + "github.com/google/uuid" ) diff --git a/internal/proxy/client.go b/internal/proxy/client.go index ac5b98d..2309028 100644 --- a/internal/proxy/client.go +++ b/internal/proxy/client.go @@ -5,22 +5,22 @@ package proxy import ( "bytes" - cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha" "context" "fmt" - "github.com/samber/lo" - "github.com/sirupsen/logrus" - "google.golang.org/grpc" "io" "net/http" "sync/atomic" "time" + + "github.com/samber/lo" + "github.com/sirupsen/logrus" + "google.golang.org/grpc" + + cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha" ) const ( - KeepAliveMessageID = "keep-alive" - KeepAliveDefault = 10 * time.Second - KeepAliveTimeoutDefault = time.Minute + KeepAliveMessageID = "keep-alive" ) type CloudClient interface { @@ -37,12 +37,13 @@ type Client struct { processedCount atomic.Int64 lastSeen atomic.Int64 + lastSeenError atomic.Pointer[error] keepAlive atomic.Int64 keepAliveTimeout atomic.Int64 version string } -func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, clusterID, version string) *Client { +func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logger, clusterID, version string, keepalive, keepaliveTimeout time.Duration) *Client { c := &Client{ grpcConn: grpcConn, cloudClient: cloudClient, @@ -50,27 +51,32 @@ func New(grpcConn *grpc.ClientConn, cloudClient CloudClient, logger *logrus.Logg clusterID: clusterID, version: version, } - c.keepAlive.Store(int64(KeepAliveDefault)) - c.keepAliveTimeout.Store(int64(KeepAliveTimeoutDefault)) + c.keepAlive.Store(int64(keepalive)) + c.keepAliveTimeout.Store(int64(keepaliveTimeout)) return c } func (c *Client) Run(ctx context.Context) error { + t := time.NewTimer(time.Millisecond) + for { - if ctx.Err() != nil { - return nil - } - c.log.Info("Starting proxy client") - stream, closeStream, err := c.getStream() - if err != nil { - c.log.Errorf("c.getStream: %v", err) - time.Sleep(time.Second) - continue - } - err = c.run(ctx, stream, closeStream) - if err != nil { - c.log.Errorf("c.run exited: %v", err) + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + c.log.Info("Starting proxy client") + stream, closeStream, err := c.getStream() + if err != nil { + c.log.Errorf("Could not get stream, restarting proxy client in %vs: %v", time.Duration(c.keepAlive.Load()).Seconds(), err) + t.Reset(time.Duration(c.keepAlive.Load())) + continue + } + err = c.run(ctx, stream, closeStream) + if err != nil { + c.log.Errorf("Restarting proxy client in %vs: due to error: %v", time.Duration(c.keepAlive.Load()).Seconds(), err) + t.Reset(time.Duration(c.keepAlive.Load())) + } } } } @@ -109,6 +115,7 @@ func (c *Client) sendInitialRequest(stream cloudproxyv1alpha.CloudProxyAPI_Strea return fmt.Errorf("stream.Send: initial request %w", err) } c.lastSeen.Store(time.Now().UnixNano()) + c.lastSeenError.Store(nil) c.log.Info("Stream to castai started successfully") @@ -123,25 +130,50 @@ func (c *Client) run(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI return fmt.Errorf("c.Connect: %w", err) } - ctxWithCancel, cancel := context.WithCancel(ctx) - defer cancel() - go c.sendKeepAlive(ctxWithCancel, stream) + go c.sendKeepAlive(stream) + + go func() { + for { + select { + case <-ctx.Done(): + return + case <-stream.Context().Done(): + return + default: + if !c.isAlive() { + return + } + } + + c.log.Debugf("Polling stream for messages") + + in, err := stream.Recv() + if err != nil { + c.log.Errorf("stream.Recv: got error: %v", err) + c.lastSeen.Store(0) + c.lastSeenError.Store(&err) + return + } + + c.log.Debugf("Handling message from castai") + go c.handleMessage(in, stream) + } + }() for { - if ctx.Err() != nil { + select { + case <-ctx.Done(): return ctx.Err() + case <-stream.Context().Done(): + return fmt.Errorf("stream closed %w", stream.Context().Err()) + case <-time.After(time.Duration(c.keepAlive.Load())): + if !c.isAlive() { + if err := c.lastSeenError.Load(); err != nil { + return fmt.Errorf("recived error: %w", *err) + } + return fmt.Errorf("last seen too old, closing stream") + } } - if !c.isAlive() { - return fmt.Errorf("last seen too old, closing stream") - } - c.log.Info("Polling stream for messages") - in, err := stream.Recv() - if err != nil { - return fmt.Errorf("stream.Recv: %w", err) - } - - c.log.Info("Handling message from castai") - go c.handleMessage(in, stream) } } @@ -226,21 +258,21 @@ func (c *Client) isAlive() bool { return true } -func (c *Client) sendKeepAlive(ctx context.Context, stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) { +func (c *Client) sendKeepAlive(stream cloudproxyv1alpha.CloudProxyAPI_StreamCloudProxyClient) { ticker := time.NewTimer(time.Duration(c.keepAlive.Load())) defer ticker.Stop() c.log.Info("Starting keep-alive loop") for { - if !c.isAlive() { - c.log.Info("Stopping keep-alive loop: client connection is not alive") - return - } select { - case <-ctx.Done(): - c.log.Infof("Stopping keep-alive loop: context ended with %v", context.Cause(ctx)) + case <-stream.Context().Done(): + c.log.Infof("Stopping keep-alive loop: stream ended with %v", stream.Context().Err()) return case <-ticker.C: + if !c.isAlive() { + c.log.Info("Stopping keep-alive loop: client connection is not alive") + return + } c.log.Debug("Sending keep-alive to castai") err := stream.Send(&cloudproxyv1alpha.StreamCloudProxyRequest{ Request: &cloudproxyv1alpha.StreamCloudProxyRequest_ClientStats{ diff --git a/internal/proxy/client_test.go b/internal/proxy/client_test.go index 3d573d3..0a1c446 100644 --- a/internal/proxy/client_test.go +++ b/internal/proxy/client_test.go @@ -2,20 +2,23 @@ package proxy import ( "bytes" - mock_proxy "cloud-proxy/internal/proxy/mock" - cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha" "context" "fmt" - "github.com/golang/mock/gomock" - "github.com/samber/lo" - "github.com/sirupsen/logrus" - "github.com/stretchr/testify/require" "io" "net/http" "net/url" "reflect" "testing" "time" + + "github.com/golang/mock/gomock" + "github.com/samber/lo" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "cloud-proxy/internal/config" + mock_proxy "cloud-proxy/internal/proxy/mock" + cloudproxyv1alpha "cloud-proxy/proto/gen/proto/v1alpha" ) type mockReadCloserErr struct{} @@ -79,7 +82,7 @@ func TestClient_toResponse(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - c := New(nil, nil, nil, "clusterID", "version") + c := New(nil, nil, nil, "clusterID", "version", time.Second, time.Minute) got := c.toResponse(tt.args.resp) //diff := cmp.Diff(got, tt.want, protocmp.Transform()) //require.Empty(t, diff) @@ -142,7 +145,7 @@ func TestClient_toHTTPRequest(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - c := New(nil, nil, nil, "clusterID", "version") + c := New(nil, nil, nil, "clusterID", "version", time.Second, time.Minute) got, err := c.toHTTPRequest(tt.args.req) require.Equal(t, tt.wantErr, err != nil, err) if err != nil { @@ -177,8 +180,8 @@ func TestClient_handleMessage(t *testing.T) { { name: "nil response", wantLastSeenUpdated: false, - wantKeepAlive: int64(KeepAliveDefault), - wantKeepAliveTimeout: int64(KeepAliveTimeoutDefault), + wantKeepAlive: int64(config.KeepAliveDefault), + wantKeepAliveTimeout: int64(config.KeepAliveTimeoutDefault), }, { name: "keep alive", @@ -188,8 +191,8 @@ func TestClient_handleMessage(t *testing.T) { }, }, wantLastSeenUpdated: true, - wantKeepAlive: int64(KeepAliveDefault), - wantKeepAliveTimeout: int64(KeepAliveTimeoutDefault), + wantKeepAlive: int64(config.KeepAliveDefault), + wantKeepAliveTimeout: int64(config.KeepAliveTimeoutDefault), }, { name: "keep alive timeout and keepalive", @@ -235,8 +238,8 @@ func TestClient_handleMessage(t *testing.T) { }, }, wantLastSeenUpdated: false, - wantKeepAlive: int64(KeepAliveDefault), - wantKeepAliveTimeout: int64(KeepAliveTimeoutDefault), + wantKeepAlive: int64(config.KeepAliveDefault), + wantKeepAliveTimeout: int64(config.KeepAliveTimeoutDefault), wantErrCount: 1, }, } @@ -250,7 +253,7 @@ func TestClient_handleMessage(t *testing.T) { if tt.fields.tuneMockCloudClient != nil { tt.fields.tuneMockCloudClient(cloudClient) } - c := New(nil, cloudClient, logrus.New(), "clusterID", "version") + c := New(nil, cloudClient, logrus.New(), "clusterID", "version", config.KeepAliveDefault, config.KeepAliveTimeoutDefault) stream := mock_proxy.NewMockCloudProxyAPI_StreamCloudProxyClient(ctrl) if tt.args.tuneMockStream != nil { tt.args.tuneMockStream(stream) @@ -335,7 +338,7 @@ func TestClient_processHttpRequest(t *testing.T) { if tt.fields.tuneMockCloudClient != nil { tt.fields.tuneMockCloudClient(cloudClient) } - c := New(nil, cloudClient, logrus.New(), "clusterID", "version") + c := New(nil, cloudClient, logrus.New(), "clusterID", "version", time.Second, time.Minute) if got := c.processHttpRequest(tt.args.req); !reflect.DeepEqual(got, tt.want) { t.Errorf("processHttpRequest() = %v, want %v", got, tt.want) } @@ -348,7 +351,6 @@ func TestClient_sendKeepAlive(t *testing.T) { t.Parallel() type args struct { - ctx func() context.Context tuneMockStream func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) keepAlive int64 keepAliveTimeout int64 @@ -361,32 +363,19 @@ func TestClient_sendKeepAlive(t *testing.T) { { name: "end of ticker", args: args{ - ctx: func() context.Context { - return context.Background() - }, keepAlive: 0, - }, - }, - { - name: "context done", - args: args{ - ctx: func() context.Context { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - return ctx + tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) { + m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() + m.EXPECT().Context().Return(context.Background()).AnyTimes() }, - keepAlive: int64(time.Second), - keepAliveTimeout: int64(10 * time.Minute), }, }, { name: "send returned error, should exit", args: args{ - ctx: func() context.Context { - return context.Background() - }, tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) { m.EXPECT().Send(gomock.Any()).Return(fmt.Errorf("error")) + m.EXPECT().Context().Return(context.Background()).AnyTimes() }, keepAlive: int64(time.Second), keepAliveTimeout: int64(10 * time.Minute), @@ -401,7 +390,7 @@ func TestClient_sendKeepAlive(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - c := New(nil, nil, logrus.New(), "clusterID", "version") + c := New(nil, nil, logrus.New(), "clusterID", "version", config.KeepAliveDefault, config.KeepAliveTimeoutDefault) c.keepAlive.Store(tt.args.keepAlive) c.keepAliveTimeout.Store(tt.args.keepAliveTimeout) @@ -411,7 +400,7 @@ func TestClient_sendKeepAlive(t *testing.T) { } c.lastSeen.Store(time.Now().UnixNano()) - c.sendKeepAlive(tt.args.ctx(), stream) + c.sendKeepAlive(stream) require.Equal(t, tt.isLastSeenZero, c.lastSeen.Load() == 0, "lastSeen: %v", c.lastSeen.Load()) }) } @@ -419,15 +408,13 @@ func TestClient_sendKeepAlive(t *testing.T) { func TestClient_run(t *testing.T) { t.Parallel() - type fields struct { - } + type args struct { ctx func() context.Context tuneMockStream func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) } tests := []struct { name string - fields fields args args wantErr bool wantLastSeenUpdated bool @@ -453,12 +440,28 @@ func TestClient_run(t *testing.T) { return ctx }, tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) { - m.EXPECT().Send(gomock.Any()).Return(nil) + m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times + m.EXPECT().Context().Return(context.Background()).AnyTimes() // expected 0 or 1 times }, }, wantLastSeenUpdated: true, wantErr: true, }, + { + name: "stream not alive", + args: args{ + ctx: func() context.Context { + return context.Background() + }, + tuneMockStream: func(m *mock_proxy.MockCloudProxyAPI_StreamCloudProxyClient) { + m.EXPECT().Send(gomock.Any()).Return(nil).AnyTimes() // expected 0 or 1 times + m.EXPECT().Context().Return(context.Background()).AnyTimes() // expected 0 or 1 times + m.EXPECT().Recv().Return(nil, fmt.Errorf("test error")) + }, + }, + wantLastSeenUpdated: false, + wantErr: true, + }, } for _, tt := range tests { tt := tt @@ -467,7 +470,7 @@ func TestClient_run(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - c := New(nil, nil, logrus.New(), "clusterID", "version") + c := New(nil, nil, logrus.New(), "clusterID", "version", time.Second, time.Second) stream := mock_proxy.NewMockCloudProxyAPI_StreamCloudProxyClient(ctrl) if tt.args.tuneMockStream != nil { tt.args.tuneMockStream(stream)