Skip to content
This repository has been archived by the owner on Oct 14, 2024. It is now read-only.

Commit

Permalink
Fix Certs Tests
Browse files Browse the repository at this point in the history
Fix tests in certs package
  • Loading branch information
rodneyosodo committed Nov 17, 2023
1 parent 7acf21a commit 0bc57b4
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 47 deletions.
3 changes: 2 additions & 1 deletion certs/postgres/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package postgres_test

import (
"database/sql"
"fmt"
"os"
"testing"
Expand Down Expand Up @@ -41,7 +42,7 @@ func TestMain(m *testing.M) {

if err := pool.Retry(func() error {
url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port)
db, err = sqlx.Open("postgres", url)
db, err := sql.Open("pgx", url)
if err != nil {
return err
}
Expand Down
129 changes: 83 additions & 46 deletions certs/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (
"testing"
"time"

"github.com/absmach/magistrala"
authmocks "github.com/absmach/magistrala/auth/mocks"
"github.com/absmach/magistrala/certs"
"github.com/absmach/magistrala/certs/mocks"
chmocks "github.com/absmach/magistrala/internal/groups/mocks"
"github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
mgsdk "github.com/absmach/magistrala/pkg/sdk/go"
"github.com/absmach/magistrala/pkg/uuid"
Expand All @@ -24,18 +26,20 @@ import (
thmocks "github.com/absmach/magistrala/things/mocks"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

const (
wrongValue = "wrong-value"
email = "[email protected]"
token = "token"
thingsNum = 1
thingKey = "thingKey"
thingID = "1"
ttl = "1h"
certNum = 10
invalid = "invalid"
email = "[email protected]"
token = "token"
thingsNum = 1
thingKey = "thingKey"
thingID = "1"
ttl = "1h"
certNum = 10
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"

cfgAuthTimeout = "1s"

Expand All @@ -45,8 +49,16 @@ const (
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
)

func newService() (certs.Service, error) {
tsvc, auth := newThingsService()
func newThingsServer(svc things.Service) *httptest.Server {
logger := logger.NewMock()
mux := chi.NewMux()
httpapi.MakeHandler(svc, nil, mux, logger, instanceID)
return httptest.NewServer(mux)
}

func newService(t *testing.T) (certs.Service, *authmocks.Service, *thmocks.Repository) {
auth := new(authmocks.Service)
tsvc, trepo := newThingsService(auth)
server := newThingsServer(tsvc)

config := mgsdk.Config{
Expand All @@ -57,33 +69,27 @@ func newService() (certs.Service, error) {
repo := mocks.NewCertsRepository()

tlsCert, caCert, err := certs.LoadCertificates(caPath, caKeyPath)
if err != nil {
return nil, err
}
require.Nil(t, err, fmt.Sprintf("unexpected cert loading error: %s\n", err))

authTimeout, err := time.ParseDuration(cfgAuthTimeout)
if err != nil {
return nil, err
}
require.Nil(t, err, fmt.Sprintf("unexpected auth timeout parsing error: %s\n", err))

pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout)

return certs.New(auth, repo, sdk, pki), nil
return certs.New(auth, repo, sdk, pki), auth, trepo
}

func newThingsService() (things.Service, *authmocks.Service) {
auth := new(authmocks.Service)
func newThingsService(auth *authmocks.Service) (things.Service, *thmocks.Repository) {
thingCache := thmocks.NewCache()
idProvider := uuid.NewMock()
cRepo := new(thmocks.Repository)
gRepo := new(chmocks.Repository)

return things.NewService(auth, cRepo, gRepo, thingCache, idProvider), auth
return things.NewService(auth, cRepo, gRepo, thingCache, idProvider), cRepo
}

func TestIssueCert(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

cases := []struct {
token string
Expand All @@ -109,29 +115,40 @@ func TestIssueCert(t *testing.T) {
},
{
desc: "issue new cert for non existing thing id",
token: wrongValue,
token: invalid,
thingID: thingID,
ttl: ttl,
err: errors.ErrAuthentication,
},
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID, Credentials: clients.Credentials{Secret: thingKey}}, tc.err)
c, err := svc.IssueCert(context.Background(), tc.token, tc.thingID, tc.ttl)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
cert, _ := certs.ReadCert([]byte(c.ClientCert))
if cert != nil {
assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, thingKey, cert.Subject.CommonName))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}

func TestRevokeCert(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

_, err = svc.IssueCert(context.Background(), token, thingID, ttl)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil)
_, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()

cases := []struct {
token string
Expand All @@ -147,7 +164,7 @@ func TestRevokeCert(t *testing.T) {
},
{
desc: "revoke cert for invalid token",
token: wrongValue,
token: invalid,
thingID: thingID,
err: errors.ErrAuthentication,
},
Expand All @@ -160,18 +177,29 @@ func TestRevokeCert(t *testing.T) {
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID, Credentials: clients.Credentials{Secret: thingKey}}, tc.err)
_, err := svc.RevokeCert(context.Background(), tc.token, tc.thingID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}

func TestListCerts(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

for i := 0; i < certNum; i++ {
_, err = svc.IssueCert(context.Background(), token, thingID, ttl)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil)
_, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}

cases := []struct {
Expand All @@ -194,7 +222,7 @@ func TestListCerts(t *testing.T) {
},
{
desc: "list all certs with invalid token",
token: wrongValue,
token: invalid,
thingID: thingID,
offset: 0,
limit: certNum,
Expand Down Expand Up @@ -222,21 +250,28 @@ func TestListCerts(t *testing.T) {
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
size := uint64(len(page.Certs))
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
}
}

func TestListSerials(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

var issuedCerts []certs.Cert
for i := 0; i < certNum; i++ {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil)
cert, err := svc.IssueCert(context.Background(), token, thingID, ttl)
assert.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()

crt := certs.Cert{
OwnerID: cert.OwnerID,
Expand Down Expand Up @@ -267,7 +302,7 @@ func TestListSerials(t *testing.T) {
},
{
desc: "list all certs with invalid token",
token: wrongValue,
token: invalid,
thingID: thingID,
offset: 0,
limit: certNum,
Expand Down Expand Up @@ -295,18 +330,25 @@ func TestListSerials(t *testing.T) {
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
page, err := svc.ListSerials(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
assert.Equal(t, tc.certs, page.Certs, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.certs, page.Certs))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
}
}

func TestViewCert(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil)
ic, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()

cert := certs.Cert{
ThingID: thingID,
Expand All @@ -331,30 +373,25 @@ func TestViewCert(t *testing.T) {
},
{
desc: "list cert with invalid token",
token: wrongValue,
token: invalid,
serialID: cert.Serial,
cert: certs.Cert{},
err: errors.ErrAuthentication,
},
{
desc: "list cert with invalid serial",
token: token,
serialID: wrongValue,
serialID: invalid,
cert: certs.Cert{},
err: errors.ErrNotFound,
},
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
cert, err := svc.ViewCert(context.Background(), tc.token, tc.serialID)
assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.cert, cert))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
}
}

func newThingsServer(svc things.Service) *httptest.Server {
logger := logger.NewMock()
mux := chi.NewMux()
httpapi.MakeHandler(svc, nil, mux, logger, instanceID)
return httptest.NewServer(mux)
}

0 comments on commit 0bc57b4

Please sign in to comment.