diff --git a/internal/dataprotection/action/action_create_vs.go b/internal/dataprotection/action/action_create_vs.go index b8ed5ab7c78..1d3fc340f8c 100644 --- a/internal/dataprotection/action/action_create_vs.go +++ b/internal/dataprotection/action/action_create_vs.go @@ -20,12 +20,14 @@ along with this program. If not, see . package action import ( + "context" "fmt" "strings" vsv1 "github.com/kubernetes-csi/external-snapshotter/client/v6/apis/volumesnapshot/v1" "github.com/pkg/errors" corev1 "k8s.io/api/core/v1" + storagev1 "k8s.io/api/storage/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" @@ -139,8 +141,8 @@ func (c *CreateVolumeSnapshotAction) createVolumeSnapshotIfNotExist(ctx Context, pvc *corev1.PersistentVolumeClaim, key client.ObjectKey) error { var ( - err error - vsc *vsv1.VolumeSnapshotClass + err error + vscName string ) snap := &vsv1.VolumeSnapshot{} @@ -167,8 +169,14 @@ func (c *CreateVolumeSnapshotAction) createVolumeSnapshotIfNotExist(ctx Context, }, } - if vsc != nil { - snap.Spec.VolumeSnapshotClassName = &vsc.Name + if pvc.Spec.StorageClassName != nil && *pvc.Spec.StorageClassName != "" { + if vscName, err = c.getVolumeSnapshotClassName(ctx.Ctx, ctx.Client, vsCli, *pvc.Spec.StorageClassName); err != nil { + return err + } + } + + if vscName != "" { + snap.Spec.VolumeSnapshotClassName = &vscName } controllerutil.AddFinalizer(snap, dptypes.DataProtectionFinalizerName) @@ -184,6 +192,29 @@ func (c *CreateVolumeSnapshotAction) createVolumeSnapshotIfNotExist(ctx Context, return nil } +func (c *CreateVolumeSnapshotAction) getVolumeSnapshotClassName( + ctx context.Context, + cli client.Client, + vsCli intctrlutil.VolumeSnapshotCompatClient, + scName string) (string, error) { + scObj := storagev1.StorageClass{} + // ignore if not found storage class, use the default volume snapshot class + if err := cli.Get(ctx, client.ObjectKey{Name: scName}, &scObj); client.IgnoreNotFound(err) != nil { + return "", err + } + + vscList := vsv1.VolumeSnapshotClassList{} + if err := vsCli.List(&vscList); err != nil { + return "", err + } + for _, item := range vscList.Items { + if item.Driver == scObj.Provisioner { + return item.Name, nil + } + } + return "", nil +} + func ensureVolumeSnapshotReady( vsCli intctrlutil.VolumeSnapshotCompatClient, key client.ObjectKey) (bool, *vsv1.VolumeSnapshot, error) {