Skip to content

Commit

Permalink
Merge branch 'master' into mountinfo
Browse files Browse the repository at this point in the history
  • Loading branch information
pnookala-px authored Sep 14, 2023
2 parents 304e255 + 6b9d48d commit b1b978c
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 38 deletions.
137 changes: 104 additions & 33 deletions csi/csi.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package csi

import (
"fmt"
"os"
"sync"
"time"

csi "github.com/container-storage-interface/spec/lib/go/csi"
"github.com/libopenstorage/openstorage/api"
Expand All @@ -44,7 +46,8 @@ import (
)

var (
clogger *logrus.Logger
clogger *logrus.Logger
csiSocketCheckInterval = 30 * time.Second
)

func init() {
Expand Down Expand Up @@ -93,6 +96,8 @@ type OsdCsiServer struct {
csiDriverName string
allowInlineVolumes bool
stopCleanupCh chan bool
config *OsdCsiServerConfig
autoRecoverStopCh chan struct{}
}

// NewOsdCsiServer creates a gRPC CSI complient server on the
Expand All @@ -116,37 +121,9 @@ func NewOsdCsiServer(config *OsdCsiServerConfig) (grpcserver.Server, error) {
return nil, fmt.Errorf("Unable to get driver %s info: %s", config.DriverName, err.Error())
}

// create correlation interceptor
var unaryInterceptors []grpc.UnaryServerInterceptor
correlationInterceptor := correlation.ContextInterceptor{
Origin: correlation.ComponentCSIDriver,
}
opts := make([]grpc.ServerOption, 0)
unaryInterceptors = append(unaryInterceptors, correlationInterceptor.ContextUnaryServerInterceptor)

// create scheduler interceptor
switch config.SchedulerName {
case "kubernetes":
logrus.Infof("CSI K8s filter being added for %s scheduler", config.SchedulerName)
ki := k8s.NewInterceptor()
unaryInterceptors = append(unaryInterceptors, ki.SchedUnaryInterceptor)

default:
logrus.Infof("No CSI filter being added for %s scheduler", config.SchedulerName)
}

// Add interceptors
opts = append(opts, grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(unaryInterceptors...)))

// Create server
gServer, err := grpcserver.New(&grpcserver.GrpcServerConfig{
Name: "CSI 1.7",
Net: config.Net,
Address: config.Address,
Opts: opts,
})
gServer, err := createGrpcServer(config)
if err != nil {
return nil, fmt.Errorf("Failed to create CSI server: %v", err)
return nil, err
}

return &OsdCsiServer{
Expand All @@ -160,6 +137,8 @@ func NewOsdCsiServer(config *OsdCsiServerConfig) (grpcserver.Server, error) {
allowInlineVolumes: config.EnableInlineVolumes,
roundRobinBalancer: config.RoundRobinBalancer,
cloudBackupClient: api.NewOpenStorageCloudBackupClient,
config: config,
autoRecoverStopCh: make(chan struct{}),
}, nil
}

Expand Down Expand Up @@ -257,18 +236,110 @@ func (s *OsdCsiServer) addEncryptionInfoToLabels(labels, csiSecrets map[string]s
// Start is used to start the server.
// It will return an error if the server is already running.
func (s *OsdCsiServer) Start() error {
return s.GrpcServer.Start(func(grpcServer *grpc.Server) {
if err := s.GrpcServer.Start(func(grpcServer *grpc.Server) {
csi.RegisterIdentityServer(grpcServer, s)
csi.RegisterControllerServer(grpcServer, s)
csi.RegisterNodeServer(grpcServer, s)
})
}); err != nil {
return err
}

if s.config.Net == "unix" {
go func() {
err := autoSocketRecover(s, s.autoRecoverStopCh)
if err != nil {
logrus.Errorf("failed to start CSI driver socket auto-recover watcher: %v", err)
}
}()
}

return nil
}

// Start is used to stop the server.
func (s *OsdCsiServer) Stop() {
close(s.autoRecoverStopCh)
s.GrpcServer.Stop()
}

func createGrpcServer(config *OsdCsiServerConfig) (*grpcserver.GrpcServer, error) {
// create correlation interceptor
var unaryInterceptors []grpc.UnaryServerInterceptor
correlationInterceptor := correlation.ContextInterceptor{
Origin: correlation.ComponentCSIDriver,
}
opts := make([]grpc.ServerOption, 0)
unaryInterceptors = append(unaryInterceptors, correlationInterceptor.ContextUnaryServerInterceptor)

// create scheduler interceptor
switch config.SchedulerName {
case "kubernetes":
logrus.Infof("CSI K8s filter being added for %s scheduler", config.SchedulerName)
ki := k8s.NewInterceptor()
unaryInterceptors = append(unaryInterceptors, ki.SchedUnaryInterceptor)

default:
logrus.Infof("No CSI filter being added for %s scheduler", config.SchedulerName)
}

// Add interceptors
opts = append(opts, grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(unaryInterceptors...)))

// Create server
gServer, err := grpcserver.New(&grpcserver.GrpcServerConfig{
Name: "CSI 1.7",
Net: config.Net,
Address: config.Address,
Opts: opts,
})
if err != nil {
return nil, fmt.Errorf("Failed to create CSI server: %v", err)
}

return gServer, nil
}

func autoSocketRecover(s *OsdCsiServer, stopCh chan struct{}) error {
socketPath := s.Address()
ticker := time.NewTicker(csiSocketCheckInterval)

// Start checking for CSI socket delete
for {
select {
case <-stopCh:
return nil
case <-ticker.C:
}

// Check if socket deleted
_, err := os.Stat(socketPath)
if err == nil {
continue
}

logrus.Infof("Detected CSI socket deleted at path %s. Stopping CSI gRPC server", socketPath)
s.GrpcServer.Stop()

// Re-create gRPC server
gServer, err := createGrpcServer(s.config)
if err != nil {
logrus.Errorf("failed to re-create gRPC server: %v. Retrying in %s...", err, csiSocketCheckInterval)
continue
}
s.GrpcServer = gServer

// Start server
logrus.Infof("Restarting CSI gRPC server at %s", socketPath)
if err := s.Start(); err != nil {
logrus.Errorf("CSI server failed to auto-recover after socket deletion: %v. Retrying in %s...", err, csiSocketCheckInterval)
continue
}

// Exit for next process to start
return nil
}
}

// adjustFinalErrors adjusts certain gRPC status to make CSI callers
// (csi-provisioner, kubelet, etc) retry instead of being marked as a failure.
// See https://github.com/kubernetes/kubernetes/blob/64ed9145452d2d1d324d2437566f1ea1ce76f226/pkg/volume/csi/csi_client.go#L718-L724
Expand Down
56 changes: 51 additions & 5 deletions csi/csi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ import (
)

const (
mockDriverName = "mock"
testSharedSecret = "mysecret"
fakeWithSched = "fake-sched"
mockDriverName = "mock"
testSharedSecret = "mysecret"
fakeWithSched = "fake-sched"
testSocketLocation = "/tmp/csi-ut.sock"
)

var (
Expand Down Expand Up @@ -141,6 +142,15 @@ func newTestServer(t *testing.T) *testServer {
})
}

