diff --git a/csi/csi.go b/csi/csi.go index a19931cf4..5c5bf739f 100644 --- a/csi/csi.go +++ b/csi/csi.go @@ -18,7 +18,9 @@ package csi import ( "fmt" + "os" "sync" + "time" csi "github.com/container-storage-interface/spec/lib/go/csi" "github.com/libopenstorage/openstorage/api" @@ -44,7 +46,8 @@ import ( ) var ( - clogger *logrus.Logger + clogger *logrus.Logger + csiSocketCheckInterval = 30 * time.Second ) func init() { @@ -93,6 +96,8 @@ type OsdCsiServer struct { csiDriverName string allowInlineVolumes bool stopCleanupCh chan bool + config *OsdCsiServerConfig + autoRecoverStopCh chan struct{} } // NewOsdCsiServer creates a gRPC CSI complient server on the @@ -116,37 +121,9 @@ func NewOsdCsiServer(config *OsdCsiServerConfig) (grpcserver.Server, error) { return nil, fmt.Errorf("Unable to get driver %s info: %s", config.DriverName, err.Error()) } - // create correlation interceptor - var unaryInterceptors []grpc.UnaryServerInterceptor - correlationInterceptor := correlation.ContextInterceptor{ - Origin: correlation.ComponentCSIDriver, - } - opts := make([]grpc.ServerOption, 0) - unaryInterceptors = append(unaryInterceptors, correlationInterceptor.ContextUnaryServerInterceptor) - - // create scheduler interceptor - switch config.SchedulerName { - case "kubernetes": - logrus.Infof("CSI K8s filter being added for %s scheduler", config.SchedulerName) - ki := k8s.NewInterceptor() - unaryInterceptors = append(unaryInterceptors, ki.SchedUnaryInterceptor) - - default: - logrus.Infof("No CSI filter being added for %s scheduler", config.SchedulerName) - } - - // Add interceptors - opts = append(opts, grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(unaryInterceptors...))) - - // Create server - gServer, err := grpcserver.New(&grpcserver.GrpcServerConfig{ - Name: "CSI 1.7", - Net: config.Net, - Address: config.Address, - Opts: opts, - }) + gServer, err := createGrpcServer(config) if err != nil { - return nil, fmt.Errorf("Failed to create CSI server: %v", err) + return nil, err } return &OsdCsiServer{ @@ -160,6 +137,8 @@ func NewOsdCsiServer(config *OsdCsiServerConfig) (grpcserver.Server, error) { allowInlineVolumes: config.EnableInlineVolumes, roundRobinBalancer: config.RoundRobinBalancer, cloudBackupClient: api.NewOpenStorageCloudBackupClient, + config: config, + autoRecoverStopCh: make(chan struct{}), }, nil } @@ -257,18 +236,110 @@ func (s *OsdCsiServer) addEncryptionInfoToLabels(labels, csiSecrets map[string]s // Start is used to start the server. // It will return an error if the server is already running. func (s *OsdCsiServer) Start() error { - return s.GrpcServer.Start(func(grpcServer *grpc.Server) { + if err := s.GrpcServer.Start(func(grpcServer *grpc.Server) { csi.RegisterIdentityServer(grpcServer, s) csi.RegisterControllerServer(grpcServer, s) csi.RegisterNodeServer(grpcServer, s) - }) + }); err != nil { + return err + } + + if s.config.Net == "unix" { + go func() { + err := autoSocketRecover(s, s.autoRecoverStopCh) + if err != nil { + logrus.Errorf("failed to start CSI driver socket auto-recover watcher: %v", err) + } + }() + } + + return nil } // Start is used to stop the server. func (s *OsdCsiServer) Stop() { + close(s.autoRecoverStopCh) s.GrpcServer.Stop() } +func createGrpcServer(config *OsdCsiServerConfig) (*grpcserver.GrpcServer, error) { + // create correlation interceptor + var unaryInterceptors []grpc.UnaryServerInterceptor + correlationInterceptor := correlation.ContextInterceptor{ + Origin: correlation.ComponentCSIDriver, + } + opts := make([]grpc.ServerOption, 0) + unaryInterceptors = append(unaryInterceptors, correlationInterceptor.ContextUnaryServerInterceptor) + + // create scheduler interceptor + switch config.SchedulerName { + case "kubernetes": + logrus.Infof("CSI K8s filter being added for %s scheduler", config.SchedulerName) + ki := k8s.NewInterceptor() + unaryInterceptors = append(unaryInterceptors, ki.SchedUnaryInterceptor) + + default: + logrus.Infof("No CSI filter being added for %s scheduler", config.SchedulerName) + } + + // Add interceptors + opts = append(opts, grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(unaryInterceptors...))) + + // Create server + gServer, err := grpcserver.New(&grpcserver.GrpcServerConfig{ + Name: "CSI 1.7", + Net: config.Net, + Address: config.Address, + Opts: opts, + }) + if err != nil { + return nil, fmt.Errorf("Failed to create CSI server: %v", err) + } + + return gServer, nil +} + +func autoSocketRecover(s *OsdCsiServer, stopCh chan struct{}) error { + socketPath := s.Address() + ticker := time.NewTicker(csiSocketCheckInterval) + + // Start checking for CSI socket delete + for { + select { + case <-stopCh: + return nil + case <-ticker.C: + } + + // Check if socket deleted + _, err := os.Stat(socketPath) + if err == nil { + continue + } + + logrus.Infof("Detected CSI socket deleted at path %s. Stopping CSI gRPC server", socketPath) + s.GrpcServer.Stop() + + // Re-create gRPC server + gServer, err := createGrpcServer(s.config) + if err != nil { + logrus.Errorf("failed to re-create gRPC server: %v. Retrying in %s...", err, csiSocketCheckInterval) + continue + } + s.GrpcServer = gServer + + // Start server + logrus.Infof("Restarting CSI gRPC server at %s", socketPath) + if err := s.Start(); err != nil { + logrus.Errorf("CSI server failed to auto-recover after socket deletion: %v. Retrying in %s...", err, csiSocketCheckInterval) + continue + } + + // Exit for next process to start + return nil + } +} + // adjustFinalErrors adjusts certain gRPC status to make CSI callers // (csi-provisioner, kubelet, etc) retry instead of being marked as a failure. // See https://github.com/kubernetes/kubernetes/blob/64ed9145452d2d1d324d2437566f1ea1ce76f226/pkg/volume/csi/csi_client.go#L718-L724 diff --git a/csi/csi_test.go b/csi/csi_test.go index 4a1400d14..6a1e7fc8a 100644 --- a/csi/csi_test.go +++ b/csi/csi_test.go @@ -52,9 +52,10 @@ import ( ) const ( - mockDriverName = "mock" - testSharedSecret = "mysecret" - fakeWithSched = "fake-sched" + mockDriverName = "mock" + testSharedSecret = "mysecret" + fakeWithSched = "fake-sched" + testSocketLocation = "/tmp/csi-ut.sock" ) var ( @@ -141,6 +142,15 @@ func newTestServer(t *testing.T) *testServer { }) } +func newUDSTestServer(t *testing.T) *testServer { + os.Remove(testSocketLocation) + return newTestServerWithConfig(t, &OsdCsiServerConfig{ + DriverName: mockDriverName, + Address: testSocketLocation, + Net: "unix", + }) +} + func newTestServerWithConfig(t *testing.T, config *OsdCsiServerConfig) *testServer { tester := &testServer{} tester.setPorts() @@ -205,8 +215,13 @@ func newTestServerWithConfig(t *testing.T, config *OsdCsiServerConfig) *testServ }) // Setup CSI simple driver - config.Net = "tcp" - config.Address = "127.0.0.1:0" + // Allow for net and address to be overwritten + if config.Net == "" { + config.Net = "tcp" + } + if config.Address == "" { + config.Address = "127.0.0.1:0" + } config.SdkUds = tester.uds config.SdkPort = tester.port tester.server, err = NewOsdCsiServer(config) @@ -446,3 +461,34 @@ func TestCSIServerStartContextInterceptor(t *testing.T) { expectedInfoLog = "csi-driver" assert.Contains(t, logStr, expectedInfoLog) } + +func TestCSISocketAutoRecover(t *testing.T) { + csiSocketCheckInterval = 1 * time.Second + + // Start server and wait for socket to be up and running + s := newUDSTestServer(t) + assert.True(t, s.Server().IsRunning()) + defer func() { + s.Stop() + }() + assert.Eventually(t, s.server.IsRunning, 30*time.Second, time.Second) + assert.Eventually(t, func() bool { + _, err := os.Stat(testSocketLocation) + return err == nil + }, 30*time.Second, time.Second) + _, err := os.Stat(testSocketLocation) + assert.NoError(t, err, "UDS should exist after startup") + + // Delete socket and wait for it to be gone + err = os.Remove(testSocketLocation) + assert.NoError(t, err) + + // Wait for auto-recover + assert.Eventually(t, func() bool { + _, err := os.Stat(testSocketLocation) + return err == nil + }, 30*time.Second, time.Second) + assert.True(t, s.server.IsRunning(), "Server should be running after autorecover") + _, err = os.Stat(testSocketLocation) + assert.NoError(t, err, "UDS should exist after autorecover") +} diff --git a/pkg/grpcserver/grpcserver.go b/pkg/grpcserver/grpcserver.go index 6c753d300..df03aa3d8 100644 --- a/pkg/grpcserver/grpcserver.go +++ b/pkg/grpcserver/grpcserver.go @@ -153,6 +153,11 @@ func (s *GrpcServer) IsRunning() bool { return s.running } +// Listener returns the listener used for this gRPC server +func (s *GrpcServer) Listener() net.Listener { + return s.listener +} + func (s *GrpcServer) goServe(started chan<- bool) { s.wg.Add(1) go func() {