From 1270096999c7ab8737a4e8a19aa3d7a68bea4bab Mon Sep 17 00:00:00 2001 From: Rodney Osodo <28790446+rodneyosodo@users.noreply.github.com> Date: Wed, 15 Nov 2023 15:21:53 +0300 Subject: [PATCH] Fix failing tests in SDK Fix failing tests in `sdk` package. --- internal/api/common.go | 4 +- internal/groups/api/requests.go | 3 + internal/groups/service.go | 4 +- pkg/sdk/go/certs_test.go | 71 ++++++-- pkg/sdk/go/channels_test.go | 283 +++++++++++++---------------- pkg/sdk/go/consumers_test.go | 35 +++- pkg/sdk/go/groups_test.go | 201 ++++++++++++-------- pkg/sdk/go/health_test.go | 2 +- pkg/sdk/go/setup_test.go | 2 +- pkg/sdk/go/things_test.go | 313 +++++++++++++++++--------------- pkg/sdk/go/tokens_test.go | 54 +++--- pkg/sdk/go/users_test.go | 254 +++++++++++++++++++------- things/service.go | 2 +- 13 files changed, 733 insertions(+), 495 deletions(-) diff --git a/internal/api/common.go b/internal/api/common.go index 3cf71522a..a84ef6d8a 100644 --- a/internal/api/common.go +++ b/internal/api/common.go @@ -109,9 +109,11 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) { errors.Contains(err, apiutil.ErrEmptyList), errors.Contains(err, apiutil.ErrMissingMemberType), errors.Contains(err, apiutil.ErrMissingMemberKind), + errors.Contains(err, apiutil.ErrLimitSize), errors.Contains(err, apiutil.ErrNameSize): w.WriteHeader(http.StatusBadRequest) - case errors.Contains(err, errors.ErrAuthentication): + case errors.Contains(err, errors.ErrAuthentication), + errors.Contains(err, apiutil.ErrBearerToken): w.WriteHeader(http.StatusUnauthorized) case errors.Contains(err, errors.ErrNotFound): w.WriteHeader(http.StatusNotFound) diff --git a/internal/groups/api/requests.go b/internal/groups/api/requests.go index b4c0d6f89..0907dc067 100644 --- a/internal/groups/api/requests.go +++ b/internal/groups/api/requests.go @@ -73,6 +73,9 @@ func (req listGroupsReq) validate() error { if req.Level < mggroups.MinLevel || req.Level > mggroups.MaxLevel { return apiutil.ErrInvalidLevel } + if req.Limit > api.MaxLimitSize || req.Limit < 1 { + return apiutil.ErrLimitSize + } return nil } diff --git a/internal/groups/service.go b/internal/groups/service.go index 911d7755d..5dc64fa6a 100644 --- a/internal/groups/service.go +++ b/internal/groups/service.go @@ -588,7 +588,7 @@ func (svc service) authorize(ctx context.Context, subjectType, subject, permissi } res, err := svc.auth.Authorize(ctx, req) if err != nil { - return "", errors.Wrap(errors.ErrAuthorization, err) + return "", err } if !res.GetAuthorized() { return "", errors.ErrAuthorization @@ -607,7 +607,7 @@ func (svc service) authorizeKind(ctx context.Context, subjectType, subjectKind, } res, err := svc.auth.Authorize(ctx, req) if err != nil { - return "", errors.Wrap(errors.ErrAuthorization, err) + return "", err } if !res.GetAuthorized() { return "", errors.ErrAuthorization diff --git a/pkg/sdk/go/certs_test.go b/pkg/sdk/go/certs_test.go index a8c7c1d2c..4d022dec7 100644 --- a/pkg/sdk/go/certs_test.go +++ b/pkg/sdk/go/certs_test.go @@ -10,14 +10,19 @@ import ( "testing" "time" + "github.com/absmach/magistrala" + authmocks "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/certs" httpapi "github.com/absmach/magistrala/certs/api" "github.com/absmach/magistrala/certs/mocks" "github.com/absmach/magistrala/internal/apiutil" "github.com/absmach/magistrala/logger" + "github.com/absmach/magistrala/pkg/clients" "github.com/absmach/magistrala/pkg/errors" sdk "github.com/absmach/magistrala/pkg/sdk/go" + thmocks "github.com/absmach/magistrala/things/mocks" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -31,8 +36,8 @@ var ( cfgSignHoursValid = "24h" ) -func newCertService() (certs.Service, error) { - server, _, _, auth := newThingsServer() +func newCertService() (certs.Service, *authmocks.Service, *thmocks.Repository, error) { + server, trepo, _, auth := newThingsServer() config := sdk.Config{ ThingsURL: server.URL, } @@ -42,17 +47,17 @@ func newCertService() (certs.Service, error) { tlsCert, caCert, err := certs.LoadCertificates(caPath, caKeyPath) if err != nil { - return nil, err + return nil, auth, trepo, err } authTimeout, err := time.ParseDuration(cfgAuthTimeout) if err != nil { - return nil, err + return nil, auth, trepo, err } pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout) - return certs.New(auth, repo, mgsdk, pki), nil + return certs.New(auth, repo, mgsdk, pki), auth, trepo, nil } func newCertServer(svc certs.Service) *httptest.Server { @@ -62,7 +67,7 @@ func newCertServer(svc certs.Service) *httptest.Server { } func TestIssueCert(t *testing.T) { - svc, err := newCertService() + svc, auth, trepo, err := newCertService() require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err)) ts := newCertServer(svc) defer ts.Close() @@ -86,21 +91,21 @@ func TestIssueCert(t *testing.T) { desc: "create new cert with thing id and duration", thingID: thingID, duration: "10h", - token: adminToken, + token: validToken, err: nil, }, { desc: "create new cert with empty thing id and duration", thingID: "", duration: "10h", - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest), }, { desc: "create new cert with invalid thing id and duration", thingID: "ah", duration: "10h", - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, certs.ErrFailedCertCreation), http.StatusInternalServerError), }, { @@ -128,29 +133,35 @@ func TestIssueCert(t *testing.T) { desc: "create new cert with invalid token", thingID: thingID, duration: "10h", - token: wrongValue, + token: authmocks.InvalidValue, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, errors.ErrAuthentication), http.StatusUnauthorized), }, { desc: "create new empty cert", thingID: "", duration: "", - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingID), http.StatusBadRequest), }, } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID}, tc.err) cert, err := mgsdk.IssueCert(tc.thingID, tc.duration, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if err == nil { assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc)) } + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() } } func TestViewCert(t *testing.T) { - svc, err := newCertService() + svc, auth, trepo, err := newCertService() require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err)) ts := newCertServer(svc) defer ts.Close() @@ -163,8 +174,14 @@ func TestViewCert(t *testing.T) { mgsdk := sdk.NewSDK(sdkConf) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil) cert, err := mgsdk.IssueCert(thingID, "10h", token) require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() cases := []struct { desc string @@ -197,16 +214,18 @@ func TestViewCert(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) cert, err := mgsdk.ViewCert(tc.certID, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if err == nil { assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc)) } + repoCall.Unset() } } func TestViewCertByThing(t *testing.T) { - svc, err := newCertService() + svc, auth, trepo, err := newCertService() require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err)) ts := newCertServer(svc) defer ts.Close() @@ -219,8 +238,14 @@ func TestViewCertByThing(t *testing.T) { mgsdk := sdk.NewSDK(sdkConf) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil) _, err = mgsdk.IssueCert(thingID, "10h", token) require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() cases := []struct { desc string @@ -253,16 +278,18 @@ func TestViewCertByThing(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) cert, err := mgsdk.ViewCertByThing(tc.thingID, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if err == nil { assert.NotEmpty(t, cert, fmt.Sprintf("%s: got empty cert", tc.desc)) } + repoCall.Unset() } } func TestRevokeCert(t *testing.T) { - svc, err := newCertService() + svc, auth, trepo, err := newCertService() require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err)) ts := newCertServer(svc) defer ts.Close() @@ -275,8 +302,14 @@ func TestRevokeCert(t *testing.T) { mgsdk := sdk.NewSDK(sdkConf) - _, err = mgsdk.IssueCert(thingID, "10h", adminToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID}, nil) + _, err = mgsdk.IssueCert(thingID, "10h", validToken) require.Nil(t, err, fmt.Sprintf("unexpected error during creating cert: %s", err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() cases := []struct { desc string @@ -287,7 +320,7 @@ func TestRevokeCert(t *testing.T) { { desc: "revoke cert with invalid token", thingID: thingID, - token: wrongValue, + token: authmocks.InvalidValue, err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, { @@ -323,10 +356,16 @@ func TestRevokeCert(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID}, nil) response, err := mgsdk.RevokeCert(tc.thingID, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if err == nil { assert.NotEmpty(t, response, fmt.Sprintf("%s: got empty revocation time", tc.desc)) } + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() } } diff --git a/pkg/sdk/go/channels_test.go b/pkg/sdk/go/channels_test.go index 292e6ae36..6fc5bacac 100644 --- a/pkg/sdk/go/channels_test.go +++ b/pkg/sdk/go/channels_test.go @@ -7,13 +7,15 @@ import ( "fmt" "net/http" "net/http/httptest" + "reflect" "testing" "time" + "github.com/absmach/magistrala" authmocks "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/apiutil" "github.com/absmach/magistrala/internal/groups" - gmocks "github.com/absmach/magistrala/internal/groups/mocks" + "github.com/absmach/magistrala/internal/groups/mocks" "github.com/absmach/magistrala/internal/testsutil" mglog "github.com/absmach/magistrala/logger" mgclients "github.com/absmach/magistrala/pkg/clients" @@ -22,30 +24,30 @@ import ( sdk "github.com/absmach/magistrala/pkg/sdk/go" "github.com/absmach/magistrala/things" api "github.com/absmach/magistrala/things/api/http" - "github.com/absmach/magistrala/things/mocks" + thmocks "github.com/absmach/magistrala/things/mocks" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) -func newChannelsServer() (*httptest.Server, *mocks.Repository, *gmocks.Repository, *authmocks.Service) { - cRepo := new(mocks.Repository) - gRepo := new(gmocks.Repository) - thingCache := mocks.NewCache() +func newChannelsServer() (*httptest.Server, *mocks.Repository, *authmocks.Service) { + cRepo := new(thmocks.Repository) + grepo := new(mocks.Repository) + thingCache := thmocks.NewCache() auth := new(authmocks.Service) - csvc := things.NewService(auth, cRepo, gRepo, thingCache, idProvider) - gsvc := groups.NewService(gRepo, idProvider, auth) + csvc := things.NewService(auth, cRepo, grepo, thingCache, idProvider) + gsvc := groups.NewService(grepo, idProvider, auth) logger := mglog.NewMock() mux := chi.NewRouter() api.MakeHandler(csvc, gsvc, mux, logger, "") - return httptest.NewServer(mux), cRepo, gRepo, auth + return httptest.NewServer(mux), grepo, auth } func TestCreateChannel(t *testing.T) { - ts, _, gRepo, _ := newChannelsServer() + ts, grepo, auth := newChannelsServer() defer ts.Close() channel := sdk.Channel{ @@ -73,6 +75,7 @@ func TestCreateChannel(t *testing.T) { { desc: "create channel with existing name", channel: channel, + token: token, err: nil, }, { @@ -93,32 +96,36 @@ func TestCreateChannel(t *testing.T) { ParentID: testsutil.GenerateUUID(t), Status: mgclients.EnabledStatus.String(), }, - err: nil, + token: token, + err: nil, }, { desc: "create channel with invalid parent", channel: sdk.Channel{ Name: gName, - ParentID: gmocks.WrongID, + ParentID: mocks.WrongID, Status: mgclients.EnabledStatus.String(), }, - err: errors.NewSDKErrorWithStatus(errors.ErrCreateEntity, http.StatusInternalServerError), + token: token, + err: errors.NewSDKErrorWithStatus(errors.ErrCreateEntity, http.StatusInternalServerError), }, { desc: "create channel with invalid owner", channel: sdk.Channel{ Name: gName, - OwnerID: gmocks.WrongID, + OwnerID: mocks.WrongID, Status: mgclients.EnabledStatus.String(), }, - err: errors.NewSDKErrorWithStatus(sdk.ErrFailedCreation, http.StatusInternalServerError), + token: token, + err: errors.NewSDKErrorWithStatus(sdk.ErrFailedCreation, http.StatusInternalServerError), }, { desc: "create channel with missing name", channel: sdk.Channel{ Status: mgclients.EnabledStatus.String(), }, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest), + token: token, + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest), }, { desc: "create a channel with every field defined", @@ -138,90 +145,26 @@ func TestCreateChannel(t *testing.T) { }, } for _, tc := range cases { - repoCall := gRepo.On("Save", mock.Anything, mock.Anything).Return(convertChannel(sdk.Channel{}), tc.err) - rChannel, err := mgsdk.CreateChannel(tc.channel, adminToken) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) - if err == nil { - assert.NotEmpty(t, rChannel, fmt.Sprintf("%s: expected not nil on client ID", tc.desc)) - ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) - } - repoCall.Unset() - } -} - -func TestCreateChannels(t *testing.T) { - ts, _, gRepo, _ := newClientServer() - defer ts.Close() - - channels := []sdk.Channel{ - { - Name: "channelName", - Metadata: validMetadata, - Status: mgclients.EnabledStatus.String(), - }, - { - Name: "channelName2", - Metadata: validMetadata, - Status: mgclients.EnabledStatus.String(), - }, - } - - conf := sdk.Config{ - ThingsURL: ts.URL, - } - mgsdk := sdk.NewSDK(conf) - cases := []struct { - desc string - channels []sdk.Channel - response []sdk.Channel - token string - err errors.SDKError - }{ - { - desc: "create channels successfully", - channels: channels, - response: channels, - token: token, - err: nil, - }, - { - desc: "register empty channels", - channels: []sdk.Channel{}, - response: []sdk.Channel{}, - token: token, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrEmptyList), http.StatusBadRequest), - }, - { - desc: "register channels that can't be marshalled", - channels: []sdk.Channel{ - { - Name: "test", - Metadata: map[string]interface{}{ - "test": make(chan int), - }, - }, - }, - response: []sdk.Channel{}, - token: token, - err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")), - }, - } - for _, tc := range cases { - repoCall := gRepo.On("Save", mock.Anything, mock.Anything).Return(convertChannels(tc.response), tc.err) - rChannel, err := mgsdk.CreateChannels(tc.channels, adminToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("AddPolicy", mock.Anything, mock.Anything).Return(&magistrala.AddPolicyRes{Authorized: true}, nil) + repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall3 := grepo.On("Save", mock.Anything, mock.Anything).Return(convertChannel(sdk.Channel{}), tc.err) + rChannel, err := mgsdk.CreateChannel(tc.channel, validToken) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) if err == nil { assert.NotEmpty(t, rChannel, fmt.Sprintf("%s: expected not nil on client ID", tc.desc)) - ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) } repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + repoCall3.Unset() } } func TestListChannels(t *testing.T) { - ts, _, gRepo, _ := newClientServer() + ts, grepo, auth := newChannelsServer() defer ts.Close() var chs []sdk.Channel @@ -268,7 +211,7 @@ func TestListChannels(t *testing.T) { token: invalidToken, offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), response: nil, }, { @@ -276,7 +219,7 @@ func TestListChannels(t *testing.T) { token: "", offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized), response: nil, }, { @@ -284,7 +227,7 @@ func TestListChannels(t *testing.T) { token: token, offset: offset, limit: 0, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: nil, response: nil, }, { @@ -292,7 +235,7 @@ func TestListChannels(t *testing.T) { token: token, offset: offset, limit: 110, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), response: []sdk.Channel(nil), }, { @@ -325,21 +268,29 @@ func TestListChannels(t *testing.T) { } for _, tc := range cases { - repoCall1 := gRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertChannels(tc.response)}, tc.err) - pm := sdk.PageMetadata{} - page, err := mgsdk.Channels(pm, adminToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response)}, nil) + repoCall2 := grepo.On("RetrieveByIDs", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertChannels(tc.response)}, tc.err) + pm := sdk.PageMetadata{ + Offset: tc.offset, + Limit: tc.limit, + Level: uint64(tc.level), + } + page, err := mgsdk.Channels(pm, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, len(tc.response), len(page.Channels), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc)) + ok := repoCall2.Parent.AssertCalled(t, "RetrieveByIDs", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("RetrieveByIDs was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() } } func TestViewChannel(t *testing.T) { - ts, _, gRepo, _ := newClientServer() + ts, grepo, auth := newChannelsServer() defer ts.Close() channel := sdk.Channel{ @@ -365,7 +316,7 @@ func TestViewChannel(t *testing.T) { }{ { desc: "view channel", - token: adminToken, + token: validToken, channelID: channel.ID, response: channel, err: nil, @@ -375,19 +326,20 @@ func TestViewChannel(t *testing.T) { token: "wrongtoken", channelID: channel.ID, response: sdk.Channel{Children: []*sdk.Channel{}}, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthorization, errors.ErrAuthentication), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "view channel for wrong id", - token: adminToken, - channelID: gmocks.WrongID, + token: validToken, + channelID: mocks.WrongID, response: sdk.Channel{Children: []*sdk.Channel{}}, err: errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), }, } for _, tc := range cases { - repoCall1 := gRepo.On("RetrieveByID", mock.Anything, tc.channelID).Return(convertChannel(tc.response), tc.err) + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 := grepo.On("RetrieveByID", mock.Anything, tc.channelID).Return(convertChannel(tc.response), tc.err) grp, err := mgsdk.Channel(tc.channelID, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if len(tc.response.Children) == 0 { @@ -401,12 +353,13 @@ func TestViewChannel(t *testing.T) { ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.channelID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() } } func TestUpdateChannel(t *testing.T) { - ts, _, gRepo, _ := newClientServer() + ts, grepo, auth := newChannelsServer() defer ts.Close() channel := sdk.Channel{ @@ -440,7 +393,7 @@ func TestUpdateChannel(t *testing.T) { ID: channel.ID, Name: "NewName", }, - token: adminToken, + token: validToken, err: nil, }, { @@ -453,7 +406,7 @@ func TestUpdateChannel(t *testing.T) { ID: channel.ID, Description: "NewDescription", }, - token: adminToken, + token: validToken, err: nil, }, { @@ -470,39 +423,39 @@ func TestUpdateChannel(t *testing.T) { "field": "value2", }, }, - token: adminToken, + token: validToken, err: nil, }, { desc: "update channel name with invalid channel id", channel: sdk.Channel{ - ID: gmocks.WrongID, + ID: mocks.WrongID, Name: "NewName", }, response: sdk.Channel{}, - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), }, { desc: "update channel description with invalid channel id", channel: sdk.Channel{ - ID: gmocks.WrongID, + ID: mocks.WrongID, Description: "NewDescription", }, response: sdk.Channel{}, - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), }, { desc: "update channel metadata with invalid channel id", channel: sdk.Channel{ - ID: gmocks.WrongID, + ID: mocks.WrongID, Metadata: sdk.Metadata{ "field": "value2", }, }, response: sdk.Channel{}, - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), }, { @@ -513,7 +466,7 @@ func TestUpdateChannel(t *testing.T) { }, response: sdk.Channel{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthorization, errors.ErrAuthentication), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update channel description with invalid token", @@ -523,7 +476,7 @@ func TestUpdateChannel(t *testing.T) { }, response: sdk.Channel{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthorization, errors.ErrAuthentication), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update channel metadata with invalid token", @@ -535,7 +488,7 @@ func TestUpdateChannel(t *testing.T) { }, response: sdk.Channel{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthorization, errors.ErrAuthentication), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update channel that can't be marshalled", @@ -552,19 +505,21 @@ func TestUpdateChannel(t *testing.T) { } for _, tc := range cases { - repoCall1 := gRepo.On("Update", mock.Anything, mock.Anything).Return(convertChannel(tc.response), tc.err) + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 := grepo.On("Update", mock.Anything, mock.Anything).Return(convertChannel(tc.response), tc.err) _, err := mgsdk.UpdateChannel(tc.channel, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() } } func TestListChannelsByThing(t *testing.T) { - ts, _, _, auth := newClientServer() + ts, grepo, auth := newChannelsServer() auth.Test(t) defer ts.Close() @@ -573,11 +528,12 @@ func TestListChannelsByThing(t *testing.T) { } mgsdk := sdk.NewSDK(conf) - nChannels := uint64(100) + nChannels := uint64(10) aChannels := []sdk.Channel{} for i := uint64(1); i < nChannels; i++ { channel := sdk.Channel{ + ID: generateUUID(t), Name: fmt.Sprintf("membership_%d@example.com", i), Metadata: sdk.Metadata{"role": "channel"}, Status: mgclients.EnabledStatus.String(), @@ -595,7 +551,7 @@ func TestListChannelsByThing(t *testing.T) { }{ { desc: "list channel with authorized token", - token: adminToken, + token: validToken, clientID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{}, response: aChannels, @@ -603,7 +559,7 @@ func TestListChannelsByThing(t *testing.T) { }, { desc: "list channel with offset and limit", - token: adminToken, + token: validToken, clientID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ Offset: 6, @@ -616,7 +572,7 @@ func TestListChannelsByThing(t *testing.T) { }, { desc: "list channel with given name", - token: adminToken, + token: validToken, clientID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ Name: gName, @@ -630,7 +586,7 @@ func TestListChannelsByThing(t *testing.T) { }, { desc: "list channel with given level", - token: adminToken, + token: validToken, clientID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ Level: 1, @@ -644,7 +600,7 @@ func TestListChannelsByThing(t *testing.T) { }, { desc: "list channel with metadata", - token: adminToken, + token: validToken, clientID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ Metadata: validMetadata, @@ -662,27 +618,29 @@ func TestListChannelsByThing(t *testing.T) { clientID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{}, response: []sdk.Channel(nil), - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthorization, errors.ErrAuthentication), http.StatusUnauthorized), - }, - { - desc: "list channel with an invalid id", - token: adminToken, - clientID: gmocks.WrongID, - page: sdk.PageMetadata{}, - response: []sdk.Channel(nil), - err: errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := auth.On("ListAllSubjects", mock.Anything, mock.Anything).Return(&magistrala.ListSubjectsRes{Policies: toIDs(tc.response)}, nil) + repoCall3 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response)}, nil) + repoCall4 := grepo.On("RetrieveByIDs", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertChannels(tc.response)}, tc.err) page, err := mgsdk.ChannelsByThing(tc.clientID, tc.page, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, page.Channels, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page.Channels)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + repoCall3.Unset() + repoCall4.Unset() } } func TestEnableChannel(t *testing.T) { - ts, _, gRepo, _ := newClientServer() + ts, grepo, auth := newChannelsServer() defer ts.Close() conf := sdk.Config{ @@ -700,12 +658,14 @@ func TestEnableChannel(t *testing.T) { Status: mgclients.Disabled, } - repoCall1 := gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil) - repoCall2 := gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(nil) - _, err := mgsdk.EnableChannel("wrongID", adminToken) - assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.Wrap(mggroups.ErrEnableGroup, errors.ErrNotFound), http.StatusNotFound), fmt.Sprintf("Enable channel with wrong id: expected %v got %v", errors.ErrNotFound, err)) + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 := grepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil) + repoCall2 := grepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(nil) + _, err := mgsdk.EnableChannel("wrongID", validToken) + assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), fmt.Sprintf("Enable channel with wrong id: expected %v got %v", errors.ErrNotFound, err)) ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, "wrongID") assert.True(t, ok, "RetrieveByID was not called on enabling channel") + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() @@ -717,22 +677,23 @@ func TestEnableChannel(t *testing.T) { UpdatedAt: creationTime, Status: mgclients.DisabledStatus, } - - repoCall1 = gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(ch, nil) - repoCall2 = gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(ch, nil) - res, err := mgsdk.EnableChannel(channel.ID, adminToken) + repoCall = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 = grepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(ch, nil) + repoCall2 = grepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(ch, nil) + res, err := mgsdk.EnableChannel(channel.ID, validToken) assert.Nil(t, err, fmt.Sprintf("Enable channel with correct id: expected %v got %v", nil, err)) assert.Equal(t, channel, res, fmt.Sprintf("Enable channel with correct id: expected %v got %v", channel, res)) ok = repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, channel.ID) assert.True(t, ok, "RetrieveByID was not called on enabling channel") ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) assert.True(t, ok, "ChangeStatus was not called on enabling channel") + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() } func TestDisableChannel(t *testing.T) { - ts, _, gRepo, _ := newClientServer() + ts, grepo, auth := newChannelsServer() defer ts.Close() conf := sdk.Config{ @@ -750,12 +711,14 @@ func TestDisableChannel(t *testing.T) { Status: mgclients.Enabled, } - repoCall1 := gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(sdk.ErrFailedRemoval) - repoCall2 := gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil) - _, err := mgsdk.DisableChannel("wrongID", adminToken) - assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.Wrap(mggroups.ErrDisableGroup, errors.ErrNotFound), http.StatusNotFound), fmt.Sprintf("Disable channel with wrong id: expected %v got %v", errors.ErrNotFound, err)) + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 := grepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(sdk.ErrFailedRemoval) + repoCall2 := grepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil) + _, err := mgsdk.DisableChannel("wrongID", validToken) + assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), fmt.Sprintf("Disable channel with wrong id: expected %v got %v", errors.ErrNotFound, err)) ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, "wrongID") assert.True(t, ok, "Memberships was not called on disabling channel with wrong id") + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() @@ -768,15 +731,31 @@ func TestDisableChannel(t *testing.T) { Status: mgclients.EnabledStatus, } - repoCall1 = gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(ch, nil) - repoCall2 = gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(ch, nil) - res, err := mgsdk.DisableChannel(channel.ID, adminToken) + repoCall = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 = grepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(ch, nil) + repoCall2 = grepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(ch, nil) + res, err := mgsdk.DisableChannel(channel.ID, validToken) assert.Nil(t, err, fmt.Sprintf("Disable channel with correct id: expected %v got %v", nil, err)) assert.Equal(t, channel, res, fmt.Sprintf("Disable channel with correct id: expected %v got %v", channel, res)) ok = repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, channel.ID) assert.True(t, ok, "RetrieveByID was not called on disabling channel with correct id") ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) assert.True(t, ok, "ChangeStatus was not called on disabling channel with correct id") + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() } + +func toIDs(objects interface{}) []string { + v := reflect.ValueOf(objects) + if v.Kind() != reflect.Slice { + panic("objects argument must be a slice") + } + ids := make([]string, v.Len()) + for i := 0; i < v.Len(); i++ { + id := v.Index(i).FieldByName("ID").String() + ids[i] = id + } + + return ids +} diff --git a/pkg/sdk/go/consumers_test.go b/pkg/sdk/go/consumers_test.go index 6a8f09f6b..b6d2b8881 100644 --- a/pkg/sdk/go/consumers_test.go +++ b/pkg/sdk/go/consumers_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "testing" + "github.com/absmach/magistrala" authmocks "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/consumers/notifiers" httpapi "github.com/absmach/magistrala/consumers/notifiers/api" @@ -19,11 +20,10 @@ import ( sdk "github.com/absmach/magistrala/pkg/sdk/go" "github.com/absmach/magistrala/pkg/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) -const wrongValue = "wrong_value" - var ( sub1 = sdk.Subscription{ Topic: "topic", @@ -33,14 +33,14 @@ var ( exampleUser1 = "email1@example.com" ) -func newSubscriptionService() notifiers.Service { +func newSubscriptionService() (notifiers.Service, *authmocks.Service) { repo := mocks.NewRepo(make(map[string]notifiers.Subscription)) auth := new(authmocks.Service) notifier := mocks.NewNotifier() idp := uuid.NewMock() from := "exampleFrom" - return notifiers.New(auth, repo, idp, notifier, from) + return notifiers.New(auth, repo, idp, notifier, from), auth } func newSubscriptionServer(svc notifiers.Service) *httptest.Server { @@ -51,7 +51,7 @@ func newSubscriptionServer(svc notifiers.Service) *httptest.Server { } func TestCreateSubscription(t *testing.T) { - svc := newSubscriptionService() + svc, auth := newSubscriptionService() ts := newSubscriptionServer(svc) defer ts.Close() @@ -87,7 +87,7 @@ func TestCreateSubscription(t *testing.T) { { desc: "create new subscription with invalid token", subscription: sub1, - token: wrongValue, + token: authmocks.InvalidValue, err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), empty: true, }, @@ -101,14 +101,16 @@ func TestCreateSubscription(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) loc, err := mgsdk.CreateSubscription(tc.subscription.Topic, tc.subscription.Contact, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.empty, loc == "", fmt.Sprintf("%s: expected empty result location, got: %s", tc.desc, loc)) + repoCall.Unset() } } func TestViewSubscription(t *testing.T) { - svc := newSubscriptionService() + svc, auth := newSubscriptionService() ts := newSubscriptionServer(svc) defer ts.Close() sdkConf := sdk.Config{ @@ -118,8 +120,10 @@ func TestViewSubscription(t *testing.T) { } mgsdk := sdk.NewSDK(sdkConf) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: exampleUser1}).Return(&magistrala.IdentityRes{Id: validID}, nil) id, err := mgsdk.CreateSubscription("topic", "contact", exampleUser1) require.Nil(t, err, fmt.Sprintf("unexpected error during creating subscription: %s", err)) + repoCall.Unset() cases := []struct { desc string @@ -152,17 +156,18 @@ func TestViewSubscription(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) respSub, err := mgsdk.ViewSubscription(tc.subID, tc.token) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) tc.response.ID = respSub.ID tc.response.OwnerID = respSub.OwnerID assert.Equal(t, tc.response, respSub, fmt.Sprintf("%s: expected response %s, got %s", tc.desc, tc.response, respSub)) + repoCall.Unset() } } func TestListSubscription(t *testing.T) { - svc := newSubscriptionService() + svc, auth := newSubscriptionService() ts := newSubscriptionServer(svc) defer ts.Close() sdkConf := sdk.Config{ @@ -175,10 +180,14 @@ func TestListSubscription(t *testing.T) { nSubs := 10 subs := make([]sdk.Subscription, nSubs) for i := 0; i < nSubs; i++ { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: exampleUser1}).Return(&magistrala.IdentityRes{Id: validID}, nil) id, err := mgsdk.CreateSubscription(fmt.Sprintf("topic_%d", i), fmt.Sprintf("contact_%d", i), exampleUser1) require.Nil(t, err, fmt.Sprintf("unexpected error during creating subscription: %s", err)) + repoCall.Unset() + repoCall = auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: exampleUser1}).Return(&magistrala.IdentityRes{Id: validID}, nil) sub, err := mgsdk.ViewSubscription(id, exampleUser1) require.Nil(t, err, fmt.Sprintf("unexpected error during getting subscription: %s", err)) + repoCall.Unset() subs[i] = sub } @@ -213,14 +222,16 @@ func TestListSubscription(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) subs, err := mgsdk.ListSubscriptions(tc.page, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, subs.Subscriptions, fmt.Sprintf("%s: expected response %v, got %v", tc.desc, tc.response, subs.Subscriptions)) + repoCall.Unset() } } func TestDeleteSubscription(t *testing.T) { - svc := newSubscriptionService() + svc, auth := newSubscriptionService() ts := newSubscriptionServer(svc) defer ts.Close() sdkConf := sdk.Config{ @@ -230,8 +241,10 @@ func TestDeleteSubscription(t *testing.T) { } mgsdk := sdk.NewSDK(sdkConf) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: exampleUser1}).Return(&magistrala.IdentityRes{Id: validID}, nil) id, err := mgsdk.CreateSubscription("topic", "contact", exampleUser1) require.Nil(t, err, fmt.Sprintf("unexpected error during creating subscription: %s", err)) + repoCall.Unset() cases := []struct { desc string @@ -264,7 +277,9 @@ func TestDeleteSubscription(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) err := mgsdk.DeleteSubscription(tc.subID, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) + repoCall.Unset() } } diff --git a/pkg/sdk/go/groups_test.go b/pkg/sdk/go/groups_test.go index fe6f2e912..23734c93e 100644 --- a/pkg/sdk/go/groups_test.go +++ b/pkg/sdk/go/groups_test.go @@ -10,10 +10,11 @@ import ( "testing" "time" + "github.com/absmach/magistrala" authmocks "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/apiutil" "github.com/absmach/magistrala/internal/groups" - gmocks "github.com/absmach/magistrala/internal/groups/mocks" + "github.com/absmach/magistrala/internal/groups/mocks" "github.com/absmach/magistrala/internal/testsutil" mglog "github.com/absmach/magistrala/logger" "github.com/absmach/magistrala/pkg/clients" @@ -22,29 +23,29 @@ import ( sdk "github.com/absmach/magistrala/pkg/sdk/go" "github.com/absmach/magistrala/users" "github.com/absmach/magistrala/users/api" - "github.com/absmach/magistrala/users/mocks" + umocks "github.com/absmach/magistrala/users/mocks" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" ) -func newGroupsServer() (*httptest.Server, *mocks.Repository, *gmocks.Repository, *authmocks.Service) { - cRepo := new(mocks.Repository) - gRepo := new(gmocks.Repository) +func newGroupsServer() (*httptest.Server, *mocks.Repository, *authmocks.Service) { + crepo := new(umocks.Repository) + grepo := new(mocks.Repository) auth := new(authmocks.Service) - csvc := users.NewService(cRepo, auth, emailer, phasher, idProvider, passRegex, true) - gsvc := groups.NewService(gRepo, idProvider, auth) + csvc := users.NewService(crepo, auth, emailer, phasher, idProvider, passRegex, true) + gsvc := groups.NewService(grepo, idProvider, auth) logger := mglog.NewMock() mux := chi.NewRouter() api.MakeHandler(csvc, gsvc, mux, logger, "") - return httptest.NewServer(mux), cRepo, gRepo, auth + return httptest.NewServer(mux), grepo, auth } func TestCreateGroup(t *testing.T) { - ts, _, gRepo, _ := newGroupsServer() + ts, grepo, auth := newGroupsServer() defer ts.Close() group := sdk.Group{ Name: "groupName", @@ -71,10 +72,12 @@ func TestCreateGroup(t *testing.T) { { desc: "create group with existing name", group: group, + token: token, err: nil, }, { - desc: "create group with parent", + desc: "create group with parent", + token: token, group: sdk.Group{ Name: gName, ParentID: testsutil.GenerateUUID(t), @@ -83,32 +86,36 @@ func TestCreateGroup(t *testing.T) { err: nil, }, { - desc: "create group with invalid parent", + desc: "create group with invalid parent", + token: token, group: sdk.Group{ Name: gName, - ParentID: gmocks.WrongID, + ParentID: mocks.WrongID, Status: clients.EnabledStatus.String(), }, err: errors.NewSDKErrorWithStatus(errors.ErrCreateEntity, http.StatusInternalServerError), }, { - desc: "create group with invalid owner", + desc: "create group with invalid owner", + token: token, group: sdk.Group{ Name: gName, - OwnerID: gmocks.WrongID, + OwnerID: mocks.WrongID, Status: clients.EnabledStatus.String(), }, err: errors.NewSDKErrorWithStatus(sdk.ErrFailedCreation, http.StatusInternalServerError), }, { - desc: "create group with missing name", + desc: "create group with missing name", + token: token, group: sdk.Group{ Status: clients.EnabledStatus.String(), }, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrNameSize), http.StatusBadRequest), }, { - desc: "create a group with every field defined", + desc: "create a group with every field defined", + token: token, group: sdk.Group{ ID: generateUUID(t), OwnerID: "owner", @@ -122,36 +129,41 @@ func TestCreateGroup(t *testing.T) { UpdatedAt: time.Now(), Status: clients.EnabledStatus.String(), }, - token: token, - err: nil, + err: nil, }, { - desc: "create a group that can't be marshalled", + desc: "create a group that can't be marshalled", + token: token, group: sdk.Group{ Name: "test", Metadata: map[string]interface{}{ "test": make(chan int), }, }, - token: token, - err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")), + err: errors.NewSDKError(fmt.Errorf("json: unsupported type: chan int")), }, } for _, tc := range cases { - repoCall := gRepo.On("Save", mock.Anything, mock.Anything).Return(convertGroup(sdk.Group{}), tc.err) - rGroup, err := mgsdk.CreateGroup(tc.group, validToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("AddPolicy", mock.Anything, mock.Anything).Return(&magistrala.AddPolicyRes{Authorized: true}, nil) + repoCall2 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall3 := grepo.On("Save", mock.Anything, mock.Anything).Return(convertGroup(sdk.Group{}), tc.err) + rGroup, err := mgsdk.CreateGroup(tc.group, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: unexpected error %s", tc.desc, err)) if err == nil { assert.NotEmpty(t, rGroup, fmt.Sprintf("%s: expected not nil on client ID", tc.desc)) - ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall3.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) } repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + repoCall3.Unset() } } func TestListGroups(t *testing.T) { - ts, _, gRepo, _ := newGroupsServer() + ts, grepo, auth := newGroupsServer() defer ts.Close() var grps []sdk.Group @@ -198,7 +210,7 @@ func TestListGroups(t *testing.T) { token: invalidToken, offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), response: nil, }, { @@ -206,7 +218,7 @@ func TestListGroups(t *testing.T) { token: "", offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized), response: nil, }, { @@ -214,7 +226,7 @@ func TestListGroups(t *testing.T) { token: token, offset: offset, limit: 0, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: nil, response: nil, }, { @@ -222,7 +234,7 @@ func TestListGroups(t *testing.T) { token: token, offset: offset, limit: 110, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), response: []sdk.Group(nil), }, { @@ -255,21 +267,29 @@ func TestListGroups(t *testing.T) { } for _, tc := range cases { - repoCall1 := gRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertGroups(tc.response)}, tc.err) - pm := sdk.PageMetadata{} - page, err := mgsdk.Groups(pm, validToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response)}, nil) + repoCall2 := grepo.On("RetrieveByIDs", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertGroups(tc.response)}, tc.err) + pm := sdk.PageMetadata{ + Offset: tc.offset, + Limit: tc.limit, + Level: uint64(tc.level), + } + page, err := mgsdk.Groups(pm, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, len(tc.response), len(page.Groups), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc)) + ok := repoCall2.Parent.AssertCalled(t, "RetrieveByIDs", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("RetrieveByIDs was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() } } func TestListParentGroups(t *testing.T) { - ts, _, gRepo, _ := newGroupsServer() + ts, grepo, auth := newGroupsServer() defer ts.Close() var grps []sdk.Group @@ -319,7 +339,7 @@ func TestListParentGroups(t *testing.T) { token: invalidToken, offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), response: nil, }, { @@ -327,7 +347,7 @@ func TestListParentGroups(t *testing.T) { token: "", offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized), response: nil, }, { @@ -335,7 +355,7 @@ func TestListParentGroups(t *testing.T) { token: token, offset: offset, limit: 0, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: nil, response: nil, }, { @@ -343,7 +363,7 @@ func TestListParentGroups(t *testing.T) { token: token, offset: offset, limit: 110, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), response: []sdk.Group(nil), }, { @@ -376,21 +396,29 @@ func TestListParentGroups(t *testing.T) { } for _, tc := range cases { - repoCall1 := gRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertGroups(tc.response)}, tc.err) - pm := sdk.PageMetadata{} - page, err := mgsdk.Parents(parentID, pm, validToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response)}, nil) + repoCall2 := grepo.On("RetrieveByIDs", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertGroups(tc.response)}, tc.err) + pm := sdk.PageMetadata{ + Offset: tc.offset, + Limit: tc.limit, + Level: uint64(tc.level), + } + page, err := mgsdk.Parents(parentID, pm, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, len(tc.response), len(page.Groups), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc)) + ok := repoCall2.Parent.AssertCalled(t, "RetrieveByIDs", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("RetrieveByIDs was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() } } func TestListChildrenGroups(t *testing.T) { - ts, _, gRepo, _ := newGroupsServer() + ts, grepo, auth := newGroupsServer() defer ts.Close() var grps []sdk.Group @@ -441,7 +469,7 @@ func TestListChildrenGroups(t *testing.T) { token: invalidToken, offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), response: nil, }, { @@ -449,7 +477,7 @@ func TestListChildrenGroups(t *testing.T) { token: "", offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized), response: nil, }, { @@ -457,7 +485,7 @@ func TestListChildrenGroups(t *testing.T) { token: token, offset: offset, limit: 0, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: nil, response: nil, }, { @@ -465,7 +493,7 @@ func TestListChildrenGroups(t *testing.T) { token: token, offset: offset, limit: 110, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), response: []sdk.Group(nil), }, { @@ -498,21 +526,29 @@ func TestListChildrenGroups(t *testing.T) { } for _, tc := range cases { - repoCall1 := gRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertGroups(tc.response)}, tc.err) - pm := sdk.PageMetadata{} - page, err := mgsdk.Children(childID, pm, validToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response)}, nil) + repoCall2 := grepo.On("RetrieveByIDs", mock.Anything, mock.Anything).Return(mggroups.Page{Groups: convertGroups(tc.response)}, tc.err) + pm := sdk.PageMetadata{ + Offset: tc.offset, + Limit: tc.limit, + Level: uint64(tc.level), + } + page, err := mgsdk.Children(childID, pm, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, len(tc.response), len(page.Groups), fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "RetrieveAll", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveAll was not called on %s", tc.desc)) + ok := repoCall2.Parent.AssertCalled(t, "RetrieveByIDs", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("RetrieveByIDs was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() } } func TestViewGroup(t *testing.T) { - ts, _, gRepo, _ := newGroupsServer() + ts, grepo, auth := newGroupsServer() defer ts.Close() group := sdk.Group{ @@ -548,19 +584,20 @@ func TestViewGroup(t *testing.T) { token: "wrongtoken", groupID: group.ID, response: sdk.Group{Children: []*sdk.Group{}}, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "view group for wrong id", token: validToken, - groupID: gmocks.WrongID, + groupID: mocks.WrongID, response: sdk.Group{Children: []*sdk.Group{}}, err: errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), }, } for _, tc := range cases { - repoCall1 := gRepo.On("RetrieveByID", mock.Anything, tc.groupID).Return(convertGroup(tc.response), tc.err) + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 := grepo.On("RetrieveByID", mock.Anything, tc.groupID).Return(convertGroup(tc.response), tc.err) grp, err := mgsdk.Group(tc.groupID, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if len(tc.response.Children) == 0 { @@ -574,12 +611,13 @@ func TestViewGroup(t *testing.T) { ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.groupID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() } } func TestUpdateGroup(t *testing.T) { - ts, _, gRepo, _ := newGroupsServer() + ts, grepo, auth := newGroupsServer() defer ts.Close() group := sdk.Group{ @@ -649,7 +687,7 @@ func TestUpdateGroup(t *testing.T) { { desc: "update group name with invalid group id", group: sdk.Group{ - ID: gmocks.WrongID, + ID: mocks.WrongID, Name: "NewName", }, response: sdk.Group{}, @@ -659,7 +697,7 @@ func TestUpdateGroup(t *testing.T) { { desc: "update group description with invalid group id", group: sdk.Group{ - ID: gmocks.WrongID, + ID: mocks.WrongID, Description: "NewDescription", }, response: sdk.Group{}, @@ -669,7 +707,7 @@ func TestUpdateGroup(t *testing.T) { { desc: "update group metadata with invalid group id", group: sdk.Group{ - ID: gmocks.WrongID, + ID: mocks.WrongID, Metadata: sdk.Metadata{ "field": "value2", }, @@ -686,7 +724,7 @@ func TestUpdateGroup(t *testing.T) { }, response: sdk.Group{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update group description with invalid token", @@ -696,7 +734,7 @@ func TestUpdateGroup(t *testing.T) { }, response: sdk.Group{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update group metadata with invalid token", @@ -708,7 +746,7 @@ func TestUpdateGroup(t *testing.T) { }, response: sdk.Group{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update a group that can't be marshalled", @@ -725,19 +763,21 @@ func TestUpdateGroup(t *testing.T) { } for _, tc := range cases { - repoCall1 := gRepo.On("Update", mock.Anything, mock.Anything).Return(convertGroup(tc.response), tc.err) + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 := grepo.On("Update", mock.Anything, mock.Anything).Return(convertGroup(tc.response), tc.err) _, err := mgsdk.UpdateGroup(tc.group, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) if tc.err == nil { ok := repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() } } func TestEnableGroup(t *testing.T) { - ts, _, gRepo, _ := newGroupsServer() + ts, grepo, auth := newGroupsServer() defer ts.Close() conf := sdk.Config{ @@ -755,12 +795,14 @@ func TestEnableGroup(t *testing.T) { Status: clients.Disabled, } - repoCall1 := gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil) - repoCall2 := gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(sdk.ErrFailedRemoval) + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 := grepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil) + repoCall2 := grepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(nil) _, err := mgsdk.EnableGroup("wrongID", validToken) assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), fmt.Sprintf("Enable group with wrong id: expected %v got %v", errors.ErrNotFound, err)) ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, "wrongID") assert.True(t, ok, "RetrieveByID was not called on enabling group") + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() @@ -772,9 +814,9 @@ func TestEnableGroup(t *testing.T) { UpdatedAt: creationTime, Status: clients.DisabledStatus, } - - repoCall1 = gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(g, nil) - repoCall2 = gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(g, nil) + repoCall = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 = grepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(g, nil) + repoCall2 = grepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(g, nil) res, err := mgsdk.EnableGroup(group.ID, validToken) assert.Nil(t, err, fmt.Sprintf("Enable group with correct id: expected %v got %v", nil, err)) assert.Equal(t, group, res, fmt.Sprintf("Enable group with correct id: expected %v got %v", group, res)) @@ -782,12 +824,13 @@ func TestEnableGroup(t *testing.T) { assert.True(t, ok, "RetrieveByID was not called on enabling group") ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) assert.True(t, ok, "ChangeStatus was not called on enabling group") + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() } func TestDisableGroup(t *testing.T) { - ts, _, gRepo, _ := newGroupsServer() + ts, grepo, auth := newGroupsServer() defer ts.Close() conf := sdk.Config{ @@ -805,12 +848,14 @@ func TestDisableGroup(t *testing.T) { Status: clients.Enabled, } - repoCall1 := gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(sdk.ErrFailedRemoval) - repoCall2 := gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil) + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 := grepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(sdk.ErrFailedRemoval) + repoCall2 := grepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(nil) _, err := mgsdk.DisableGroup("wrongID", validToken) assert.Equal(t, err, errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), fmt.Sprintf("Disable group with wrong id: expected %v got %v", errors.ErrNotFound, err)) ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, "wrongID") assert.True(t, ok, "Memberships was not called on disabling group with wrong id") + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() @@ -823,8 +868,9 @@ func TestDisableGroup(t *testing.T) { Status: clients.EnabledStatus, } - repoCall1 = gRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(g, nil) - repoCall2 = gRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(g, nil) + repoCall = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall1 = grepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(g, nil) + repoCall2 = grepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(g, nil) res, err := mgsdk.DisableGroup(group.ID, validToken) assert.Nil(t, err, fmt.Sprintf("Disable group with correct id: expected %v got %v", nil, err)) assert.Equal(t, group, res, fmt.Sprintf("Disable group with correct id: expected %v got %v", group, res)) @@ -832,6 +878,7 @@ func TestDisableGroup(t *testing.T) { assert.True(t, ok, "RetrieveByID was not called on disabling group with correct id") ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) assert.True(t, ok, "ChangeStatus was not called on disabling group with correct id") + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() } diff --git a/pkg/sdk/go/health_test.go b/pkg/sdk/go/health_test.go index 2d5971112..67a2b8a2a 100644 --- a/pkg/sdk/go/health_test.go +++ b/pkg/sdk/go/health_test.go @@ -23,7 +23,7 @@ func TestHealth(t *testing.T) { auth.Test(t) defer usclsv.Close() - certSvc, err := newCertService() + certSvc, _, _, err := newCertService() require.Nil(t, err, fmt.Sprintf("unexpected error during creating service: %s", err)) CertTs := newCertServer(certSvc) defer CertTs.Close() diff --git a/pkg/sdk/go/setup_test.go b/pkg/sdk/go/setup_test.go index 5f3e22f3a..535436477 100644 --- a/pkg/sdk/go/setup_test.go +++ b/pkg/sdk/go/setup_test.go @@ -23,7 +23,7 @@ const ( Identity = "identity" secret = "strongsecret" token = "token" - invalidToken = "invalidtoken" + invalidToken = "invalid" contentType = "application/senml+json" ) diff --git a/pkg/sdk/go/things_test.go b/pkg/sdk/go/things_test.go index 13fb4ea1a..5e9dd1ef3 100644 --- a/pkg/sdk/go/things_test.go +++ b/pkg/sdk/go/things_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/absmach/magistrala" authmocks "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/apiutil" "github.com/absmach/magistrala/internal/groups" @@ -27,11 +28,6 @@ import ( "github.com/stretchr/testify/mock" ) -var ( - adminToken = "token" - userToken = "userToken" -) - func newThingsServer() (*httptest.Server, *mocks.Repository, *gmocks.Repository, *authmocks.Service) { cRepo := new(mocks.Repository) gRepo := new(gmocks.Repository) @@ -49,7 +45,7 @@ func newThingsServer() (*httptest.Server, *mocks.Repository, *gmocks.Repository, } func TestCreateThing(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() thing := sdk.Thing{ @@ -172,7 +168,9 @@ func TestCreateThing(t *testing.T) { }, } for _, tc := range cases { - repoCall := cRepo.On("Save", mock.Anything, mock.Anything).Return(tc.response, tc.repoErr) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("AddPolicy", mock.Anything, mock.Anything).Return(&magistrala.AddPolicyRes{Authorized: true}, nil) + repoCall2 := cRepo.On("Save", mock.Anything, mock.Anything).Return(tc.response, tc.repoErr) rThing, err := mgsdk.CreateThing(tc.client, tc.token) tc.response.ID = rThing.ID @@ -184,15 +182,17 @@ func TestCreateThing(t *testing.T) { assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, rThing, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rThing)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) } repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() } } func TestCreateThings(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() things := []sdk.Thing{ @@ -254,7 +254,9 @@ func TestCreateThings(t *testing.T) { }, } for _, tc := range cases { - repoCall := cRepo.On("Save", mock.Anything, mock.Anything).Return(tc.response, tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("AddPolicy", mock.Anything, mock.Anything).Return(&magistrala.AddPolicyRes{Authorized: true}, nil) + repoCall2 := cRepo.On("Save", mock.Anything, mock.Anything).Return(tc.response, tc.err) rThing, err := mgsdk.CreateThings(tc.things, tc.token) for i, t := range rThing { tc.response[i].ID = t.ID @@ -267,15 +269,17 @@ func TestCreateThings(t *testing.T) { assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, rThing, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rThing)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) } repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() } } func TestListThings(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() var ths []sdk.Thing @@ -330,10 +334,10 @@ func TestListThings(t *testing.T) { }, { desc: "get a list of things with invalid token", - token: invalidToken, + token: authmocks.InvalidValue, offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), response: nil, }, { @@ -341,7 +345,7 @@ func TestListThings(t *testing.T) { token: "", offset: offset, limit: limit, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedList), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), response: nil, }, { @@ -349,15 +353,15 @@ func TestListThings(t *testing.T) { token: token, offset: offset, limit: 0, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusInternalServerError), - response: nil, + err: nil, + response: []sdk.Thing{}, }, { desc: "get a list of things with limit greater than max", token: token, offset: offset, limit: 110, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), response: []sdk.Thing(nil), }, { @@ -384,7 +388,7 @@ func TestListThings(t *testing.T) { }, { desc: "list things with given metadata", - token: adminToken, + token: validToken, offset: 0, limit: 1, metadata: sdk.Metadata{ @@ -395,7 +399,7 @@ func TestListThings(t *testing.T) { }, { desc: "list things with given name", - token: adminToken, + token: validToken, offset: 0, limit: 1, name: "client10", @@ -404,7 +408,7 @@ func TestListThings(t *testing.T) { }, { desc: "list things with given owner", - token: adminToken, + token: validToken, offset: 0, limit: 1, ownerID: owner, @@ -413,7 +417,7 @@ func TestListThings(t *testing.T) { }, { desc: "list things with given status", - token: adminToken, + token: validToken, offset: 0, limit: 1, status: mgclients.DisabledStatus.String(), @@ -422,7 +426,7 @@ func TestListThings(t *testing.T) { }, { desc: "list things with given tag", - token: adminToken, + token: validToken, offset: 0, limit: 1, tag: "tag1", @@ -442,17 +446,24 @@ func TestListThings(t *testing.T) { Metadata: tc.metadata, Tag: tc.tag, } - - repoCall := cRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(mgclients.ClientsPage{Page: convertClientPage(pm), Clients: convertThings(tc.response)}, tc.err) - page, err := mgsdk.Things(pm, adminToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response)}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + repoCall1 = auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{}, errors.ErrAuthorization) + } + repoCall2 := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(mgclients.ClientsPage{Page: convertClientPage(pm), Clients: convertThings(tc.response)}, tc.err) + page, err := mgsdk.Things(pm, validToken) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, page.Things, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) + repoCall2.Unset() repoCall.Unset() + repoCall1.Unset() } } func TestListThingsByChannel(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() conf := sdk.Config{ @@ -486,7 +497,7 @@ func TestListThingsByChannel(t *testing.T) { }{ { desc: "list things with authorized token", - token: adminToken, + token: validToken, channelID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{}, response: aThings, @@ -494,7 +505,7 @@ func TestListThingsByChannel(t *testing.T) { }, { desc: "list things with offset and limit", - token: adminToken, + token: validToken, channelID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ Offset: 4, @@ -505,7 +516,7 @@ func TestListThingsByChannel(t *testing.T) { }, { desc: "list things with given name", - token: adminToken, + token: validToken, channelID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ Name: Identity, @@ -518,7 +529,7 @@ func TestListThingsByChannel(t *testing.T) { { desc: "list things with given ownerID", - token: adminToken, + token: validToken, channelID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ OwnerID: user.Owner, @@ -530,7 +541,7 @@ func TestListThingsByChannel(t *testing.T) { }, { desc: "list things with given subject", - token: adminToken, + token: validToken, channelID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ Subject: subject, @@ -542,7 +553,7 @@ func TestListThingsByChannel(t *testing.T) { }, { desc: "list things with given object", - token: adminToken, + token: validToken, channelID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{ Object: object, @@ -558,11 +569,11 @@ func TestListThingsByChannel(t *testing.T) { channelID: testsutil.GenerateUUID(t), page: sdk.PageMetadata{}, response: []sdk.Thing(nil), - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "list things with an invalid id", - token: adminToken, + token: validToken, channelID: mocks.WrongID, page: sdk.PageMetadata{}, response: []sdk.Thing(nil), @@ -571,20 +582,26 @@ func TestListThingsByChannel(t *testing.T) { } for _, tc := range cases { - repoCall := cRepo.On("Members", mock.Anything, tc.channelID, mock.Anything).Return(mgclients.MembersPage{Members: convertThings(tc.response)}, tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{}, nil) + repoCall3 := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(mgclients.ClientsPage{Page: convertClientPage(tc.page), Clients: convertThings(tc.response)}, tc.err) membersPage, err := mgsdk.ThingsByChannel(tc.channelID, tc.page, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, membersPage.Things, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, membersPage.Things)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "Members", mock.Anything, tc.channelID, mock.Anything) + ok := repoCall3.Parent.AssertCalled(t, "RetrieveAllByIDs", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Members was not called on %s", tc.desc)) } repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() + repoCall3.Unset() } } func TestThing(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() thing := sdk.Thing{ @@ -609,7 +626,7 @@ func TestThing(t *testing.T) { { desc: "view thing successfully", response: thing, - token: adminToken, + token: validToken, thingID: generateUUID(t), err: nil, }, @@ -618,12 +635,12 @@ func TestThing(t *testing.T) { response: sdk.Thing{}, token: invalidToken, thingID: generateUUID(t), - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "view thing with valid token and invalid thing id", response: sdk.Thing{}, - token: adminToken, + token: validToken, thingID: mocks.WrongID, err: errors.NewSDKErrorWithStatus(errors.ErrNotFound, http.StatusNotFound), }, @@ -632,11 +649,15 @@ func TestThing(t *testing.T) { response: sdk.Thing{}, token: invalidToken, thingID: mocks.WrongID, - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, } for _, tc := range cases { + repoCall := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + if tc.token != validToken { + repoCall = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization) + } repoCall1 := cRepo.On("RetrieveByID", mock.Anything, tc.thingID).Return(convertThing(tc.response), tc.err) rClient, err := mgsdk.Thing(tc.thingID, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) @@ -646,11 +667,12 @@ func TestThing(t *testing.T) { assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) } repoCall1.Unset() + repoCall.Unset() } } func TestUpdateThing(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() conf := sdk.Config{ @@ -684,7 +706,7 @@ func TestUpdateThing(t *testing.T) { desc: "update thing name with valid token", thing: thing1, response: thing1, - token: adminToken, + token: validToken, err: nil, }, { @@ -692,13 +714,13 @@ func TestUpdateThing(t *testing.T) { thing: thing1, response: sdk.Thing{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update thing name with invalid id", thing: thing2, response: sdk.Thing{}, - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedUpdate), http.StatusInternalServerError), }, { @@ -716,20 +738,27 @@ func TestUpdateThing(t *testing.T) { } for _, tc := range cases { - repoCall1 := cRepo.On("Update", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + if tc.token != validToken { + repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization) + } + repoCall2 := cRepo.On("Update", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.err) uClient, err := mgsdk.UpdateThing(tc.thing, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) + ok := repoCall2.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } + repoCall2.Unset() + repoCall.Unset() repoCall1.Unset() } } func TestUpdateThingTags(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() conf := sdk.Config{ @@ -762,7 +791,7 @@ func TestUpdateThingTags(t *testing.T) { desc: "update thing name with valid token", thing: thing, response: thing1, - token: adminToken, + token: validToken, err: nil, }, { @@ -770,13 +799,13 @@ func TestUpdateThingTags(t *testing.T) { thing: thing1, response: sdk.Thing{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update thing name with invalid id", thing: thing2, response: sdk.Thing{}, - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedUpdate), http.StatusInternalServerError), }, { @@ -794,20 +823,27 @@ func TestUpdateThingTags(t *testing.T) { } for _, tc := range cases { - repoCall1 := cRepo.On("UpdateTags", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + if tc.token != validToken { + repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization) + } + repoCall2 := cRepo.On("UpdateTags", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.err) uClient, err := mgsdk.UpdateThingTags(tc.thing, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "UpdateTags", mock.Anything, mock.Anything) + ok := repoCall2.Parent.AssertCalled(t, "UpdateTags", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateTags was not called on %s", tc.desc)) } + repoCall2.Unset() + repoCall.Unset() repoCall1.Unset() } } func TestUpdateThingSecret(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() conf := sdk.Config{ @@ -832,7 +868,7 @@ func TestUpdateThingSecret(t *testing.T) { desc: "update thing secret with valid token", oldSecret: thing.Credentials.Secret, newSecret: "newSecret", - token: adminToken, + token: validToken, response: rthing, repoErr: nil, err: nil, @@ -844,33 +880,40 @@ func TestUpdateThingSecret(t *testing.T) { token: "non-existent", response: sdk.Thing{}, repoErr: errors.ErrAuthorization, - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update thing secret with wrong old secret", oldSecret: "oldSecret", newSecret: "newSecret", - token: adminToken, + token: validToken, response: sdk.Thing{}, repoErr: apiutil.ErrInvalidSecret, err: errors.NewSDKErrorWithStatus(apiutil.ErrInvalidSecret, http.StatusBadRequest), }, } for _, tc := range cases { - repoCall1 := cRepo.On("UpdateSecret", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.repoErr) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + if tc.token != validToken { + repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization) + } + repoCall2 := cRepo.On("UpdateSecret", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.repoErr) uClient, err := mgsdk.UpdateThingSecret(tc.oldSecret, tc.newSecret, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "UpdateSecret", mock.Anything, mock.Anything) + ok := repoCall2.Parent.AssertCalled(t, "UpdateSecret", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateSecret was not called on %s", tc.desc)) } + repoCall2.Unset() + repoCall.Unset() repoCall1.Unset() } } func TestUpdateThingOwner(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() conf := sdk.Config{ @@ -902,7 +945,7 @@ func TestUpdateThingOwner(t *testing.T) { desc: "update thing name with valid token", thing: thing, response: thing, - token: adminToken, + token: validToken, err: nil, }, { @@ -910,13 +953,13 @@ func TestUpdateThingOwner(t *testing.T) { thing: thing2, response: sdk.Thing{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "update thing name with invalid id", thing: thing2, response: sdk.Thing{}, - token: adminToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, sdk.ErrFailedUpdate), http.StatusInternalServerError), }, { @@ -934,20 +977,27 @@ func TestUpdateThingOwner(t *testing.T) { } for _, tc := range cases { - repoCall1 := cRepo.On("UpdateOwner", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + if tc.token != validToken { + repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization) + } + repoCall2 := cRepo.On("UpdateOwner", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.err) uClient, err := mgsdk.UpdateThingOwner(tc.thing, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "UpdateOwner", mock.Anything, mock.Anything) + ok := repoCall2.Parent.AssertCalled(t, "UpdateOwner", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateOwner was not called on %s", tc.desc)) } + repoCall2.Unset() + repoCall.Unset() repoCall1.Unset() } } func TestEnableThing(t *testing.T) { - ts, cRepo, _, auth := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() conf := sdk.Config{ @@ -973,7 +1023,7 @@ func TestEnableThing(t *testing.T) { { desc: "enable disabled thing", id: disabledThing1.ID, - token: adminToken, + token: validToken, thing: disabledThing1, response: endisabledThing1, repoErr: nil, @@ -982,7 +1032,7 @@ func TestEnableThing(t *testing.T) { { desc: "enable enabled thing", id: enabledThing1.ID, - token: adminToken, + token: validToken, thing: enabledThing1, response: sdk.Thing{}, repoErr: sdk.ErrFailedEnable, @@ -991,7 +1041,7 @@ func TestEnableThing(t *testing.T) { { desc: "enable non-existing thing", id: mocks.WrongID, - token: adminToken, + token: validToken, thing: sdk.Thing{}, response: sdk.Thing{}, repoErr: sdk.ErrFailedEnable, @@ -1000,19 +1050,26 @@ func TestEnableThing(t *testing.T) { } for _, tc := range cases { - repoCall1 := cRepo.On("RetrieveByID", mock.Anything, tc.id).Return(convertThing(tc.thing), tc.repoErr) - repoCall2 := cRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.repoErr) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + if tc.token != validToken { + repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization) + } + repoCall2 := cRepo.On("RetrieveByID", mock.Anything, tc.id).Return(convertThing(tc.thing), tc.repoErr) + repoCall3 := cRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.repoErr) eClient, err := mgsdk.EnableThing(tc.id, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, eClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, eClient)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.id) + ok := repoCall2.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.id) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) + ok = repoCall3.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + repoCall3.Unset() } cases2 := []struct { @@ -1056,19 +1113,21 @@ func TestEnableThing(t *testing.T) { Limit: 100, Status: tc.status, } - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := cRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertThingsPage(tc.response), nil) - clientsPage, err := mgsdk.Things(pm, adminToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response.Things)}, nil) + repoCall2 := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(convertThingsPage(tc.response), nil) + clientsPage, err := mgsdk.Things(pm, validToken) assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) size := uint64(len(clientsPage.Things)) assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", tc.desc, tc.size, size)) repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() } } func TestDisableThing(t *testing.T) { - ts, cRepo, _, auth := newClientServer() + ts, cRepo, _, auth := newThingsServer() defer ts.Close() conf := sdk.Config{ @@ -1094,7 +1153,7 @@ func TestDisableThing(t *testing.T) { { desc: "disable enabled thing", id: enabledThing1.ID, - token: adminToken, + token: validToken, thing: enabledThing1, response: disenabledThing1, repoErr: nil, @@ -1103,7 +1162,7 @@ func TestDisableThing(t *testing.T) { { desc: "disable disabled thing", id: disabledThing1.ID, - token: adminToken, + token: validToken, thing: disabledThing1, response: sdk.Thing{}, repoErr: sdk.ErrFailedDisable, @@ -1113,7 +1172,7 @@ func TestDisableThing(t *testing.T) { desc: "disable non-existing thing", id: mocks.WrongID, thing: sdk.Thing{}, - token: adminToken, + token: validToken, response: sdk.Thing{}, repoErr: sdk.ErrFailedDisable, err: errors.NewSDKErrorWithStatus(errors.Wrap(sdk.ErrFailedDisable, errors.ErrNotFound), http.StatusNotFound), @@ -1121,19 +1180,26 @@ func TestDisableThing(t *testing.T) { } for _, tc := range cases { - repoCall1 := cRepo.On("RetrieveByID", mock.Anything, tc.id).Return(convertThing(tc.thing), tc.repoErr) - repoCall2 := cRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.repoErr) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + if tc.token != validToken { + repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization) + } + repoCall2 := cRepo.On("RetrieveByID", mock.Anything, tc.id).Return(convertThing(tc.thing), tc.repoErr) + repoCall3 := cRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(convertThing(tc.response), tc.repoErr) dThing, err := mgsdk.DisableThing(tc.id, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, dThing, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, dThing)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.id) + ok := repoCall2.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.id) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) + ok = repoCall3.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc)) } + repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + repoCall3.Unset() } cases2 := []struct { @@ -1177,71 +1243,21 @@ func TestDisableThing(t *testing.T) { Limit: 100, Status: tc.status, } - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := cRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertThingsPage(tc.response), nil) - page, err := mgsdk.Things(pm, adminToken) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("ListAllObjects", mock.Anything, mock.Anything).Return(&magistrala.ListObjectsRes{Policies: toIDs(tc.response.Things)}, nil) + repoCall2 := cRepo.On("RetrieveAllByIDs", mock.Anything, mock.Anything).Return(convertThingsPage(tc.response), nil) + page, err := mgsdk.Things(pm, validToken) assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) size := uint64(len(page.Things)) assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected size %d got %d\n", tc.desc, tc.size, size)) repoCall.Unset() repoCall1.Unset() - } -} - -func TestIdentify(t *testing.T) { - ts, cRepo, _, _ := newClientServer() - defer ts.Close() - - conf := sdk.Config{ - ThingsURL: ts.URL, - } - mgsdk := sdk.NewSDK(conf) - - thing = sdk.Thing{ - ID: generateUUID(t), - Name: "clientname", - Credentials: sdk.Credentials{Identity: "clientidentity", Secret: generateUUID(t)}, - Status: mgclients.EnabledStatus.String(), - } - - cases := []struct { - desc string - secret string - response string - repoErr error - err errors.SDKError - }{ - { - desc: "identify thing successfully", - response: thing.ID, - secret: thing.Credentials.Secret, - repoErr: nil, - err: nil, - }, - { - desc: "identify thing with an invalid token", - response: "", - secret: invalidToken, - repoErr: errors.ErrAuthentication, - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), - }, - } - - for _, tc := range cases { - repoCall := cRepo.On("RetrieveBySecret", mock.Anything, mock.Anything).Return(convertThing(thing), tc.repoErr) - id, err := mgsdk.IdentifyThing(tc.secret) - assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) - assert.Equal(t, tc.response, id, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, id)) - if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "RetrieveBySecret", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveBySecret was not called on %s", tc.desc)) - } - repoCall.Unset() + repoCall2.Unset() } } func TestShareThing(t *testing.T) { - ts, _, _, auth := newClientServer() + ts, _, _, auth := newThingsServer() auth.Test(t) defer ts.Close() @@ -1262,7 +1278,7 @@ func TestShareThing(t *testing.T) { desc: "share thing with valid token", channelID: generateUUID(t), thingID: "thingID", - token: adminToken, + token: validToken, err: nil, }, { @@ -1270,24 +1286,33 @@ func TestShareThing(t *testing.T) { channelID: generateUUID(t), thingID: "thingID", token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthorization, errors.ErrAuthentication), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), }, { desc: "share thing with valid token for unauthorized user", channelID: generateUUID(t), thingID: "thingID", - token: userToken, + token: validToken, err: errors.NewSDKErrorWithStatus(errors.ErrAuthorization, http.StatusForbidden), repoErr: errors.ErrAuthorization, }, } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.repoErr) + if tc.token != validToken { + repoCall1 = auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: false}, errors.ErrAuthorization) + } + repoCall2 := auth.On("AddPolicy", mock.Anything, mock.Anything).Return(&magistrala.AddPolicyRes{Authorized: true}, nil) req := sdk.UsersRelationRequest{ Relation: "viewer", UserIDs: []string{tc.channelID}, } err := mgsdk.ShareThing(tc.thingID, req, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() } } diff --git a/pkg/sdk/go/tokens_test.go b/pkg/sdk/go/tokens_test.go index bdf537448..8a1464603 100644 --- a/pkg/sdk/go/tokens_test.go +++ b/pkg/sdk/go/tokens_test.go @@ -8,6 +8,7 @@ import ( "net/http" "testing" + "github.com/absmach/magistrala" "github.com/absmach/magistrala/internal/apiutil" "github.com/absmach/magistrala/pkg/errors" sdk "github.com/absmach/magistrala/pkg/sdk/go" @@ -16,7 +17,7 @@ import ( ) func TestIssueToken(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newClientServer() defer ts.Close() conf := sdk.Config{ @@ -41,6 +42,7 @@ func TestIssueToken(t *testing.T) { cases := []struct { desc string login sdk.Login + token *magistrala.Token dbClient sdk.User err errors.SDKError }{ @@ -48,37 +50,46 @@ func TestIssueToken(t *testing.T) { desc: "issue token for a new user", login: sdk.Login{Identity: client.Credentials.Identity, Secret: client.Credentials.Secret}, dbClient: rClient, - err: nil, + token: &magistrala.Token{ + AccessToken: validToken, + RefreshToken: &validToken, + AccessType: "Bearer", + }, + err: nil, }, { desc: "issue token for an empty user", login: sdk.Login{}, + token: &magistrala.Token{}, err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrMissingIdentity), http.StatusInternalServerError), }, { desc: "issue token for invalid identity", login: sdk.Login{Identity: "invalid", Secret: "secret"}, + token: &magistrala.Token{}, dbClient: wrongClient, err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, } for _, tc := range cases { - repoCall := cRepo.On("RetrieveByIdentity", mock.Anything, mock.Anything).Return(convertClient(tc.dbClient), tc.err) + repoCall := auth.On("Issue", mock.Anything, mock.Anything).Return(tc.token, nil) + repoCall1 := cRepo.On("RetrieveByIdentity", mock.Anything, mock.Anything).Return(convertClient(tc.dbClient), tc.err) token, err := mgsdk.CreateToken(tc.login) switch tc.err { case nil: assert.NotEmpty(t, token, fmt.Sprintf("%s: expected token, got empty", tc.desc)) - ok := repoCall.Parent.AssertCalled(t, "RetrieveByIdentity", mock.Anything, mock.Anything) + ok := repoCall1.Parent.AssertCalled(t, "RetrieveByIdentity", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("RetrieveByIdentity was not called on %s", tc.desc)) default: assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) } repoCall.Unset() + repoCall1.Unset() } } func TestRefreshToken(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newClientServer() defer ts.Close() conf := sdk.Config{ @@ -99,35 +110,34 @@ func TestRefreshToken(t *testing.T) { rUser.Credentials.Secret, _ = phasher.Hash(user.Credentials.Secret) cases := []struct { - desc string - token string - err errors.SDKError + desc string + token string + rtoken *magistrala.Token + err errors.SDKError }{ { desc: "refresh token for a valid refresh token", token: token, - err: nil, + rtoken: &magistrala.Token{ + AccessToken: validToken, + RefreshToken: &validToken, + AccessType: "Bearer", + }, + err: nil, }, { - desc: "refresh token for a valid access token", - token: token, - err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), - }, - { - desc: "refresh token for an empty token", - token: "", - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusInternalServerError), + desc: "refresh token for an empty token", + token: "", + rtoken: &magistrala.Token{}, + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrBearerToken), http.StatusUnauthorized), }, } for _, tc := range cases { repoCall := cRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(convertClient(user), tc.err) + repoCall1 := auth.On("Refresh", mock.Anything, mock.Anything).Return(tc.rtoken, nil) _, err := mgsdk.RefreshToken(sdk.Login{}, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) - if tc.err == nil { - assert.NotEmpty(t, token, fmt.Sprintf("%s: expected token, got empty", tc.desc)) - ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) - } repoCall.Unset() + repoCall1.Unset() } } diff --git a/pkg/sdk/go/users_test.go b/pkg/sdk/go/users_test.go index 5c9655f4f..65d8b465b 100644 --- a/pkg/sdk/go/users_test.go +++ b/pkg/sdk/go/users_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/absmach/magistrala" authmocks "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/internal/apiutil" "github.com/absmach/magistrala/internal/groups" @@ -28,8 +29,12 @@ import ( ) var ( - id = generateUUID(&testing.T{}) - validToken = "token" + id = generateUUID(&testing.T{}) + validToken = "token" + validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22" + ownerRelation = "owner" + userKind = "users" + userType = "user" ) func newClientServer() (*httptest.Server, *mocks.Repository, *gmocks.Repository, *authmocks.Service) { @@ -48,7 +53,7 @@ func newClientServer() (*httptest.Server, *mocks.Repository, *gmocks.Repository, } func TestCreateClient(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newClientServer() defer ts.Close() user := sdk.User{ @@ -175,7 +180,12 @@ func TestCreateClient(t *testing.T) { }, } for _, tc := range cases { - repoCall := cRepo.On("Save", mock.Anything, mock.Anything).Return(tc.response, tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + } + repoCall1 := auth.On("AddPolicy", mock.Anything, mock.Anything).Return(&magistrala.AddPolicyRes{Authorized: true}, nil) + repoCall2 := cRepo.On("Save", mock.Anything, mock.Anything).Return(tc.response, tc.err) rClient, err := mgsdk.CreateUser(tc.client, tc.token) tc.response.ID = rClient.ID tc.response.Owner = rClient.Owner @@ -185,9 +195,11 @@ func TestCreateClient(t *testing.T) { assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, rClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) + ok := repoCall2.Parent.AssertCalled(t, "Save", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Save was not called on %s", tc.desc)) } + repoCall2.Unset() + repoCall1.Unset() repoCall.Unset() } } @@ -266,7 +278,7 @@ func TestListClients(t *testing.T) { token: token, offset: offset, limit: 0, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), response: nil, }, { @@ -274,7 +286,7 @@ func TestListClients(t *testing.T) { token: token, offset: offset, limit: 110, - err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusInternalServerError), + err: errors.NewSDKErrorWithStatus(errors.Wrap(apiutil.ErrValidation, apiutil.ErrLimitSize), http.StatusBadRequest), response: []sdk.User(nil), }, { @@ -360,13 +372,13 @@ func TestListClients(t *testing.T) { Tag: tc.tag, } - repoCall1 := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(errors.ErrAuthorization) - repoCall2 := cRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(mgclients.ClientsPage{Page: convertClientPage(pm), Clients: convertClients(tc.response)}, tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := cRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(mgclients.ClientsPage{Page: convertClientPage(pm), Clients: convertClients(tc.response)}, tc.err) page, err := mgsdk.Users(pm, validToken) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, page.Users, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, page)) + repoCall.Unset() repoCall1.Unset() - repoCall2.Unset() } } @@ -405,7 +417,7 @@ func TestClient(t *testing.T) { response: sdk.User{}, token: invalidToken, clientID: generateUUID(t), - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, { desc: "view client with valid token and invalid client id", @@ -419,21 +431,32 @@ func TestClient(t *testing.T) { response: sdk.User{}, token: invalidToken, clientID: mocks.WrongID, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, } for _, tc := range cases { - repoCall := auth.On("Evaluate", mock.Anything, mock.Anything, mock.Anything).Return(nil) - repoCall1 := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + } + authReq := &magistrala.AuthorizeReq{ + SubjectType: userType, + SubjectKind: userKind, + Subject: validID, + Permission: ownerRelation, + ObjectType: userType, + Object: tc.clientID, + } + repoCall1 := auth.On("Authorize", mock.Anything, authReq).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) repoCall2 := cRepo.On("RetrieveByID", mock.Anything, tc.clientID).Return(convertClient(tc.response), tc.err) rClient, err := mgsdk.User(tc.clientID, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) tc.response.Credentials.Secret = "" assert.Equal(t, tc.response, rClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rClient)) if tc.err == nil { - ok := repoCall1.Parent.AssertCalled(t, "CheckAdmin", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc)) + ok := repoCall.Parent.AssertCalled(t, "Identify", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("Identify was not called on %s", tc.desc)) ok = repoCall2.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.clientID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) } @@ -444,7 +467,7 @@ func TestClient(t *testing.T) { } func TestProfile(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newClientServer() defer ts.Close() user = sdk.User{ @@ -475,20 +498,25 @@ func TestProfile(t *testing.T) { desc: "view client with an invalid token", response: sdk.User{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, } for _, tc := range cases { - repoCall := cRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + } + repoCal1 := cRepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) rClient, err := mgsdk.UserProfile(tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) tc.response.Credentials.Secret = "" assert.Equal(t, tc.response, rClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, rClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, mock.Anything) + ok := repoCal1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) } + repoCal1.Unset() repoCall.Unset() } } @@ -537,7 +565,7 @@ func TestUpdateClient(t *testing.T) { client: client1, response: sdk.User{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, { desc: "update client name with invalid id", @@ -564,19 +592,32 @@ func TestUpdateClient(t *testing.T) { } for _, tc := range cases { - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := cRepo.On("Update", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + } + authReq := &magistrala.AuthorizeReq{ + SubjectType: userType, + SubjectKind: userKind, + Subject: validID, + Permission: ownerRelation, + ObjectType: userType, + Object: tc.client.ID, + } + repoCall1 := auth.On("Authorize", mock.Anything, authReq).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := cRepo.On("Update", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) uClient, err := mgsdk.UpdateUser(tc.client, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc)) - ok = repoCall1.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) + ok := repoCall.Parent.AssertCalled(t, "Identify", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("Identify was not called on %s", tc.desc)) + ok = repoCall2.Parent.AssertCalled(t, "Update", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("Update was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() } } @@ -623,7 +664,7 @@ func TestUpdateClientTags(t *testing.T) { client: client1, response: sdk.User{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, { desc: "update client name with invalid id", @@ -651,19 +692,32 @@ func TestUpdateClientTags(t *testing.T) { } for _, tc := range cases { - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := cRepo.On("UpdateTags", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + } + authReq := &magistrala.AuthorizeReq{ + SubjectType: userType, + SubjectKind: userKind, + Subject: validID, + Permission: ownerRelation, + ObjectType: userType, + Object: tc.client.ID, + } + repoCall1 := auth.On("Authorize", mock.Anything, authReq).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := cRepo.On("UpdateTags", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) uClient, err := mgsdk.UpdateUserTags(tc.client, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc)) - ok = repoCall1.Parent.AssertCalled(t, "UpdateTags", mock.Anything, mock.Anything) + ok := repoCall.Parent.AssertCalled(t, "Identify", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("Identify was not called on %s", tc.desc)) + ok = repoCall2.Parent.AssertCalled(t, "UpdateTags", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateTags was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() } } @@ -708,7 +762,7 @@ func TestUpdateClientIdentity(t *testing.T) { client: user, response: sdk.User{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, { desc: "update client name with invalid id", @@ -736,24 +790,37 @@ func TestUpdateClientIdentity(t *testing.T) { } for _, tc := range cases { - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := cRepo.On("UpdateIdentity", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + } + authReq := &magistrala.AuthorizeReq{ + SubjectType: userType, + SubjectKind: userKind, + Subject: validID, + Permission: ownerRelation, + ObjectType: userType, + Object: tc.client.ID, + } + repoCall1 := auth.On("Authorize", mock.Anything, authReq).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := cRepo.On("UpdateIdentity", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) uClient, err := mgsdk.UpdateUserIdentity(tc.client, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc)) - ok = repoCall1.Parent.AssertCalled(t, "UpdateIdentity", mock.Anything, mock.Anything) + ok := repoCall.Parent.AssertCalled(t, "Identify", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("Identify was not called on %s", tc.desc)) + ok = repoCall2.Parent.AssertCalled(t, "UpdateIdentity", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateIdentity was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() } } func TestUpdateClientSecret(t *testing.T) { - ts, cRepo, _, _ := newClientServer() + ts, cRepo, _, auth := newClientServer() defer ts.Close() conf := sdk.Config{ @@ -790,7 +857,7 @@ func TestUpdateClientSecret(t *testing.T) { token: "non-existent", response: sdk.User{}, repoErr: errors.ErrAuthentication, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, { desc: "update client secret with wrong old secret", @@ -804,23 +871,30 @@ func TestUpdateClientSecret(t *testing.T) { } for _, tc := range cases { - repoCall := cRepo.On("RetrieveByID", mock.Anything, user.ID).Return(convertClient(tc.response), tc.repoErr) - repoCall1 := cRepo.On("RetrieveByIdentity", mock.Anything, user.Credentials.Identity).Return(convertClient(tc.response), tc.repoErr) - repoCall2 := cRepo.On("UpdateSecret", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.repoErr) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: user.ID}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + } + repoCall1 := auth.On("Issue", mock.Anything, mock.Anything).Return(&magistrala.Token{AccessToken: validToken}, nil) + repoCall2 := cRepo.On("RetrieveByID", mock.Anything, user.ID).Return(convertClient(tc.response), tc.repoErr) + repoCall3 := cRepo.On("RetrieveByIdentity", mock.Anything, user.Credentials.Identity).Return(convertClient(tc.response), tc.repoErr) + repoCall4 := cRepo.On("UpdateSecret", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.repoErr) uClient, err := mgsdk.UpdatePassword(tc.oldSecret, tc.newSecret, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, user.ID) + ok := repoCall2.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, user.ID) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) - ok = repoCall1.Parent.AssertCalled(t, "RetrieveByIdentity", mock.Anything, user.Credentials.Identity) + ok = repoCall3.Parent.AssertCalled(t, "RetrieveByIdentity", mock.Anything, user.Credentials.Identity) assert.True(t, ok, fmt.Sprintf("RetrieveByIdentity was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "UpdateSecret", mock.Anything, mock.Anything) + ok = repoCall4.Parent.AssertCalled(t, "UpdateSecret", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateSecret was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + repoCall3.Unset() + repoCall4.Unset() } } @@ -865,7 +939,7 @@ func TestUpdateClientOwner(t *testing.T) { client: client2, response: sdk.User{}, token: invalidToken, - err: errors.NewSDKErrorWithStatus(errors.Wrap(errors.ErrAuthentication, sdk.ErrInvalidJWT), http.StatusUnauthorized), + err: errors.NewSDKErrorWithStatus(errors.ErrAuthentication, http.StatusUnauthorized), }, { desc: "update client name with invalid id", @@ -892,19 +966,43 @@ func TestUpdateClientOwner(t *testing.T) { } for _, tc := range cases { - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := cRepo.On("UpdateOwner", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + if tc.token != validToken { + repoCall = auth.On("Identify", mock.Anything, mock.Anything).Return(&magistrala.IdentityRes{}, errors.ErrAuthentication) + } + authReq := &magistrala.AuthorizeReq{ + SubjectType: userType, + SubjectKind: userKind, + Subject: validID, + Permission: ownerRelation, + ObjectType: userType, + Object: tc.client.ID, + } + repoCall1 := auth.On("Authorize", mock.Anything, authReq).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + deleteReq := &magistrala.DeletePolicyReq{ + SubjectType: userType, + Subject: validID, + Relation: ownerRelation, + ObjectType: userType, + Object: tc.client.ID, + } + repoCall2 := auth.On("DeletePolicy", mock.Anything, deleteReq).Return(&magistrala.DeletePolicyRes{Deleted: true}, nil) + repoCall3 := auth.On("AddPolicy", mock.Anything, mock.Anything).Return(&magistrala.AddPolicyRes{Authorized: true}, nil) + repoCall4 := cRepo.On("UpdateOwner", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.err) uClient, err := mgsdk.UpdateUserOwner(tc.client, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, uClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, uClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc)) - ok = repoCall1.Parent.AssertCalled(t, "UpdateOwner", mock.Anything, mock.Anything) + ok := repoCall.Parent.AssertCalled(t, "Identify", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("Identify was not called on %s", tc.desc)) + ok = repoCall4.Parent.AssertCalled(t, "UpdateOwner", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("UpdateOwner was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() + repoCall2.Unset() + repoCall3.Unset() + repoCall4.Unset() } } @@ -962,23 +1060,33 @@ func TestEnableClient(t *testing.T) { } for _, tc := range cases { - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := cRepo.On("RetrieveByID", mock.Anything, tc.id).Return(convertClient(tc.client), tc.repoErr) - repoCall2 := cRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.repoErr) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + authReq := &magistrala.AuthorizeReq{ + SubjectType: userType, + SubjectKind: userKind, + Subject: validID, + Permission: ownerRelation, + ObjectType: userType, + Object: tc.id, + } + repoCall1 := auth.On("Authorize", mock.Anything, authReq).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := cRepo.On("RetrieveByID", mock.Anything, tc.id).Return(convertClient(tc.client), tc.repoErr) + repoCall3 := cRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.repoErr) eClient, err := mgsdk.EnableUser(tc.id, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, eClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, eClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc)) - ok = repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.id) + ok := repoCall.Parent.AssertCalled(t, "Identify", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("Identify was not called on %s", tc.desc)) + ok = repoCall2.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.id) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) + ok = repoCall3.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + repoCall3.Unset() } cases2 := []struct { @@ -1022,7 +1130,7 @@ func TestEnableClient(t *testing.T) { Limit: 100, Status: tc.status, } - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) repoCall1 := cRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertClientsPage(tc.response), nil) clientsPage, err := mgsdk.Users(pm, validToken) assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) @@ -1087,23 +1195,33 @@ func TestDisableClient(t *testing.T) { } for _, tc := range cases { - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) - repoCall1 := cRepo.On("RetrieveByID", mock.Anything, tc.id).Return(convertClient(tc.client), tc.repoErr) - repoCall2 := cRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.repoErr) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) + authReq := &magistrala.AuthorizeReq{ + SubjectType: userType, + SubjectKind: userKind, + Subject: validID, + Permission: ownerRelation, + ObjectType: userType, + Object: tc.id, + } + repoCall1 := auth.On("Authorize", mock.Anything, authReq).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := cRepo.On("RetrieveByID", mock.Anything, tc.id).Return(convertClient(tc.client), tc.repoErr) + repoCall3 := cRepo.On("ChangeStatus", mock.Anything, mock.Anything).Return(convertClient(tc.response), tc.repoErr) dClient, err := mgsdk.DisableUser(tc.id, tc.token) assert.Equal(t, tc.err, err, fmt.Sprintf("%s: expected error %s, got %s", tc.desc, tc.err, err)) assert.Equal(t, tc.response, dClient, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.response, dClient)) if tc.err == nil { - ok := repoCall.Parent.AssertCalled(t, "CheckAdmin", mock.Anything, mock.Anything) - assert.True(t, ok, fmt.Sprintf("CheckAdmin was not called on %s", tc.desc)) - ok = repoCall1.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.id) + ok := repoCall.Parent.AssertCalled(t, "Identify", mock.Anything, mock.Anything) + assert.True(t, ok, fmt.Sprintf("Identify was not called on %s", tc.desc)) + ok = repoCall2.Parent.AssertCalled(t, "RetrieveByID", mock.Anything, tc.id) assert.True(t, ok, fmt.Sprintf("RetrieveByID was not called on %s", tc.desc)) - ok = repoCall2.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) + ok = repoCall3.Parent.AssertCalled(t, "ChangeStatus", mock.Anything, mock.Anything) assert.True(t, ok, fmt.Sprintf("ChangeStatus was not called on %s", tc.desc)) } repoCall.Unset() repoCall1.Unset() repoCall2.Unset() + repoCall3.Unset() } cases2 := []struct { @@ -1147,7 +1265,7 @@ func TestDisableClient(t *testing.T) { Limit: 100, Status: tc.status, } - repoCall := auth.On("CheckAdmin", mock.Anything, mock.Anything).Return(nil) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: validToken}).Return(&magistrala.IdentityRes{Id: validID}, nil) repoCall1 := cRepo.On("RetrieveAll", mock.Anything, mock.Anything).Return(convertClientsPage(tc.response), nil) page, err := mgsdk.Users(pm, validToken) assert.Nil(t, err, fmt.Sprintf("unexpected error: %s", err)) diff --git a/things/service.go b/things/service.go index 80df4e3d8..30e56c174 100644 --- a/things/service.go +++ b/things/service.go @@ -441,7 +441,7 @@ func (svc *service) authorize(ctx context.Context, subjType, subjKind, subj, per } res, err := svc.auth.Authorize(ctx, req) if err != nil { - return "", errors.Wrap(errors.ErrAuthorization, err) + return "", err } if !res.GetAuthorized() { return "", errors.ErrAuthorization