From 11972ae16c28204dccf49360852e0d90e7b43202 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 9 Dec 2024 16:29:45 +0300 Subject: [PATCH] refactor issueCerts and issueforcsr Signed-off-by: nyagamunene --- api/http/endpoint.go | 2 +- api/logging.go | 5 ++- api/metrics.go | 5 ++- certs.go | 3 +- certs_test.go | 2 +- mocks/service.go | 31 +++++++++---------- service.go | 73 +++++++++++++++++++++----------------------- tracing/certs.go | 5 ++- 8 files changed, 57 insertions(+), 69 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index c1fb611..72620ad 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -106,7 +106,7 @@ func issueCertEndpoint(svc certs.Service) endpoint.Endpoint { return issueCertRes{}, err } - cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options, nil) + cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options) if err != nil { return issueCertRes{}, err } diff --git a/api/logging.go b/api/logging.go index 83eaf47..487a903 100644 --- a/api/logging.go +++ b/api/logging.go @@ -5,7 +5,6 @@ package api import ( "context" - "crypto" "crypto/x509" "fmt" "log/slog" @@ -86,7 +85,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s return lm.svc.RetrieveCAToken(ctx) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { @@ -95,7 +94,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string } lm.logger.Info(message) }(time.Now()) - return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) + return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) { diff --git a/api/metrics.go b/api/metrics.go index b50b16d..fe2f4cd 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -5,7 +5,6 @@ package api import ( "context" - "crypto" "crypto/x509" "time" @@ -72,12 +71,12 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error return mm.svc.RetrieveCAToken(ctx) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) + return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { diff --git a/certs.go b/certs.go index 327bd6b..3d2ea0e 100644 --- a/certs.go +++ b/certs.go @@ -5,7 +5,6 @@ package certs import ( "context" - "crypto" "crypto/rsa" "crypto/x509" "net" @@ -159,7 +158,7 @@ type Service interface { RetrieveCAToken(ctx context.Context) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey crypto.PrivateKey) (Certificate, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) diff --git a/certs_test.go b/certs_test.go index f9e9200..28616e2 100644 --- a/certs_test.go +++ b/certs_test.go @@ -67,7 +67,7 @@ func TestIssueCert(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(tc.err) - _, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{}, nil) + _, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{}) require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) repoCall1.Unset() }) diff --git a/mocks/service.go b/mocks/service.go index 7460f4e..3351649 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -10,8 +10,6 @@ import ( certs "github.com/absmach/certs" - crypto "crypto" - mock "github.com/stretchr/testify/mock" x509 "crypto/x509" @@ -203,9 +201,9 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s return _c } -// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option, privKey -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { - ret := _m.Called(ctx, entityID, ttl, ipAddrs, option, privKey) +// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (certs.Certificate, error) { + ret := _m.Called(ctx, entityID, ttl, ipAddrs, option) if len(ret) == 0 { panic("no return value specified for IssueCert") @@ -213,17 +211,17 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) (certs.Certificate, error)); ok { - return rf(ctx, entityID, ttl, ipAddrs, option, privKey) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)); ok { + return rf(ctx, entityID, ttl, ipAddrs, option) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) certs.Certificate); ok { - r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) certs.Certificate); ok { + r0 = rf(ctx, entityID, ttl, ipAddrs, option) } else { r0 = ret.Get(0).(certs.Certificate) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) error); ok { - r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey) + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions) error); ok { + r1 = rf(ctx, entityID, ttl, ipAddrs, option) } else { r1 = ret.Error(1) } @@ -242,14 +240,13 @@ type MockService_IssueCert_Call struct { // - ttl string // - ipAddrs []string // - option certs.SubjectOptions -// - privKey crypto.PrivateKey -func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey interface{}) *MockService_IssueCert_Call { - return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, option, privKey)} +func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}) *MockService_IssueCert_Call { + return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, option)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey crypto.PrivateKey)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions)) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), args[5].(crypto.PrivateKey)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions)) }) return _c } @@ -259,7 +256,7 @@ func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) * return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)) *MockService_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/service.go b/service.go index 4950a14..1070d6c 100644 --- a/service.go +++ b/service.go @@ -55,6 +55,7 @@ var ( ErrCertInvalidType = errors.New("invalid cert type") ErrInvalidLength = errors.New("invalid length of serial numbers") ErrPrivKeyType = errors.New("unsupported private key type") + ErrPubKeyType = errors.New("unsupported public key type") ErrFailedParse = errors.New("failed to parse key PEM") ) @@ -95,49 +96,45 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -// issueCert generates and issues a certificate for a given backendID. -// It uses the RSA algorithm to generate a private key, and then creates a certificate -// using the provided template and the generated private key. -// The certificate is then stored in the repository using the CreateCert method. -// If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key crypto.PrivateKey) (Certificate, error) { - var privKey crypto.PrivateKey - var pubKey crypto.PublicKey - var err error +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) { + pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return Certificate{}, err + } - if key == nil { - pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - if err != nil { - return Certificate{}, err - } - privKey = pKey - pubKey = pKey.Public() - } else { - switch k := key.(type) { - case *rsa.PrivateKey: - privKey = k - pubKey = k.Public() - case *ecdsa.PrivateKey: - privKey = k - pubKey = k.Public() - case ed25519.PrivateKey: - privKey = k - pubKey = k.Public() - case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: - pubKey = k - privKey = nil - default: - return Certificate{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported key type")) - } + if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil { + return Certificate{}, ErrIntermediateCANotFound } + cert, err := s.issue(ctx, entityID, ttl, ipAddrs, options, pKey.Public(), pKey) + if err != nil { + return Certificate{}, err + } + + return cert, nil +} + +func (s *service) issue(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, pubKey crypto.PublicKey, privKey crypto.PrivateKey) (Certificate, error) { serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return Certificate{}, err } - if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil { - return Certificate{}, ErrIntermediateCANotFound + subject := s.getSubject(options) + if privKey != nil { + switch privKey.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, *ed25519.PrivateKey: + break + default: + return Certificate{}, errors.Wrap(ErrCreateEntity, ErrPrivKeyType) + } + } + + switch pubKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey, *ed25519.PublicKey: + break + default: + return Certificate{}, errors.Wrap(ErrCreateEntity, ErrPubKeyType) } // Parse the TTL if provided, otherwise use the default certValidityPeriod. @@ -149,8 +146,6 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ } } - subject := s.getSubject(options) - template := x509.Certificate{ SerialNumber: serialNumber, Subject: subject, @@ -476,7 +471,7 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } - cert, err := s.IssueCert(ctx, entityID, ttl, nil, SubjectOptions{ + cert, err := s.issue(ctx, entityID, ttl, nil, SubjectOptions{ CommonName: parsedCSR.Subject.CommonName, Organization: parsedCSR.Subject.Organization, OrganizationalUnit: parsedCSR.Subject.OrganizationalUnit, @@ -485,7 +480,7 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS Locality: parsedCSR.Subject.Locality, StreetAddress: parsedCSR.Subject.StreetAddress, PostalCode: parsedCSR.Subject.PostalCode, - }, parsedCSR.PublicKey) + }, parsedCSR.PublicKey, nil) if err != nil { return Certificate{}, errors.Wrap(ErrCreateEntity, err) } diff --git a/tracing/certs.go b/tracing/certs.go index 5ba671c..d766c67 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -5,7 +5,6 @@ package tracing import ( "context" - "crypto" "crypto/x509" "github.com/absmach/certs" @@ -54,10 +53,10 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error return tm.svc.RetrieveCAToken(ctx) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() - return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) + return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (tm *tracingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) {