diff --git a/csi/csi.go b/csi/csi.go index 140cd8efe..165e34f47 100644 --- a/csi/csi.go +++ b/csi/csi.go @@ -19,7 +19,6 @@ package csi import ( "fmt" "os" - "strings" "sync" "time" @@ -98,6 +97,7 @@ type OsdCsiServer struct { allowInlineVolumes bool stopCleanupCh chan bool config *OsdCsiServerConfig + autoRecoverStopCh chan struct{} } // NewOsdCsiServer creates a gRPC CSI complient server on the @@ -244,8 +244,10 @@ func (s *OsdCsiServer) Start() error { } if s.config.Net == "unix" { + stopCh := make(chan struct{}) + s.autoRecoverStopCh = stopCh go func() { - err := autoSocketRecover(s) + err := autoSocketRecover(s, stopCh) if err != nil { logrus.Errorf("failed to start CSI driver socket auto-recover watcher: %v", err) } @@ -257,6 +259,7 @@ func (s *OsdCsiServer) Start() error { // Start is used to stop the server. func (s *OsdCsiServer) Stop() { + close(s.autoRecoverStopCh) s.GrpcServer.Stop() } @@ -297,12 +300,17 @@ func createGrpcServer(config *OsdCsiServerConfig) (*grpcserver.GrpcServer, error return gServer, nil } -func autoSocketRecover(s *OsdCsiServer) error { - socketPath := strings.TrimPrefix(s.Address(), "unix://") +func autoSocketRecover(s *OsdCsiServer, stopCh chan struct{}) error { + socketPath := s.Address() + ticker := time.NewTicker(csiSocketCheckInterval) // Start checking for CSI socket delete for { - time.Sleep(csiSocketCheckInterval) + select { + case <-stopCh: + return nil + case <-ticker.C: + } // Check if socket deleted _, err := os.Stat(socketPath)