diff --git a/internal/csi-addons/rbd/replication.go b/internal/csi-addons/rbd/replication.go index a949a383c86..f263d3e2b1e 100644 --- a/internal/csi-addons/rbd/replication.go +++ b/internal/csi-addons/rbd/replication.go @@ -26,6 +26,7 @@ import ( "strings" "time" + csicommon "github.com/ceph/ceph-csi/internal/csi-common" corerbd "github.com/ceph/ceph-csi/internal/rbd" "github.com/ceph/ceph-csi/internal/util" "github.com/ceph/ceph-csi/internal/util/log" @@ -247,7 +248,7 @@ func validateSchedulingInterval(interval string) error { func (rs *ReplicationServer) EnableVolumeReplication(ctx context.Context, req *replication.EnableVolumeReplicationRequest, ) (*replication.EnableVolumeReplicationResponse, error) { - volumeID := req.GetVolumeId() + volumeID := csicommon.GetIDFromReplication(req) if volumeID == "" { return nil, status.Error(codes.InvalidArgument, "empty volume ID in request") } @@ -329,7 +330,7 @@ func (rs *ReplicationServer) EnableVolumeReplication(ctx context.Context, func (rs *ReplicationServer) DisableVolumeReplication(ctx context.Context, req *replication.DisableVolumeReplicationRequest, ) (*replication.DisableVolumeReplicationResponse, error) { - volumeID := req.GetVolumeId() + volumeID := csicommon.GetIDFromReplication(req) if volumeID == "" { return nil, status.Error(codes.InvalidArgument, "empty volume ID in request") } @@ -404,7 +405,7 @@ func (rs *ReplicationServer) DisableVolumeReplication(ctx context.Context, func (rs *ReplicationServer) PromoteVolume(ctx context.Context, req *replication.PromoteVolumeRequest, ) (*replication.PromoteVolumeResponse, error) { - volumeID := req.GetVolumeId() + volumeID := csicommon.GetIDFromReplication(req) if volumeID == "" { return nil, status.Error(codes.InvalidArgument, "empty volume ID in request") } @@ -504,7 +505,7 @@ func (rs *ReplicationServer) PromoteVolume(ctx context.Context, func (rs *ReplicationServer) DemoteVolume(ctx context.Context, req *replication.DemoteVolumeRequest, ) (*replication.DemoteVolumeResponse, error) { - volumeID := req.GetVolumeId() + volumeID := csicommon.GetIDFromReplication(req) if volumeID == "" { return nil, status.Error(codes.InvalidArgument, "empty volume ID in request") } @@ -622,7 +623,7 @@ func checkRemoteSiteStatus(ctx context.Context, mirrorStatus *librbd.GlobalMirro func (rs *ReplicationServer) ResyncVolume(ctx context.Context, req *replication.ResyncVolumeRequest, ) (*replication.ResyncVolumeResponse, error) { - volumeID := req.GetVolumeId() + volumeID := csicommon.GetIDFromReplication(req) if volumeID == "" { return nil, status.Error(codes.InvalidArgument, "empty volume ID in request") } @@ -836,7 +837,7 @@ func getGRPCError(err error) error { func (rs *ReplicationServer) GetVolumeReplicationInfo(ctx context.Context, req *replication.GetVolumeReplicationInfoRequest, ) (*replication.GetVolumeReplicationInfoResponse, error) { - volumeID := req.GetVolumeId() + volumeID := csicommon.GetIDFromReplication(req) if volumeID == "" { return nil, status.Error(codes.InvalidArgument, "empty volume ID in request") } diff --git a/internal/csi-common/utils.go b/internal/csi-common/utils.go index f844abca74e..3f2dadd126a 100644 --- a/internal/csi-common/utils.go +++ b/internal/csi-common/utils.go @@ -116,6 +116,43 @@ func NewMiddlewareServerOption() grpc.ServerOption { return grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(middleWare...)) } +// GetIDFromReplication returns the volumeID for Replication. +func GetIDFromReplication(req interface{}) string { + getID := func(r interface { + GetVolumeId() string + GetReplicationSource() *replication.ReplicationSource + }, + ) string { + reqID := "" + src := r.GetReplicationSource() + if src != nil && src.GetVolume() != nil { + reqID = src.GetVolume().GetVolumeId() + } + if reqID == "" { + reqID = r.GetVolumeId() //nolint:nolintlint,staticcheck // req.VolumeId is deprecated + } + + return reqID + } + + switch r := req.(type) { + case *replication.EnableVolumeReplicationRequest: + return getID(r) + case *replication.DisableVolumeReplicationRequest: + return getID(r) + case *replication.PromoteVolumeRequest: + return getID(r) + case *replication.DemoteVolumeRequest: + return getID(r) + case *replication.ResyncVolumeRequest: + return getID(r) + case *replication.GetVolumeReplicationInfoRequest: + return getID(r) + default: + return "" + } +} + func getReqID(req interface{}) string { // if req is nil empty string will be returned reqID := "" @@ -156,17 +193,17 @@ func getReqID(req interface{}) string { // Replication case *replication.EnableVolumeReplicationRequest: - reqID = r.GetVolumeId() + reqID = GetIDFromReplication(r) case *replication.DisableVolumeReplicationRequest: - reqID = r.GetVolumeId() + reqID = GetIDFromReplication(r) case *replication.PromoteVolumeRequest: - reqID = r.GetVolumeId() + reqID = GetIDFromReplication(r) case *replication.DemoteVolumeRequest: - reqID = r.GetVolumeId() + reqID = GetIDFromReplication(r) case *replication.ResyncVolumeRequest: - reqID = r.GetVolumeId() + reqID = GetIDFromReplication(r) case *replication.GetVolumeReplicationInfoRequest: - reqID = r.GetVolumeId() + reqID = GetIDFromReplication(r) } return reqID diff --git a/internal/csi-common/utils_test.go b/internal/csi-common/utils_test.go index e6991fc3e31..a3c230d0d1a 100644 --- a/internal/csi-common/utils_test.go +++ b/internal/csi-common/utils_test.go @@ -94,6 +94,62 @@ func TestGetReqID(t *testing.T) { &replication.GetVolumeReplicationInfoRequest{ VolumeId: fakeID, }, + + // volumeId is set in ReplicationSource + &replication.EnableVolumeReplicationRequest{ + ReplicationSource: &replication.ReplicationSource{ + Type: &replication.ReplicationSource_Volume{ + Volume: &replication.ReplicationSource_VolumeSource{ + VolumeId: fakeID, + }, + }, + }, + }, + &replication.DisableVolumeReplicationRequest{ + ReplicationSource: &replication.ReplicationSource{ + Type: &replication.ReplicationSource_Volume{ + Volume: &replication.ReplicationSource_VolumeSource{ + VolumeId: fakeID, + }, + }, + }, + }, + &replication.PromoteVolumeRequest{ + ReplicationSource: &replication.ReplicationSource{ + Type: &replication.ReplicationSource_Volume{ + Volume: &replication.ReplicationSource_VolumeSource{ + VolumeId: fakeID, + }, + }, + }, + }, + &replication.DemoteVolumeRequest{ + ReplicationSource: &replication.ReplicationSource{ + Type: &replication.ReplicationSource_Volume{ + Volume: &replication.ReplicationSource_VolumeSource{ + VolumeId: fakeID, + }, + }, + }, + }, + &replication.ResyncVolumeRequest{ + ReplicationSource: &replication.ReplicationSource{ + Type: &replication.ReplicationSource_Volume{ + Volume: &replication.ReplicationSource_VolumeSource{ + VolumeId: fakeID, + }, + }, + }, + }, + &replication.GetVolumeReplicationInfoRequest{ + ReplicationSource: &replication.ReplicationSource{ + Type: &replication.ReplicationSource_Volume{ + Volume: &replication.ReplicationSource_VolumeSource{ + VolumeId: fakeID, + }, + }, + }, + }, } for _, r := range req { if got := getReqID(r); got != fakeID {