func newUDSTestServer(t *testing.T) *testServer {
os.Remove(testSocketLocation)
return newTestServerWithConfig(t, &OsdCsiServerConfig{
DriverName: mockDriverName,
Address: testSocketLocation,
Net: "unix",
})
}

func newTestServerWithConfig(t *testing.T, config *OsdCsiServerConfig) *testServer {
tester := &testServer{}
tester.setPorts()
Expand Down Expand Up @@ -205,8 +215,13 @@ func newTestServerWithConfig(t *testing.T, config *OsdCsiServerConfig) *testServ
})

// Setup CSI simple driver
config.Net = "tcp"
config.Address = "127.0.0.1:0"
// Allow for net and address to be overwritten
if config.Net == "" {
config.Net = "tcp"
}
if config.Address == "" {
config.Address = "127.0.0.1:0"
}
config.SdkUds = tester.uds
config.SdkPort = tester.port
tester.server, err = NewOsdCsiServer(config)
Expand Down Expand Up @@ -446,3 +461,34 @@ func TestCSIServerStartContextInterceptor(t *testing.T) {
expectedInfoLog = "csi-driver"
assert.Contains(t, logStr, expectedInfoLog)
}

func TestCSISocketAutoRecover(t *testing.T) {
csiSocketCheckInterval = 1 * time.Second

// Start server and wait for socket to be up and running
s := newUDSTestServer(t)
assert.True(t, s.Server().IsRunning())
defer func() {
s.Stop()
}()
assert.Eventually(t, s.server.IsRunning, 30*time.Second, time.Second)
assert.Eventually(t, func() bool {
_, err := os.Stat(testSocketLocation)
return err == nil
}, 30*time.Second, time.Second)
_, err := os.Stat(testSocketLocation)
assert.NoError(t, err, "UDS should exist after startup")

// Delete socket and wait for it to be gone
err = os.Remove(testSocketLocation)
assert.NoError(t, err)

// Wait for auto-recover
assert.Eventually(t, func() bool {
_, err := os.Stat(testSocketLocation)
return err == nil
}, 30*time.Second, time.Second)
assert.True(t, s.server.IsRunning(), "Server should be running after autorecover")
_, err = os.Stat(testSocketLocation)
assert.NoError(t, err, "UDS should exist after autorecover")
}
5 changes: 5 additions & 0 deletions pkg/grpcserver/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ func (s *GrpcServer) IsRunning() bool {
return s.running
}

// Listener returns the listener used for this gRPC server
func (s *GrpcServer) Listener() net.Listener {
return s.listener
}

func (s *GrpcServer) goServe(started chan<- bool) {
s.wg.Add(1)
go func() {
Expand Down

0 comments on commit b1b978c

Please sign in to comment.