Skip to content

Commit

Permalink
test: add unit test for restore cloud snapshot
Browse files Browse the repository at this point in the history
Signed-off-by: Shivanjan Chakravorty <[email protected]>
  • Loading branch information
Glitchfix committed Jun 7, 2024
1 parent 7cab17a commit d32c997
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 17 deletions.
42 changes: 25 additions & 17 deletions csi/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package csi

import (
"encoding/json"
"errors"
"fmt"
"math"
"reflect"
Expand Down Expand Up @@ -478,11 +477,6 @@ func (s *OsdCsiServer) CreateVolume(
return nil, status.Error(codes.InvalidArgument, "Name must be provided")
}

if req.VolumeContentSource != nil && req.VolumeContentSource.GetSnapshot() != nil {
clogger.WithContext(ctx).Infof("csi.CreateVolume restoring snapshot to Volume: %s", req.GetName())
return s.restoreSnapshot(ctx, req)
}

// Get parameters
spec, locator, source, err := s.specHandler.SpecFromOpts(req.GetParameters())
if err != nil {
Expand All @@ -491,6 +485,16 @@ func (s *OsdCsiServer) CreateVolume(
return nil, status.Error(codes.InvalidArgument, e)
}

// Check ID is valid with the specified volume capabilities
snapshotType, ok := locator.VolumeLabels[osdSnapshotLabelsTypeKey]
if !ok {
snapshotType = DriverTypeLocal
}
if DriverTypeCloud == snapshotType && req.VolumeContentSource != nil && req.VolumeContentSource.GetSnapshot() != nil {
clogger.WithContext(ctx).Infof("csi.CreateVolume restoring snapshot to Volume: %s", req.GetName())
return s.restoreSnapshot(ctx, req)
}

if spec.IsPureVolume() {
err = validateCreateVolumeCapabilitiesPure(req.GetVolumeCapabilities(), spec.GetProxySpec())
if err != nil {
Expand Down Expand Up @@ -641,9 +645,10 @@ func (s *OsdCsiServer) CreateVolume(
func (s *OsdCsiServer) restoreSnapshot(ctx context.Context,
req *csi.CreateVolumeRequest,
) (resp *csi.CreateVolumeResponse, err error) {
clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s", req.GetName(), req.VolumeContentSource.GetSnapshot())
snapshot := req.VolumeContentSource.GetSnapshot()
if snapshot == nil {
return nil, errors.New("snapshot fetched is not accurate or does not exist")
return nil, status.Error(codes.NotFound, "snapshot fetched is not accurate or does not exist")
}

cloudBackupClient, err := s.getCloudBackupClient(ctx)
Expand All @@ -655,7 +660,7 @@ func (s *OsdCsiServer) restoreSnapshot(ctx context.Context,

csiSnapshotID := snapshot.GetSnapshotId()
if len(csiSnapshotID) == 0 {
return nil, status.Error(codes.InvalidArgument, "Snapshot id must be provided")
return nil, status.Error(codes.InvalidArgument, "snapshot id must be provided")
}

var backupStatus *api.SdkCloudBackupStatusResponse
Expand All @@ -672,9 +677,10 @@ func (s *OsdCsiServer) restoreSnapshot(ctx context.Context,
}

if (sdk.IsErrorNotFound(err) && !cloudBackupDriverDisabled && cloudBackupClientAvailable) || !isSnapshotIDPresentInCloud {

clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a local backup", req.GetName(), csiSnapshotID)
return
}
clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a cloud backup", req.GetName(), csiSnapshotID)

resp, err = s.restoreCloudSnapshot(ctx, req)
return
Expand All @@ -689,16 +695,18 @@ func (s *OsdCsiServer) restoreCloudSnapshot(ctx context.Context,
return nil, err
}
// Get parameters
_, locator, _, err := s.specHandler.SpecFromOpts(req.GetParameters())
if err != nil {
e := fmt.Sprintf("Unable to get parameters: %s\n", err.Error())
clogger.WithContext(ctx).Errorln(e)
return nil, status.Error(codes.InvalidArgument, e)
}
_, locator, _, _ := s.specHandler.SpecFromOpts(req.GetParameters())

csiSnapshotID := req.VolumeContentSource.GetSnapshot().GetSnapshotId()

credentialID := locator.VolumeLabels[osdSnapshotCredentialIDKey]
clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s is a cloud backup with labels %+v", req.GetName(), csiSnapshotID, locator.VolumeLabels)

credentialID, ok := locator.VolumeLabels[osdSnapshotCredentialIDKey]
if !ok {
e := fmt.Sprintf("csi.CreateVolume is restoring snapshot. Volume: %s Snapshot: %s credentials missing", req.GetName(), csiSnapshotID)
clogger.WithContext(ctx).Infof(e)
return nil, status.Error(codes.InvalidArgument, e)
}

snapResp, err := cloudBackupClient.Restore(ctx, &api.SdkCloudBackupRestoreRequest{
BackupId: csiSnapshotID,
Expand All @@ -708,7 +716,7 @@ func (s *OsdCsiServer) restoreCloudSnapshot(ctx context.Context,
CredentialId: credentialID,
})
if nil != err {
return nil, err
return nil, status.Error(codes.Internal, err.Error())
}

resp = &csi.CreateVolumeResponse{
Expand Down
182 changes: 182 additions & 0 deletions csi/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3734,3 +3734,185 @@ func TestOsdCsiServer_DeleteCloudSnapshot(t *testing.T) {
})
}
}

func TestOsdCsiServer_RestoreCloudSnapshot(t *testing.T) {

ctrl := gomock.NewController(t)
defer ctrl.Finish()

mockCloudBackupClient := mock.NewMockOpenStorageCloudBackupClient(ctrl)

ctx := context.Background()

mockErr := errors.New("MOCK ERROR")
creationTime := timestamppb.Now()

mockVolumeName := "mock-volume-id"

mockRoundRobinBalancer := mockLoadBalancer.NewMockBalancer(ctrl)
mockRoundRobinBalancer.EXPECT().GetRemoteNodeConnection(gomock.Any()).DoAndReturn(
func(ctx context.Context) (*grpc.ClientConn, bool, error) {
var err error
var conn *grpc.ClientConn
if ctx.Value("remote-client-error").(bool) {
err = mockErr
conn = &grpc.ClientConn{}
}
return conn, true, err
}).AnyTimes()

mockCloudBackupClient.EXPECT().Restore(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, req *api.SdkCloudBackupRestoreRequest, opts ...grpc.CallOption) (*api.SdkCloudBackupRestoreResponse, error) {
clogger.WithContext(ctx).Infof("csi.CreateVolume is restoring snapshot. SdkCloudBackupRestoreRequest: %+v", req)
if req.BackupId == "client-error" {
return nil, mockErr
}

if req.BackupId == "snapshot-notfound" {
return nil, status.Errorf(codes.NotFound, "Snapshot not found")
}

if req.BackupId == "ok" {
return &api.SdkCloudBackupRestoreResponse{
RestoreVolumeId: req.BackupId,
TaskId: req.BackupId,
}, nil
}

return &api.SdkCloudBackupRestoreResponse{}, nil

}).AnyTimes()

mockCloudBackupClient.EXPECT().Status(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, req *api.SdkCloudBackupStatusRequest, opts ...grpc.CallOption) (*api.SdkCloudBackupStatusResponse, error) {
if req.TaskId == "status-error" {
return nil, mockErr
}

return &api.SdkCloudBackupStatusResponse{
Statuses: map[string]*api.SdkCloudBackupStatus{
req.TaskId: {
BackupId: req.TaskId,
Status: api.SdkCloudBackupStatusType_SdkCloudBackupStatusTypeDone,
StartTime: creationTime,
},
},
}, nil

}).AnyTimes()

tests := []struct {
name string
SnapshotName string
want *csi.CreateVolumeResponse
wantErr bool
}{

{
"snapshot not provided in volume source",
"nil",
nil,
true,
},
{
"remote client connection failed",
"remote-client-error",
nil,
true,
},
{
"snapshot id is blank",
"",
nil,
true,
},
{
"failed to get credentials",
"cred-error",
nil,
true,
},
{
"Cloud backup client not available",
"client-error",
nil,
true,
},
{
"fail to get cloud snap not found",
"snapshot-notfound",
nil,
true,
},
{
"Snapshot restored without error",
"ok",
&csi.CreateVolumeResponse{
Volume: &csi.Volume{
VolumeId: "ok",
ContentSource: &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: &csi.VolumeContentSource_SnapshotSource{
SnapshotId: "ok",
},
},
},
},
},
false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

snapshot := &csi.VolumeContentSource_SnapshotSource{
SnapshotId: tt.SnapshotName,
}
if tt.SnapshotName == "nil" {
snapshot = nil
}

req := &csi.CreateVolumeRequest{
Name: mockVolumeName,
VolumeContentSource: &csi.VolumeContentSource{
Type: &csi.VolumeContentSource_Snapshot{
Snapshot: snapshot,
},
},
}
specLabels := []string{}
if tt.SnapshotName != "cred-error" {
specLabels = append(specLabels, osdSnapshotCredentialIDKey+"=mockcredid")
specLabels = append(specLabels, osdSnapshotLabelsTypeKey+"=cloud")
req.Parameters = map[string]string{
api.SpecLabels: strings.Join(specLabels, ","),
}
} else {
req.Parameters = nil
}

s := &OsdCsiServer{
specHandler: spec.NewSpecHandler(),
mu: sync.Mutex{},
cloudBackupClient: func(cc grpc.ClientConnInterface) api.OpenStorageCloudBackupClient {
return mockCloudBackupClient
},
roundRobinBalancer: mockRoundRobinBalancer,
}

doClientErr := tt.SnapshotName == "remote-client-error"

ctx = context.WithValue(ctx, "remote-client-error", doClientErr)

got, err := s.CreateVolume(ctx, req)
if (err != nil) != tt.wantErr {
t.Errorf("OsdCsiServer.CreateSnapshot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("OsdCsiServer.CreateSnapshot() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit d32c997

Please sign in to comment.