diff --git a/api/client/volume/client.go b/api/client/volume/client.go index a89f80861..69a5057fd 100644 --- a/api/client/volume/client.go +++ b/api/client/volume/client.go @@ -75,10 +75,18 @@ func (v *volumeClient) GraphDriverRemove(id string) error { return nil } +func (v *volumeClient) StartVolumeWatcher() { + return +} + func (v *volumeClient) GetVolumeWatcher(locator *api.VolumeLocator, labels map[string]string) (chan *api.Volume, error) { return nil, nil } +func (v *volumeClient) StopVolumeWatcher() { + return +} + func (v *volumeClient) GraphDriverGet(id string, mountLabel string) (string, error) { response := "" if err := v.c.Get().Resource(graphPath + "/inspect").Instance(id).Do().Unmarshal(&response); err != nil { diff --git a/api/server/sdk/sdk_test.go b/api/server/sdk/sdk_test.go index 96e2781a1..f29e2dea8 100644 --- a/api/server/sdk/sdk_test.go +++ b/api/server/sdk/sdk_test.go @@ -147,6 +147,7 @@ func newTestServer(t *testing.T) *testServer { assert.Nil(t, err) + tester.m.EXPECT().StartVolumeWatcher().Return().Times(1) tester.m.EXPECT().GetVolumeWatcher(&api.VolumeLocator{}, make(map[string]string)).DoAndReturn(func(a *api.VolumeLocator, l map[string]string) (chan *api.Volume, error) { ch := make(chan *api.Volume, 1) tester.server.watcherCtxCancel() @@ -233,7 +234,7 @@ func newTestServerAuth(t *testing.T) *testServer { }, }) assert.Nil(t, err) - + tester.m.EXPECT().StartVolumeWatcher().Return().Times(1) tester.m.EXPECT().GetVolumeWatcher(&api.VolumeLocator{}, make(map[string]string)).DoAndReturn(func(a *api.VolumeLocator, l map[string]string) (chan *api.Volume, error) { ch := make(chan *api.Volume, 1) tester.server.watcherCtxCancel() @@ -294,6 +295,7 @@ func (s *testServer) Stop() { // Shutdown servers s.conn.Close() + s.m.EXPECT().StopVolumeWatcher().Return().AnyTimes() s.server.Stop() s.gw.Close() diff --git a/api/server/sdk/server.go b/api/server/sdk/server.go index 7525eb31c..ea95dcf5a 100644 --- a/api/server/sdk/server.go +++ b/api/server/sdk/server.go @@ -314,6 +314,7 @@ func (s *Server) Stop() { s.netServer.Stop() s.udsServer.Stop() s.restGateway.Stop() + s.netServer.watcherServer.stopWatcher(s.watcherCtx) s.watcherCtxCancel() if s.accessLog != nil { diff --git a/api/server/sdk/watcher.go b/api/server/sdk/watcher.go index 3b2172cad..9713b0180 100644 --- a/api/server/sdk/watcher.go +++ b/api/server/sdk/watcher.go @@ -84,6 +84,10 @@ func (s *WatcherServer) removeWatcher(name string, eventType string) { s.watchConnections[eventType] = newWatchers } +func (s *WatcherServer) stopWatcher(ctx context.Context) { + s.volumeServer.driver(ctx).StopVolumeWatcher() +} + func (s *WatcherServer) startWatcher(ctx context.Context) error { group, _ := errgroup.WithContext(ctx) errChan := make(chan error) @@ -121,7 +125,7 @@ func (s *WatcherServer) startVolumeWatcher(ctx context.Context) error { time.Sleep(2 * time.Second) continue } - + s.volumeServer.driver(ctx).StartVolumeWatcher() volumeChannel, err := s.volumeServer.driver(ctx).GetVolumeWatcher(&api.VolumeLocator{}, make(map[string]string)) if err != nil { logrus.Warnf("Error getting volume watcher %v", err) diff --git a/csi/csi_test.go b/csi/csi_test.go index a6b0b390a..4a1400d14 100644 --- a/csi/csi_test.go +++ b/csi/csi_test.go @@ -270,6 +270,7 @@ func (s *testServer) Stop() { // Shutdown servers s.conn.Close() + s.m.EXPECT().StopVolumeWatcher().Return().AnyTimes() s.server.Stop() s.sdk.Stop() diff --git a/volume/drivers/buse/buse.go b/volume/drivers/buse/buse.go index deeea82d5..6d48ca079 100644 --- a/volume/drivers/buse/buse.go +++ b/volume/drivers/buse/buse.go @@ -146,10 +146,18 @@ func Init(params map[string]string) (volume.VolumeDriver, error) { // These functions below implement the volume driver interface. // +func (d *driver) StartVolumeWatcher() { + return +} + func (d *driver) GetVolumeWatcher(locator *api.VolumeLocator, labels map[string]string) (chan *api.Volume, error) { return nil, nil } +func (d *driver) StopVolumeWatcher() { + return +} + func (d *driver) String() string { return Name } diff --git a/volume/drivers/fake/fake.go b/volume/drivers/fake/fake.go index 0be3f22d4..c99315ade 100644 --- a/volume/drivers/fake/fake.go +++ b/volume/drivers/fake/fake.go @@ -142,6 +142,10 @@ func volumeGenerator(d *driver) { } } +func (d *driver) StartVolumeWatcher() { + return +} + func (d *driver) GetVolumeWatcher(locator *api.VolumeLocator, labels map[string]string) (chan *api.Volume, error) { go volumeGenerator(d) if d.volumeChannel == nil { @@ -150,6 +154,10 @@ func (d *driver) GetVolumeWatcher(locator *api.VolumeLocator, labels map[string] return d.volumeChannel, nil } +func (d *driver) StopVolumeWatcher() { + return +} + func (d *driver) Name() string { return Name } diff --git a/volume/drivers/fuse/volume_driver.go b/volume/drivers/fuse/volume_driver.go index f6026c82a..104f468b3 100644 --- a/volume/drivers/fuse/volume_driver.go +++ b/volume/drivers/fuse/volume_driver.go @@ -59,11 +59,18 @@ func newVolumeDriver( provider, } } +func (v *volumeDriver) StartVolumeWatcher() { + return +} func (v *volumeDriver) GetVolumeWatcher(locator *api.VolumeLocator, labels map[string]string) (chan *api.Volume, error) { return nil, nil } +func (v *volumeDriver) StopVolumeWatcher() { + return +} + func (v *volumeDriver) Name() string { return v.name } diff --git a/volume/drivers/mock/driver.mock.go b/volume/drivers/mock/driver.mock.go index d271eef2c..8e625515a 100644 --- a/volume/drivers/mock/driver.mock.go +++ b/volume/drivers/mock/driver.mock.go @@ -897,6 +897,18 @@ func (mr *MockVolumeDriverMockRecorder) SnapshotGroup(arg0, arg1, arg2, arg3 int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SnapshotGroup", reflect.TypeOf((*MockVolumeDriver)(nil).SnapshotGroup), arg0, arg1, arg2, arg3) } +// StartVolumeWatcher mocks base method. +func (m *MockVolumeDriver) StartVolumeWatcher() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StartVolumeWatcher") +} + +// StartVolumeWatcher indicates an expected call of StartVolumeWatcher. +func (mr *MockVolumeDriverMockRecorder) StartVolumeWatcher() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartVolumeWatcher", reflect.TypeOf((*MockVolumeDriver)(nil).StartVolumeWatcher)) +} + // Stats mocks base method. func (m *MockVolumeDriver) Stats(arg0 string, arg1 bool) (*api.Stats, error) { m.ctrl.T.Helper() @@ -926,6 +938,18 @@ func (mr *MockVolumeDriverMockRecorder) Status() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Status", reflect.TypeOf((*MockVolumeDriver)(nil).Status)) } +// StopVolumeWatcher mocks base method. +func (m *MockVolumeDriver) StopVolumeWatcher() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "StopVolumeWatcher") +} + +// StopVolumeWatcher indicates an expected call of StopVolumeWatcher. +func (mr *MockVolumeDriverMockRecorder) StopVolumeWatcher() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StopVolumeWatcher", reflect.TypeOf((*MockVolumeDriver)(nil).StopVolumeWatcher)) +} + // Type mocks base method. func (m *MockVolumeDriver) Type() api.DriverType { m.ctrl.T.Helper() diff --git a/volume/drivers/nfs/nfs.go b/volume/drivers/nfs/nfs.go index c38b53da3..31833ff07 100644 --- a/volume/drivers/nfs/nfs.go +++ b/volume/drivers/nfs/nfs.go @@ -183,10 +183,18 @@ func Init(params map[string]string) (volume.VolumeDriver, error) { return inst, nil } +func (d *driver) StartVolumeWatcher() { + return +} + func (d *driver) GetVolumeWatcher(locator *api.VolumeLocator, labels map[string]string) (chan *api.Volume, error) { return nil, nil } +func (d *driver) StopVolumeWatcher() { + return +} + func (d *driver) Name() string { return Name } diff --git a/volume/drivers/vfs/vfs.go b/volume/drivers/vfs/vfs.go index 4c8c250c7..267c3cdc6 100644 --- a/volume/drivers/vfs/vfs.go +++ b/volume/drivers/vfs/vfs.go @@ -57,10 +57,18 @@ func Init(params map[string]string) (volume.VolumeDriver, error) { }, nil } +func (d *driver) StartVolumeWatcher() { + return +} + func (d *driver) GetVolumeWatcher(locator *api.VolumeLocator, labels map[string]string) (chan *api.Volume, error) { return nil, nil } +func (d *driver) StopVolumeWatcher() { + return +} + func (d *driver) Name() string { return Name } diff --git a/volume/volume.go b/volume/volume.go index 1827fe269..ef7872b4e 100644 --- a/volume/volume.go +++ b/volume/volume.go @@ -311,8 +311,12 @@ type Enumerator interface { // Water provides a set of function to get volume type Watcher interface { + // Stop Volume notifier + StartVolumeWatcher() // Gets Volume notifier GetVolumeWatcher(locator *api.VolumeLocator, labels map[string]string) (chan *api.Volume, error) + // Stop Volume notifier + StopVolumeWatcher() } // StoreEnumerator combines Store and Enumerator capabilities