diff --git a/csi/controller.go b/csi/controller.go index 880019ea0..44fced79f 100644 --- a/csi/controller.go +++ b/csi/controller.go @@ -18,6 +18,7 @@ package csi import ( "encoding/json" + "errors" "fmt" "math" "reflect" @@ -477,6 +478,11 @@ func (s *OsdCsiServer) CreateVolume( return nil, status.Error(codes.InvalidArgument, "Name must be provided") } + if req.VolumeContentSource != nil && req.VolumeContentSource.GetSnapshot() != nil { + clogger.WithContext(ctx).Infof("csi.CreateVolume restoring snapshot to Volume: %s", req.GetName()) + return s.restoreSnapshot(ctx, req) + } + // Get parameters spec, locator, source, err := s.specHandler.SpecFromOpts(req.GetParameters()) if err != nil { @@ -632,6 +638,88 @@ func (s *OsdCsiServer) CreateVolume( }, nil } +func (s *OsdCsiServer) restoreSnapshot(ctx context.Context, + req *csi.CreateVolumeRequest, +) (resp *csi.CreateVolumeResponse, err error) { + snapshot := req.VolumeContentSource.GetSnapshot() + if snapshot == nil { + return nil, errors.New("snapshot fetched is not accurate or does not exist") + } + + cloudBackupClient, err := s.getCloudBackupClient(ctx) + cloudBackupClientAvailable := cloudBackupClient != nil + cloudBackupDriverDisabled := sdk.IsErrorUnavailable(err) + if (err != nil && !cloudBackupDriverDisabled) || cloudBackupClient == nil { + return nil, err + } + + csiSnapshotID := snapshot.GetSnapshotId() + if len(csiSnapshotID) == 0 { + return nil, status.Error(codes.InvalidArgument, "Snapshot id must be provided") + } + + var backupStatus *api.SdkCloudBackupStatusResponse + if cloudBackupClientAvailable && !cloudBackupDriverDisabled { + // Check if snapshot has been created but is in error state + backupStatus, err = cloudBackupClient.Status(ctx, &api.SdkCloudBackupStatusRequest{ + TaskId: csiSnapshotID, + }) + } + + isSnapshotIDPresentInCloud := true + if backupStatus != nil { + _, isSnapshotIDPresentInCloud = backupStatus.Statuses[csiSnapshotID] + } + + if (sdk.IsErrorNotFound(err) && !cloudBackupDriverDisabled && cloudBackupClientAvailable) || !isSnapshotIDPresentInCloud { + + return + } + + resp, err = s.restoreCloudSnapshot(ctx, req) + return + +} + +func (s *OsdCsiServer) restoreCloudSnapshot(ctx context.Context, + req *csi.CreateVolumeRequest, +) (resp *csi.CreateVolumeResponse, err error) { + cloudBackupClient, err := s.getCloudBackupClient(ctx) + if err != nil { + return nil, err + } + // Get parameters + _, locator, _, err := s.specHandler.SpecFromOpts(req.GetParameters()) + if err != nil { + e := fmt.Sprintf("Unable to get parameters: %s\n", err.Error()) + clogger.WithContext(ctx).Errorln(e) + return nil, status.Error(codes.InvalidArgument, e) + } + + csiSnapshotID := req.VolumeContentSource.GetSnapshot().GetSnapshotId() + + credentialID := locator.VolumeLabels[osdSnapshotCredentialIDKey] + + snapResp, err := cloudBackupClient.Restore(ctx, &api.SdkCloudBackupRestoreRequest{ + BackupId: csiSnapshotID, + RestoreVolumeName: req.GetName(), + TaskId: req.GetName(), + Locator: locator, + CredentialId: credentialID, + }) + if nil != err { + return nil, err + } + + resp = &csi.CreateVolumeResponse{ + Volume: &csi.Volume{ + VolumeId: snapResp.GetRestoreVolumeId(), + ContentSource: req.VolumeContentSource, + }, + } + return +} + func getClonedPVCMetadata(locator *api.VolumeLocator) map[string]string { metadataLabels := map[string]string{} pvcName, ok := locator.VolumeLabels[intreePvcNameKey]