diff --git a/certs/postgres/setup_test.go b/certs/postgres/setup_test.go index b110296bd..7d0416247 100644 --- a/certs/postgres/setup_test.go +++ b/certs/postgres/setup_test.go @@ -4,6 +4,7 @@ package postgres_test import ( + "database/sql" "fmt" "os" "testing" @@ -41,7 +42,7 @@ func TestMain(m *testing.M) { if err := pool.Retry(func() error { url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port) - db, err = sqlx.Open("postgres", url) + db, err := sql.Open("pgx", url) if err != nil { return err } diff --git a/certs/service_test.go b/certs/service_test.go index 6790af034..66124ac03 100644 --- a/certs/service_test.go +++ b/certs/service_test.go @@ -11,11 +11,13 @@ import ( "testing" "time" + "github.com/absmach/magistrala" authmocks "github.com/absmach/magistrala/auth/mocks" "github.com/absmach/magistrala/certs" "github.com/absmach/magistrala/certs/mocks" chmocks "github.com/absmach/magistrala/internal/groups/mocks" "github.com/absmach/magistrala/logger" + "github.com/absmach/magistrala/pkg/clients" "github.com/absmach/magistrala/pkg/errors" mgsdk "github.com/absmach/magistrala/pkg/sdk/go" "github.com/absmach/magistrala/pkg/uuid" @@ -24,18 +26,20 @@ import ( thmocks "github.com/absmach/magistrala/things/mocks" "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) const ( - wrongValue = "wrong-value" - email = "user@example.com" - token = "token" - thingsNum = 1 - thingKey = "thingKey" - thingID = "1" - ttl = "1h" - certNum = 10 + invalid = "invalid" + email = "user@example.com" + token = "token" + thingsNum = 1 + thingKey = "thingKey" + thingID = "1" + ttl = "1h" + certNum = 10 + validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22" cfgAuthTimeout = "1s" @@ -45,8 +49,16 @@ const ( instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" ) -func newService() (certs.Service, error) { - tsvc, auth := newThingsService() +func newThingsServer(svc things.Service) *httptest.Server { + logger := logger.NewMock() + mux := chi.NewMux() + httpapi.MakeHandler(svc, nil, mux, logger, instanceID) + return httptest.NewServer(mux) +} + +func newService(t *testing.T) (certs.Service, *authmocks.Service, *thmocks.Repository) { + auth := new(authmocks.Service) + tsvc, trepo := newThingsService(auth) server := newThingsServer(tsvc) config := mgsdk.Config{ @@ -57,33 +69,27 @@ func newService() (certs.Service, error) { repo := mocks.NewCertsRepository() tlsCert, caCert, err := certs.LoadCertificates(caPath, caKeyPath) - if err != nil { - return nil, err - } + require.Nil(t, err, fmt.Sprintf("unexpected cert loading error: %s\n", err)) authTimeout, err := time.ParseDuration(cfgAuthTimeout) - if err != nil { - return nil, err - } + require.Nil(t, err, fmt.Sprintf("unexpected auth timeout parsing error: %s\n", err)) pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout) - return certs.New(auth, repo, sdk, pki), nil + return certs.New(auth, repo, sdk, pki), auth, trepo } -func newThingsService() (things.Service, *authmocks.Service) { - auth := new(authmocks.Service) +func newThingsService(auth *authmocks.Service) (things.Service, *thmocks.Repository) { thingCache := thmocks.NewCache() idProvider := uuid.NewMock() cRepo := new(thmocks.Repository) gRepo := new(chmocks.Repository) - return things.NewService(auth, cRepo, gRepo, thingCache, idProvider), auth + return things.NewService(auth, cRepo, gRepo, thingCache, idProvider), cRepo } func TestIssueCert(t *testing.T) { - svc, err := newService() - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) + svc, auth, trepo := newService(t) cases := []struct { token string @@ -109,7 +115,7 @@ func TestIssueCert(t *testing.T) { }, { desc: "issue new cert for non existing thing id", - token: wrongValue, + token: invalid, thingID: thingID, ttl: ttl, err: errors.ErrAuthentication, @@ -117,21 +123,32 @@ func TestIssueCert(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID, Credentials: clients.Credentials{Secret: thingKey}}, tc.err) c, err := svc.IssueCert(context.Background(), tc.token, tc.thingID, tc.ttl) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) cert, _ := certs.ReadCert([]byte(c.ClientCert)) if cert != nil { - assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, thingKey, cert.Subject.CommonName)) } + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() } } func TestRevokeCert(t *testing.T) { - svc, err := newService() - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) + svc, auth, trepo := newService(t) - _, err = svc.IssueCert(context.Background(), token, thingID, ttl) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil) + _, err := svc.IssueCert(context.Background(), token, thingID, ttl) require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() cases := []struct { token string @@ -147,7 +164,7 @@ func TestRevokeCert(t *testing.T) { }, { desc: "revoke cert for invalid token", - token: wrongValue, + token: invalid, thingID: thingID, err: errors.ErrAuthentication, }, @@ -160,18 +177,29 @@ func TestRevokeCert(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID, Credentials: clients.Credentials{Secret: thingKey}}, tc.err) _, err := svc.RevokeCert(context.Background(), tc.token, tc.thingID) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() } } func TestListCerts(t *testing.T) { - svc, err := newService() - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) + svc, auth, trepo := newService(t) for i := 0; i < certNum; i++ { - _, err = svc.IssueCert(context.Background(), token, thingID, ttl) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil) + _, err := svc.IssueCert(context.Background(), token, thingID, ttl) require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() } cases := []struct { @@ -194,7 +222,7 @@ func TestListCerts(t *testing.T) { }, { desc: "list all certs with invalid token", - token: wrongValue, + token: invalid, thingID: thingID, offset: 0, limit: certNum, @@ -222,21 +250,28 @@ func TestListCerts(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit) size := uint64(len(page.Certs)) assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() } } func TestListSerials(t *testing.T) { - svc, err := newService() - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) + svc, auth, trepo := newService(t) var issuedCerts []certs.Cert for i := 0; i < certNum; i++ { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil) cert, err := svc.IssueCert(context.Background(), token, thingID, ttl) assert.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() crt := certs.Cert{ OwnerID: cert.OwnerID, @@ -267,7 +302,7 @@ func TestListSerials(t *testing.T) { }, { desc: "list all certs with invalid token", - token: wrongValue, + token: invalid, thingID: thingID, offset: 0, limit: certNum, @@ -295,18 +330,25 @@ func TestListSerials(t *testing.T) { } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) page, err := svc.ListSerials(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit) assert.Equal(t, tc.certs, page.Certs, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.certs, page.Certs)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() } } func TestViewCert(t *testing.T) { - svc, err := newService() - require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err)) + svc, auth, trepo := newService(t) + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil) + repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil) + repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil) ic, err := svc.IssueCert(context.Background(), token, thingID, ttl) require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err)) + repoCall.Unset() + repoCall1.Unset() + repoCall2.Unset() cert := certs.Cert{ ThingID: thingID, @@ -331,7 +373,7 @@ func TestViewCert(t *testing.T) { }, { desc: "list cert with invalid token", - token: wrongValue, + token: invalid, serialID: cert.Serial, cert: certs.Cert{}, err: errors.ErrAuthentication, @@ -339,22 +381,17 @@ func TestViewCert(t *testing.T) { { desc: "list cert with invalid serial", token: token, - serialID: wrongValue, + serialID: invalid, cert: certs.Cert{}, err: errors.ErrNotFound, }, } for _, tc := range cases { + repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil) cert, err := svc.ViewCert(context.Background(), tc.token, tc.serialID) assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.cert, cert)) assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err)) + repoCall.Unset() } } - -func newThingsServer(svc things.Service) *httptest.Server { - logger := logger.NewMock() - mux := chi.NewMux() - httpapi.MakeHandler(svc, nil, mux, logger, instanceID) - return httptest.NewServer(mux) -}