diff --git a/csi/controller.go b/csi/controller.go index 44fced79f..a9a3aa75f 100644 --- a/csi/controller.go +++ b/csi/controller.go @@ -18,7 +18,6 @@ package csi import ( "encoding/json" - "errors" "fmt" "math" "reflect" @@ -478,11 +477,6 @@ func (s *OsdCsiServer) CreateVolume( return nil, status.Error(codes.InvalidArgument, "Name must be provided") } - if req.VolumeContentSource != nil && req.VolumeContentSource.GetSnapshot() != nil { - clogger.WithContext(ctx).Infof("csi.CreateVolume restoring snapshot to Volume: %s", req.GetName()) - return s.restoreSnapshot(ctx, req) - } - // Get parameters spec, locator, source, err := s.specHandler.SpecFromOpts(req.GetParameters()) if err != nil { @@ -491,6 +485,16 @@ func (s *OsdCsiServer) CreateVolume( return nil, status.Error(codes.InvalidArgument, e) } + // Check ID is valid with the specified volume capabilities + snapshotType, ok := locator.VolumeLabels[osdSnapshotLabelsTypeKey] + if !ok { + snapshotType = DriverTypeLocal + } + if DriverTypeCloud == snapshotType && req.VolumeContentSource != nil && req.VolumeContentSource.GetSnapshot() != nil { + clogger.WithContext(ctx).Infof("csi.CreateVolume restoring snapshot to Volume: %s", req.GetName()) + return s.restoreSnapshot(ctx, req) + } + if spec.IsPureVolume() { err = validateCreateVolumeCapabilitiesPure(req.GetVolumeCapabilities(), spec.GetProxySpec()) if err != nil { @@ -641,9 +645,10 @@ func (s *OsdCsiServer) CreateVolume( func (s *OsdCsiServer) restoreSnapshot(ctx context.Context, req *csi.CreateVolumeRequest, ) (resp *csi.CreateVolumeResponse, err error) { + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s", req.GetName(), req.VolumeContentSource.GetSnapshot()) snapshot := req.VolumeContentSource.GetSnapshot() if snapshot == nil { - return nil, errors.New("snapshot fetched is not accurate or does not exist") + return nil, status.Error(codes.NotFound, "snapshot fetched is not accurate or does not exist") } cloudBackupClient, err := s.getCloudBackupClient(ctx) @@ -655,7 +660,7 @@ func (s *OsdCsiServer) restoreSnapshot(ctx context.Context, csiSnapshotID := snapshot.GetSnapshotId() if len(csiSnapshotID) == 0 { - return nil, status.Error(codes.InvalidArgument, "Snapshot id must be provided") + return nil, status.Error(codes.InvalidArgument, "snapshot id must be provided") } var backupStatus *api.SdkCloudBackupStatusResponse @@ -672,9 +677,10 @@ func (s *OsdCsiServer) restoreSnapshot(ctx context.Context, } if (sdk.IsErrorNotFound(err) && !cloudBackupDriverDisabled && cloudBackupClientAvailable) || !isSnapshotIDPresentInCloud { - + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a local backup", req.GetName(), csiSnapshotID) return } + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a cloud backup", req.GetName(), csiSnapshotID) resp, err = s.restoreCloudSnapshot(ctx, req) return @@ -689,16 +695,18 @@ func (s *OsdCsiServer) restoreCloudSnapshot(ctx context.Context, return nil, err } // Get parameters - _, locator, _, err := s.specHandler.SpecFromOpts(req.GetParameters()) - if err != nil { - e := fmt.Sprintf("Unable to get parameters: %s\n", err.Error()) - clogger.WithContext(ctx).Errorln(e) - return nil, status.Error(codes.InvalidArgument, e) - } + _, locator, _, _ := s.specHandler.SpecFromOpts(req.GetParameters()) csiSnapshotID := req.VolumeContentSource.GetSnapshot().GetSnapshotId() - credentialID := locator.VolumeLabels[osdSnapshotCredentialIDKey] + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a cloud backup with labels %+v", req.GetName(), csiSnapshotID, locator.VolumeLabels) + + credentialID, ok := locator.VolumeLabels[osdSnapshotCredentialIDKey] + if !ok { + e := fmt.Sprintf("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s credentials missing", req.GetName(), csiSnapshotID) + clogger.WithContext(ctx).Infof(e) + return nil, status.Error(codes.InvalidArgument, e) + } snapResp, err := cloudBackupClient.Restore(ctx, &api.SdkCloudBackupRestoreRequest{ BackupId: csiSnapshotID, @@ -708,7 +716,7 @@ func (s *OsdCsiServer) restoreCloudSnapshot(ctx context.Context, CredentialId: credentialID, }) if nil != err { - return nil, err + return nil, status.Error(codes.Internal, err.Error()) } resp = &csi.CreateVolumeResponse{ diff --git a/csi/controller_test.go b/csi/controller_test.go index 3c8d95649..a33fc9ea1 100644 --- a/csi/controller_test.go +++ b/csi/controller_test.go @@ -3734,3 +3734,185 @@ func TestOsdCsiServer_DeleteCloudSnapshot(t *testing.T) { }) } } + +func TestOsdCsiServer_RestoreCloudSnapshot(t *testing.T) { + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCloudBackupClient := mock.NewMockOpenStorageCloudBackupClient(ctrl) + + ctx := context.Background() + + mockErr := errors.New("MOCK ERROR") + creationTime := timestamppb.Now() + + mockVolumeName := "mock-volume-id" + + mockRoundRobinBalancer := mockLoadBalancer.NewMockBalancer(ctrl) + mockRoundRobinBalancer.EXPECT().GetRemoteNodeConnection(gomock.Any()).DoAndReturn( + func(ctx context.Context) (*grpc.ClientConn, bool, error) { + var err error + var conn *grpc.ClientConn + if ctx.Value("remote-client-error").(bool) { + err = mockErr + conn = &grpc.ClientConn{} + } + return conn, true, err + }).AnyTimes() + + mockCloudBackupClient.EXPECT().Restore(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, req *api.SdkCloudBackupRestoreRequest, opts ...grpc.CallOption) (*api.SdkCloudBackupRestoreResponse, error) { + clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. SdkCloudBackupRestoreRequest: %+v", req) + if req.BackupId == "client-error" { + return nil, mockErr + } + + if req.BackupId == "snapshot-notfound" { + return nil, status.Errorf(codes.NotFound, "Snapshot not found") + } + + if req.BackupId == "ok" { + return &api.SdkCloudBackupRestoreResponse{ + RestoreVolumeId: req.BackupId, + TaskId: req.BackupId, + }, nil + } + + return &api.SdkCloudBackupRestoreResponse{}, nil + + }).AnyTimes() + + mockCloudBackupClient.EXPECT().Status(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, req *api.SdkCloudBackupStatusRequest, opts ...grpc.CallOption) (*api.SdkCloudBackupStatusResponse, error) { + if req.TaskId == "status-error" { + return nil, mockErr + } + + return &api.SdkCloudBackupStatusResponse{ + Statuses: map[string]*api.SdkCloudBackupStatus{ + req.TaskId: { + BackupId: req.TaskId, + Status: api.SdkCloudBackupStatusType_SdkCloudBackupStatusTypeDone, + StartTime: creationTime, + }, + }, + }, nil + + }).AnyTimes() + + tests := []struct { + name string + SnapshotName string + want *csi.CreateVolumeResponse + wantErr bool + }{ + + { + "snapshot not provided in volume source", + "nil", + nil, + true, + }, + { + "remote client connection failed", + "remote-client-error", + nil, + true, + }, + { + "snapshot id is blank", + "", + nil, + true, + }, + { + "failed to get credentials", + "cred-error", + nil, + true, + }, + { + "Cloud backup client not available", + "client-error", + nil, + true, + }, + { + "fail to get cloud snap not found", + "snapshot-notfound", + nil, + true, + }, + { + "Snapshot restored without error", + "ok", + &csi.CreateVolumeResponse{ + Volume: &csi.Volume{ + VolumeId: "ok", + ContentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Snapshot{ + Snapshot: &csi.VolumeContentSource_SnapshotSource{ + SnapshotId: "ok", + }, + }, + }, + }, + }, + false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + snapshot := &csi.VolumeContentSource_SnapshotSource{ + SnapshotId: tt.SnapshotName, + } + if tt.SnapshotName == "nil" { + snapshot = nil + } + + req := &csi.CreateVolumeRequest{ + Name: mockVolumeName, + VolumeContentSource: &csi.VolumeContentSource{ + Type: &csi.VolumeContentSource_Snapshot{ + Snapshot: snapshot, + }, + }, + } + specLabels := []string{} + if tt.SnapshotName != "cred-error" { + specLabels = append(specLabels, osdSnapshotCredentialIDKey+"=mockcredid") + specLabels = append(specLabels, osdSnapshotLabelsTypeKey+"=cloud") + req.Parameters = map[string]string{ + api.SpecLabels: strings.Join(specLabels, ","), + } + } else { + req.Parameters = nil + } + + s := &OsdCsiServer{ + specHandler: spec.NewSpecHandler(), + mu: sync.Mutex{}, + cloudBackupClient: func(cc grpc.ClientConnInterface) api.OpenStorageCloudBackupClient { + return mockCloudBackupClient + }, + roundRobinBalancer: mockRoundRobinBalancer, + } + + doClientErr := tt.SnapshotName == "remote-client-error" + + ctx = context.WithValue(ctx, "remote-client-error", doClientErr) + + got, err := s.CreateVolume(ctx, req) + if (err != nil) != tt.wantErr { + t.Errorf("OsdCsiServer.CreateSnapshot() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("OsdCsiServer.CreateSnapshot() = %v, want %v", got, tt.want) + } + }) + } +}