diff --git a/pkg/azuredisk/azuredisk.go b/pkg/azuredisk/azuredisk.go index e7f7bd12d7..db8072353e 100644 --- a/pkg/azuredisk/azuredisk.go +++ b/pkg/azuredisk/azuredisk.go @@ -125,7 +125,9 @@ type DriverCore struct { removeNotReadyTaint bool kubeClient kubernetes.Interface // a timed cache storing volume stats - volStatsCache azcache.Resource + volStatsCache azcache.Resource + maxConcurrentFormat int64 + concurrentFormatTimeout int64 } // Driver is the v1 implementation of the Azure Disk CSI Driver. @@ -176,6 +178,8 @@ func newDriverV1(options *DriverOptions) *Driver { driver.endpoint = options.Endpoint driver.disableAVSetNodes = options.DisableAVSetNodes driver.removeNotReadyTaint = options.RemoveNotReadyTaint + driver.maxConcurrentFormat = options.MaxConcurrentFormat + driver.concurrentFormatTimeout = options.ConcurrentFormatTimeout driver.volumeLocks = volumehelper.NewVolumeLocks() driver.ioHandler = azureutils.NewOSIOHandler() driver.hostUtil = hostutil.NewHostUtil() @@ -263,7 +267,7 @@ func newDriverV1(options *DriverOptions) *Driver { } } - driver.mounter, err = mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface) + driver.mounter, err = mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) if err != nil { klog.Fatalf("Failed to get safe mounter. Error: %v", err) } diff --git a/pkg/azuredisk/azuredisk_option.go b/pkg/azuredisk/azuredisk_option.go index ef13816c8e..3556f341a8 100644 --- a/pkg/azuredisk/azuredisk_option.go +++ b/pkg/azuredisk/azuredisk_option.go @@ -61,6 +61,8 @@ type DriverOptions struct { Endpoint string DisableAVSetNodes bool RemoveNotReadyTaint bool + MaxConcurrentFormat int64 + ConcurrentFormatTimeout int64 } func (o *DriverOptions) AddFlags() *flag.FlagSet { @@ -103,6 +105,8 @@ func (o *DriverOptions) AddFlags() *flag.FlagSet { fs.BoolVar(&o.DisableAVSetNodes, "disable-avset-nodes", false, "disable DisableAvailabilitySetNodes in cloud config for controller") fs.BoolVar(&o.RemoveNotReadyTaint, "remove-not-ready-taint", true, "remove NotReady taint from node when node is ready") fs.StringVar(&o.Endpoint, "endpoint", "unix://tmp/csi.sock", "CSI endpoint") + fs.Int64Var(&o.MaxConcurrentFormat, "max-concurrent-format", 2, "maximum number of concurrent format exec calls") + fs.Int64Var(&o.ConcurrentFormatTimeout, "concurrent-format-timeout", 120, "maximum time in seconds duration of a format operation before its concurrency token is released") return fs } diff --git a/pkg/azuredisk/azuredisk_v2.go b/pkg/azuredisk/azuredisk_v2.go index e8b2eebc50..cad50b0850 100644 --- a/pkg/azuredisk/azuredisk_v2.go +++ b/pkg/azuredisk/azuredisk_v2.go @@ -26,6 +26,7 @@ import ( "fmt" "os" "reflect" + "time" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v6" "github.com/container-storage-interface/spec/lib/go/csi" @@ -143,7 +144,7 @@ func newDriverV2(options *DriverOptions) *DriverV2 { } } - driver.mounter, err = mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface) + driver.mounter, err = mounter.NewSafeMounter(driver.enableWindowsHostProcess, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) if err != nil { klog.Fatalf("Failed to get safe mounter. Error: %v", err) } diff --git a/pkg/azuredisk/fake_azuredisk.go b/pkg/azuredisk/fake_azuredisk.go index 8670ff6c89..e3fe666e6d 100644 --- a/pkg/azuredisk/fake_azuredisk.go +++ b/pkg/azuredisk/fake_azuredisk.go @@ -125,7 +125,7 @@ func newFakeDriverV1(ctrl *gomock.Controller) (*fakeDriverV1, error) { driver.diskController = NewManagedDiskController(driver.cloud) driver.clientFactory = driver.cloud.ComputeClientFactory - mounter, err := mounter.NewSafeMounter(true, driver.useCSIProxyGAInterface) + mounter, err := mounter.NewSafeMounter(true, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) if err != nil { return nil, err } diff --git a/pkg/azuredisk/fake_azuredisk_v2.go b/pkg/azuredisk/fake_azuredisk_v2.go index 93bb4409c8..ece5748b07 100644 --- a/pkg/azuredisk/fake_azuredisk_v2.go +++ b/pkg/azuredisk/fake_azuredisk_v2.go @@ -20,6 +20,8 @@ limitations under the License. package azuredisk import ( + "time" + "github.com/container-storage-interface/spec/lib/go/csi" "go.uber.org/mock/gomock" "k8s.io/client-go/kubernetes/fake" @@ -74,7 +76,7 @@ func newFakeDriverV2(ctrl *gomock.Controller) (*fakeDriverV2, error) { driver.diskController = NewManagedDiskController(driver.cloud) driver.clientFactory = driver.cloud.ComputeClientFactory - mounter, err := mounter.NewSafeMounter(true, driver.useCSIProxyGAInterface) + mounter, err := mounter.NewSafeMounter(true, driver.useCSIProxyGAInterface, int(driver.maxConcurrentFormat), time.Duration(driver.concurrentFormatTimeout)*time.Second) if err != nil { return nil, err } diff --git a/pkg/mounter/fake_safe_mounter.go b/pkg/mounter/fake_safe_mounter.go index 4d392e790c..e5d079b616 100644 --- a/pkg/mounter/fake_safe_mounter.go +++ b/pkg/mounter/fake_safe_mounter.go @@ -20,6 +20,7 @@ import ( "fmt" "runtime" "strings" + "time" "k8s.io/mount-utils" "k8s.io/utils/exec" @@ -35,7 +36,7 @@ type FakeSafeMounter struct { // NewFakeSafeMounter creates a mount.SafeFormatAndMount instance suitable for use in unit tests. func NewFakeSafeMounter() (*mount.SafeFormatAndMount, error) { if runtime.GOOS == "windows" { - return NewSafeMounter(true, true) + return NewSafeMounter(true, true, 2, time.Duration(120)*time.Second) } fakeSafeMounter := FakeSafeMounter{} diff --git a/pkg/mounter/safe_mounter_unix.go b/pkg/mounter/safe_mounter_unix.go index 22189c28c2..2fe95f9c75 100644 --- a/pkg/mounter/safe_mounter_unix.go +++ b/pkg/mounter/safe_mounter_unix.go @@ -20,13 +20,13 @@ limitations under the License. package mounter import ( + "time" + "k8s.io/mount-utils" utilexec "k8s.io/utils/exec" ) -func NewSafeMounter(_, _ bool) (*mount.SafeFormatAndMount, error) { - return &mount.SafeFormatAndMount{ - Interface: mount.New(""), - Exec: utilexec.New(), - }, nil +func NewSafeMounter(_, _ bool, maxConcurrentFormat int, concurrentFormatTimeout time.Duration) (*mount.SafeFormatAndMount, error) { + opt := mount.WithMaxConcurrentFormat(maxConcurrentFormat, concurrentFormatTimeout) + return mount.NewSafeFormatAndMount(mount.New(""), utilexec.New(), opt), nil } diff --git a/pkg/mounter/safe_mounter_unix_test.go b/pkg/mounter/safe_mounter_unix_test.go index e620101f63..2ae13cccef 100644 --- a/pkg/mounter/safe_mounter_unix_test.go +++ b/pkg/mounter/safe_mounter_unix_test.go @@ -18,12 +18,13 @@ package mounter import ( "testing" + "time" "github.com/stretchr/testify/assert" ) func TestNewSafeMounter(t *testing.T) { - resp, err := NewSafeMounter(true, true) + resp, err := NewSafeMounter(true, true, 2, time.Duration(120)*time.Second) assert.NotNil(t, resp) assert.Nil(t, err) } diff --git a/pkg/mounter/safe_mounter_windows.go b/pkg/mounter/safe_mounter_windows.go index 7496324364..f0832f9f8a 100644 --- a/pkg/mounter/safe_mounter_windows.go +++ b/pkg/mounter/safe_mounter_windows.go @@ -25,6 +25,7 @@ import ( "os" "strconv" "strings" + "time" "github.com/container-storage-interface/spec/lib/go/csi" disk "github.com/kubernetes-csi/csi-proxy/client/api/disk/v1" @@ -411,13 +412,11 @@ func newCSIProxyMounter() (*csiProxyMounter, error) { }, nil } -func NewSafeMounter(enableWindowsHostProcess, useCSIProxyGAInterface bool) (*mount.SafeFormatAndMount, error) { +func NewSafeMounter(enableWindowsHostProcess, useCSIProxyGAInterface bool, maxConcurrentFormat int, concurrentFormatTimeout time.Duration) (*mount.SafeFormatAndMount, error) { if enableWindowsHostProcess { klog.V(2).Infof("using windows host process mounter") - return &mount.SafeFormatAndMount{ - Interface: NewWinMounter(), - Exec: utilexec.New(), - }, nil + opt := mount.WithMaxConcurrentFormat(maxConcurrentFormat, concurrentFormatTimeout) + return mount.NewSafeFormatAndMount(NewWinMounter(), utilexec.New(), opt), nil } else { if useCSIProxyGAInterface { csiProxyMounter, err := newCSIProxyMounter()