Skip to content
This repository has been archived by the owner on Oct 14, 2024. It is now read-only.

NOISSUE - Fix Certs Tests #51

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion certs/postgres/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package postgres_test

import (
"database/sql"
"fmt"
"os"
"testing"
Expand Down Expand Up @@ -41,7 +42,7 @@ func TestMain(m *testing.M) {

if err := pool.Retry(func() error {
url := fmt.Sprintf("host=localhost port=%s user=test dbname=test password=test sslmode=disable", port)
db, err = sqlx.Open("postgres", url)
db, err := sql.Open("pgx", url)
if err != nil {
return err
}
Expand Down
129 changes: 83 additions & 46 deletions certs/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ import (
"testing"
"time"

"github.com/absmach/magistrala"
authmocks "github.com/absmach/magistrala/auth/mocks"
"github.com/absmach/magistrala/certs"
"github.com/absmach/magistrala/certs/mocks"
chmocks "github.com/absmach/magistrala/internal/groups/mocks"
"github.com/absmach/magistrala/logger"
"github.com/absmach/magistrala/pkg/clients"
"github.com/absmach/magistrala/pkg/errors"
mgsdk "github.com/absmach/magistrala/pkg/sdk/go"
"github.com/absmach/magistrala/pkg/uuid"
Expand All @@ -24,18 +26,20 @@ import (
thmocks "github.com/absmach/magistrala/things/mocks"
"github.com/go-chi/chi/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

const (
wrongValue = "wrong-value"
email = "[email protected]"
token = "token"
thingsNum = 1
thingKey = "thingKey"
thingID = "1"
ttl = "1h"
certNum = 10
invalid = "invalid"
email = "[email protected]"
token = "token"
thingsNum = 1
thingKey = "thingKey"
thingID = "1"
ttl = "1h"
certNum = 10
validID = "d4ebb847-5d0e-4e46-bdd9-b6aceaaa3a22"

cfgAuthTimeout = "1s"

Expand All @@ -45,8 +49,16 @@ const (
instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002"
)

func newService() (certs.Service, error) {
tsvc, auth := newThingsService()
func newThingsServer(svc things.Service) *httptest.Server {
logger := logger.NewMock()
mux := chi.NewMux()
httpapi.MakeHandler(svc, nil, mux, logger, instanceID)
return httptest.NewServer(mux)
}

func newService(t *testing.T) (certs.Service, *authmocks.Service, *thmocks.Repository) {
auth := new(authmocks.Service)
tsvc, trepo := newThingsService(auth)
server := newThingsServer(tsvc)

config := mgsdk.Config{
Expand All @@ -57,33 +69,27 @@ func newService() (certs.Service, error) {
repo := mocks.NewCertsRepository()

tlsCert, caCert, err := certs.LoadCertificates(caPath, caKeyPath)
if err != nil {
return nil, err
}
require.Nil(t, err, fmt.Sprintf("unexpected cert loading error: %s\n", err))

authTimeout, err := time.ParseDuration(cfgAuthTimeout)
if err != nil {
return nil, err
}
require.Nil(t, err, fmt.Sprintf("unexpected auth timeout parsing error: %s\n", err))

pki := mocks.NewPkiAgent(tlsCert, caCert, cfgSignHoursValid, authTimeout)

return certs.New(auth, repo, sdk, pki), nil
return certs.New(auth, repo, sdk, pki), auth, trepo
}

func newThingsService() (things.Service, *authmocks.Service) {
auth := new(authmocks.Service)
func newThingsService(auth *authmocks.Service) (things.Service, *thmocks.Repository) {
thingCache := thmocks.NewCache()
idProvider := uuid.NewMock()
cRepo := new(thmocks.Repository)
gRepo := new(chmocks.Repository)

return things.NewService(auth, cRepo, gRepo, thingCache, idProvider), auth
return things.NewService(auth, cRepo, gRepo, thingCache, idProvider), cRepo
}

func TestIssueCert(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

cases := []struct {
token string
Expand All @@ -109,29 +115,40 @@ func TestIssueCert(t *testing.T) {
},
{
desc: "issue new cert for non existing thing id",
token: wrongValue,
token: invalid,
thingID: thingID,
ttl: ttl,
err: errors.ErrAuthentication,
},
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID, Credentials: clients.Credentials{Secret: thingKey}}, tc.err)
c, err := svc.IssueCert(context.Background(), tc.token, tc.thingID, tc.ttl)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
cert, _ := certs.ReadCert([]byte(c.ClientCert))
if cert != nil {
assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
assert.True(t, strings.Contains(cert.Subject.CommonName, thingKey), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, thingKey, cert.Subject.CommonName))
}
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}

func TestRevokeCert(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

_, err = svc.IssueCert(context.Background(), token, thingID, ttl)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil)
_, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()

cases := []struct {
token string
Expand All @@ -147,7 +164,7 @@ func TestRevokeCert(t *testing.T) {
},
{
desc: "revoke cert for invalid token",
token: wrongValue,
token: invalid,
thingID: thingID,
err: errors.ErrAuthentication,
},
Expand All @@ -160,18 +177,29 @@ func TestRevokeCert(t *testing.T) {
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, tc.err)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: tc.thingID, Credentials: clients.Credentials{Secret: thingKey}}, tc.err)
_, err := svc.RevokeCert(context.Background(), tc.token, tc.thingID)
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}
}

func TestListCerts(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

for i := 0; i < certNum; i++ {
_, err = svc.IssueCert(context.Background(), token, thingID, ttl)
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil)
_, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()
}

cases := []struct {
Expand All @@ -194,7 +222,7 @@ func TestListCerts(t *testing.T) {
},
{
desc: "list all certs with invalid token",
token: wrongValue,
token: invalid,
thingID: thingID,
offset: 0,
limit: certNum,
Expand Down Expand Up @@ -222,21 +250,28 @@ func TestListCerts(t *testing.T) {
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
page, err := svc.ListCerts(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
size := uint64(len(page.Certs))
assert.Equal(t, tc.size, size, fmt.Sprintf("%s: expected %d got %d\n", tc.desc, tc.size, size))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
}
}

func TestListSerials(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

var issuedCerts []certs.Cert
for i := 0; i < certNum; i++ {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil)
cert, err := svc.IssueCert(context.Background(), token, thingID, ttl)
assert.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()

crt := certs.Cert{
OwnerID: cert.OwnerID,
Expand Down Expand Up @@ -267,7 +302,7 @@ func TestListSerials(t *testing.T) {
},
{
desc: "list all certs with invalid token",
token: wrongValue,
token: invalid,
thingID: thingID,
offset: 0,
limit: certNum,
Expand Down Expand Up @@ -295,18 +330,25 @@ func TestListSerials(t *testing.T) {
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
page, err := svc.ListSerials(context.Background(), tc.token, tc.thingID, tc.offset, tc.limit)
assert.Equal(t, tc.certs, page.Certs, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.certs, page.Certs))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
}
}

func TestViewCert(t *testing.T) {
svc, err := newService()
require.Nil(t, err, fmt.Sprintf("unexpected service creation error: %s\n", err))
svc, auth, trepo := newService(t)

repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
repoCall1 := auth.On("Authorize", mock.Anything, mock.Anything).Return(&magistrala.AuthorizeRes{Authorized: true}, nil)
repoCall2 := trepo.On("RetrieveByID", mock.Anything, mock.Anything).Return(clients.Client{ID: thingID, Credentials: clients.Credentials{Secret: thingKey}}, nil)
ic, err := svc.IssueCert(context.Background(), token, thingID, ttl)
require.Nil(t, err, fmt.Sprintf("unexpected cert creation error: %s\n", err))
repoCall.Unset()
repoCall1.Unset()
repoCall2.Unset()

cert := certs.Cert{
ThingID: thingID,
Expand All @@ -331,30 +373,25 @@ func TestViewCert(t *testing.T) {
},
{
desc: "list cert with invalid token",
token: wrongValue,
token: invalid,
serialID: cert.Serial,
cert: certs.Cert{},
err: errors.ErrAuthentication,
},
{
desc: "list cert with invalid serial",
token: token,
serialID: wrongValue,
serialID: invalid,
cert: certs.Cert{},
err: errors.ErrNotFound,
},
}

for _, tc := range cases {
repoCall := auth.On("Identify", mock.Anything, &magistrala.IdentityReq{Token: tc.token}).Return(&magistrala.IdentityRes{Id: validID}, nil)
cert, err := svc.ViewCert(context.Background(), tc.token, tc.serialID)
assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s: expected %v got %v\n", tc.desc, tc.cert, cert))
assert.True(t, errors.Contains(err, tc.err), fmt.Sprintf("%s: expected %s got %s\n", tc.desc, tc.err, err))
repoCall.Unset()
}
}

func newThingsServer(svc things.Service) *httptest.Server {
logger := logger.NewMock()
mux := chi.NewMux()
httpapi.MakeHandler(svc, nil, mux, logger, instanceID)
return httptest.NewServer(mux)
}