diff --git a/csi/controller.go b/csi/controller.go index 880019ea0..a9a3aa75f 100644 --- a/csi/controller.go +++ b/csi/controller.go @@ -485,6 +485,16 @@ func (s *OsdCsiServer) CreateVolume( return nil, status.Error(codes.InvalidArgument, e) } + // Check ID is valid with the specified volume capabilities + snapshotType, ok := locator.VolumeLabels[osdSnapshotLabelsTypeKey] + if !ok { + snapshotType = DriverTypeLocal + } + if DriverTypeCloud == snapshotType && req.VolumeContentSource != nil && req.VolumeContentSource.GetSnapshot() != nil { + clogger.WithContext(ctx).Infof("csi.CreateVolume restoring snapshot to Volume: %s", req.GetName()) + return s.restoreSnapshot(ctx, req) + } + if spec.IsPureVolume() { err = validateCreateVolumeCapabilitiesPure(req.GetVolumeCapabilities(), spec.GetProxySpec()) if err != nil { @@ -632,6 +642,92 @@ func (s *OsdCsiServer) CreateVolume( }, nil } +func (s *OsdCsiServer) restoreSnapshot(ctx context.Context, + req *csi.CreateVolumeRequest, +) (resp *csi.CreateVolumeResponse, err error) { + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s", req.GetName(), req.VolumeContentSource.GetSnapshot()) + snapshot := req.VolumeContentSource.GetSnapshot() + if snapshot == nil { + return nil, status.Error(codes.NotFound, "snapshot fetched is not accurate or does not exist") + } + + cloudBackupClient, err := s.getCloudBackupClient(ctx) + cloudBackupClientAvailable := cloudBackupClient != nil + cloudBackupDriverDisabled := sdk.IsErrorUnavailable(err) + if (err != nil && !cloudBackupDriverDisabled) || cloudBackupClient == nil { + return nil, err + } + + csiSnapshotID := snapshot.GetSnapshotId() + if len(csiSnapshotID) == 0 { + return nil, status.Error(codes.InvalidArgument, "snapshot id must be provided") + } + + var backupStatus *api.SdkCloudBackupStatusResponse + if cloudBackupClientAvailable && !cloudBackupDriverDisabled { + // Check if snapshot has been created but is in error state + backupStatus, err = cloudBackupClient.Status(ctx, &api.SdkCloudBackupStatusRequest{ + TaskId: csiSnapshotID, + }) + } + + isSnapshotIDPresentInCloud := true + if backupStatus != nil { + _, isSnapshotIDPresentInCloud = backupStatus.Statuses[csiSnapshotID] + } + + if (sdk.IsErrorNotFound(err) && !cloudBackupDriverDisabled && cloudBackupClientAvailable) || !isSnapshotIDPresentInCloud { + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a local backup", req.GetName(), csiSnapshotID) + return + } + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a cloud backup", req.GetName(), csiSnapshotID) + + resp, err = s.restoreCloudSnapshot(ctx, req) + return + +} + +func (s *OsdCsiServer) restoreCloudSnapshot(ctx context.Context, + req *csi.CreateVolumeRequest, +) (resp *csi.CreateVolumeResponse, err error) { + cloudBackupClient, err := s.getCloudBackupClient(ctx) + if err != nil { + return nil, err + } + // Get parameters + _, locator, _, _ := s.specHandler.SpecFromOpts(req.GetParameters()) + + csiSnapshotID := req.VolumeContentSource.GetSnapshot().GetSnapshotId() + + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a cloud backup with labels %+v", req.GetName(), csiSnapshotID, locator.VolumeLabels) + + credentialID, ok := locator.VolumeLabels[osdSnapshotCredentialIDKey] + if !ok { + e := fmt.Sprintf("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s credentials missing", req.GetName(), csiSnapshotID) + clogger.WithContext(ctx).Infof(e) + return nil, status.Error(codes.InvalidArgument, e) + } + + snapResp, err := cloudBackupClient.Restore(ctx, &api.SdkCloudBackupRestoreRequest{ + BackupId: csiSnapshotID, + RestoreVolumeName: req.GetName(), + TaskId: req.GetName(), + Locator: locator, + CredentialId: credentialID, + }) + if nil != err { + return nil, status.Error(codes.Internal, err.Error()) + } + + resp = &csi.CreateVolumeResponse{ + Volume: &csi.Volume{ + VolumeId: snapResp.GetRestoreVolumeId(), + ContentSource: req.VolumeContentSource, + }, + } + return +} + func getClonedPVCMetadata(locator *api.VolumeLocator) map[string]string { metadataLabels := map[string]string{} pvcName, ok := locator.VolumeLabels[intreePvcNameKey] diff --git a/csi/controller_test.go b/csi/controller_test.go index 3c8d95649..a33fc9ea1 100644 --- a/csi/controller_test.go +++ b/csi/controller_test.go @@ -3734,3 +3734,185 @@ func TestOsdCsiServer_DeleteCloudSnapshot(t *testing.T) { }) } } + +func TestOsdCsiServer_RestoreCloudSnapshot(t *testing.T) { + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCloudBackupClient := mock.NewMockOpenStorageCloudBackupClient(ctrl) + + ctx := context.Background() + + mockErr := errors.New("MOCK ERROR") + creationTime := timestamppb.Now() + + mockVolumeName := "mock-volume-id" + + mockRoundRobinBalancer := mockLoadBalancer.NewMockBalancer(ctrl) + mockRoundRobinBalancer.EXPECT().GetRemoteNodeConnection(gomock.Any()).DoAndReturn( + func(ctx context.Context) (*grpc.ClientConn, bool, error) { + var err error + var conn *grpc.ClientConn + if ctx.Value("remote-client-error").(bool) { + err = mockErr + conn = &grpc.ClientConn{} + } + return conn, true, err + }).AnyTimes() + + mockCloudBackupClient.EXPECT().Restore(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, req *api.SdkCloudBackupRestoreRequest, opts ...grpc.CallOption) (*api.SdkCloudBackupRestoreResponse, error) { + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. SdkCloudBackupRestoreRequest: %+v", req) + if req.BackupId == "client-error" { + return nil, mockErr + } + + if req.BackupId == "snapshot-notfound" { + return nil, status.Errorf(codes.NotFound, "Snapshot not found") + } + + if req.BackupId == "ok" { + return &api.SdkCloudBackupRestoreResponse{ + RestoreVolumeId: req.BackupId, + TaskId: req.BackupId, + }, nil + } + + return &api.SdkCloudBackupRestoreResponse{}, nil + + }).AnyTimes() + + mockCloudBackupClient.EXPECT().Status(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, req *api.SdkCloudBackupStatusRequest, opts ...grpc.CallOption) (*api.SdkCloudBackupStatusResponse, error) { + if req.TaskId == "status-error" { + return nil, mockErr + } + + return &api.SdkCloudBackupStatusResponse{ + Statuses: map[string]*api.SdkCloudBackupStatus{ + req.TaskId: { + BackupId: req.TaskId, + Status: api.SdkCloudBackupStatusType_SdkCloudBackupStatusTypeDone, + StartTime: creationTime, + }, + }, + }, nil + + }).AnyTimes() + + tests := []struct { + name string + SnapshotName string + want *csi.CreateVolumeResponse + wantErr bool + }{ + + { + "snapshot not provided in volume source", + "nil", + nil, + true, + }, + { + "remote client connection failed", + "remote-client-error", + nil, + true, + }, + { + "snapshot id is blank", + "", + nil, + true, + }, + { + "failed to get credentials", + "cred-error", + nil, + true, + }, + { + "Cloud backup client not available", + "client-error", + nil, + true, + }, + { + "fail to get cloud snap not found", + "snapshot-notfound", + nil, + true, + }, + { + "Snapshot restored without error", + "ok", + &csi.CreateVolumeResponse{ + Volume: &csi.Volume{ + VolumeId: "ok", + ContentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Snapshot{ + Snapshot: &csi.VolumeContentSource_SnapshotSource{ + SnapshotId: "ok", + }, + }, + }, + }, + }, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + snapshot := &csi.VolumeContentSource_SnapshotSource{ + SnapshotId: tt.SnapshotName, + } + if tt.SnapshotName == "nil" { + snapshot = nil + } + + req := &csi.CreateVolumeRequest{ + Name: mockVolumeName, + VolumeContentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Snapshot{ + Snapshot: snapshot, + }, + }, + } + specLabels := []string{} + if tt.SnapshotName != "cred-error" { + specLabels = append(specLabels, osdSnapshotCredentialIDKey+"=mockcredid") + specLabels = append(specLabels, osdSnapshotLabelsTypeKey+"=cloud") + req.Parameters = map[string]string{ + api.SpecLabels: strings.Join(specLabels, ","), + } + } else { + req.Parameters = nil + } + + s := &OsdCsiServer{ + specHandler: spec.NewSpecHandler(), + mu: sync.Mutex{}, + cloudBackupClient: func(cc grpc.ClientConnInterface) api.OpenStorageCloudBackupClient { + return mockCloudBackupClient + }, + roundRobinBalancer: mockRoundRobinBalancer, + } + + doClientErr := tt.SnapshotName == "remote-client-error" + + ctx = context.WithValue(ctx, "remote-client-error", doClientErr) + + got, err := s.CreateVolume(ctx, req) + if (err != nil) != tt.wantErr { + t.Errorf("OsdCsiServer.CreateSnapshot() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("OsdCsiServer.CreateSnapshot() = %v, want %v", got, tt.want) + } + }) + } +}