Skip to content

Commit

Permalink
refactor issueCerts and issueforcsr
Browse files Browse the repository at this point in the history
Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene committed Dec 9, 2024
1 parent 85dedd2 commit 11972ae
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 69 deletions.
2 changes: 1 addition & 1 deletion api/http/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func issueCertEndpoint(svc certs.Service) endpoint.Endpoint {
return issueCertRes{}, err
}

cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options, nil)
cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options)
if err != nil {
return issueCertRes{}, err
}
Expand Down
5 changes: 2 additions & 3 deletions api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package api

import (
"context"
"crypto"
"crypto/x509"
"fmt"
"log/slog"
Expand Down Expand Up @@ -86,7 +85,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s
return lm.svc.RetrieveCAToken(ctx)
}

func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (cert certs.Certificate, err error) {
func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin))
if err != nil {
Expand All @@ -95,7 +94,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey)
return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options)
}

func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) {
Expand Down
5 changes: 2 additions & 3 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package api

import (
"context"
"crypto"
"crypto/x509"
"time"

Expand Down Expand Up @@ -72,12 +71,12 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error
return mm.svc.RetrieveCAToken(ctx)
}

func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) {
func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) {
defer func(begin time.Time) {
mm.counter.With("method", "issue_certificate").Add(1)
mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey)
return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options)
}

func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) {
Expand Down
3 changes: 1 addition & 2 deletions certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package certs

import (
"context"
"crypto"
"crypto/rsa"
"crypto/x509"
"net"
Expand Down Expand Up @@ -159,7 +158,7 @@ type Service interface {
RetrieveCAToken(ctx context.Context) (string, error)

// IssueCert issues a certificate from the database.
IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey crypto.PrivateKey) (Certificate, error)
IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error)

// OCSP retrieves the OCSP response for a certificate.
OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error)
Expand Down
2 changes: 1 addition & 1 deletion certs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestIssueCert(t *testing.T) {
t.Run(tc.desc, func(t *testing.T) {
repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(tc.err)

_, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{}, nil)
_, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{})
require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err)
repoCall1.Unset()
})
Expand Down
31 changes: 14 additions & 17 deletions mocks/service.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

73 changes: 34 additions & 39 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ var (
ErrCertInvalidType = errors.New("invalid cert type")
ErrInvalidLength = errors.New("invalid length of serial numbers")
ErrPrivKeyType = errors.New("unsupported private key type")
ErrPubKeyType = errors.New("unsupported public key type")
ErrFailedParse = errors.New("failed to parse key PEM")
)

Expand Down Expand Up @@ -95,49 +96,45 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service,
// using the provided template and the generated private key.
// The certificate is then stored in the repository using the CreateCert method.
// If the root CA is not found, it returns an error.
// issueCert generates and issues a certificate for a given backendID.
// It uses the RSA algorithm to generate a private key, and then creates a certificate
// using the provided template and the generated private key.
// The certificate is then stored in the repository using the CreateCert method.
// If the root CA is not found, it returns an error.
func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key crypto.PrivateKey) (Certificate, error) {
var privKey crypto.PrivateKey
var pubKey crypto.PublicKey
var err error
func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) {
pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes)
if err != nil {
return Certificate{}, err
}

if key == nil {
pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes)
if err != nil {
return Certificate{}, err
}
privKey = pKey
pubKey = pKey.Public()
} else {
switch k := key.(type) {
case *rsa.PrivateKey:
privKey = k
pubKey = k.Public()
case *ecdsa.PrivateKey:
privKey = k
pubKey = k.Public()
case ed25519.PrivateKey:
privKey = k
pubKey = k.Public()
case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey:
pubKey = k
privKey = nil
default:
return Certificate{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported key type"))
}
if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil {
return Certificate{}, ErrIntermediateCANotFound
}

cert, err := s.issue(ctx, entityID, ttl, ipAddrs, options, pKey.Public(), pKey)
if err != nil {
return Certificate{}, err
}

return cert, nil
}

func (s *service) issue(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, pubKey crypto.PublicKey, privKey crypto.PrivateKey) (Certificate, error) {
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
return Certificate{}, err
}

if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil {
return Certificate{}, ErrIntermediateCANotFound
subject := s.getSubject(options)
if privKey != nil {
switch privKey.(type) {
case *rsa.PrivateKey, *ecdsa.PrivateKey, *ed25519.PrivateKey:
break
default:
return Certificate{}, errors.Wrap(ErrCreateEntity, ErrPrivKeyType)
}
}

switch pubKey.(type) {
case *rsa.PublicKey, *ecdsa.PublicKey, *ed25519.PublicKey:
break
default:
return Certificate{}, errors.Wrap(ErrCreateEntity, ErrPubKeyType)
}

// Parse the TTL if provided, otherwise use the default certValidityPeriod.
Expand All @@ -149,8 +146,6 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [
}
}

subject := s.getSubject(options)

template := x509.Certificate{
SerialNumber: serialNumber,
Subject: subject,
Expand Down Expand Up @@ -476,7 +471,7 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS
return Certificate{}, errors.Wrap(ErrMalformedEntity, err)
}

cert, err := s.IssueCert(ctx, entityID, ttl, nil, SubjectOptions{
cert, err := s.issue(ctx, entityID, ttl, nil, SubjectOptions{
CommonName: parsedCSR.Subject.CommonName,
Organization: parsedCSR.Subject.Organization,
OrganizationalUnit: parsedCSR.Subject.OrganizationalUnit,
Expand All @@ -485,7 +480,7 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS
Locality: parsedCSR.Subject.Locality,
StreetAddress: parsedCSR.Subject.StreetAddress,
PostalCode: parsedCSR.Subject.PostalCode,
}, parsedCSR.PublicKey)
}, parsedCSR.PublicKey, nil)
if err != nil {
return Certificate{}, errors.Wrap(ErrCreateEntity, err)
}
Expand Down
5 changes: 2 additions & 3 deletions tracing/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package tracing

import (
"context"
"crypto"
"crypto/x509"

"github.com/absmach/certs"
Expand Down Expand Up @@ -54,10 +53,10 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error
return tm.svc.RetrieveCAToken(ctx)
}

func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) {
func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) {
ctx, span := tm.tracer.Start(ctx, "issue_cert")
defer span.End()
return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey)
return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options)
}

func (tm *tracingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) {
Expand Down

0 comments on commit 11972ae

Please sign in to comment.