Skip to content

Commit

Permalink
Certs - Add intermediate CA, CA rotation and CRL (#17)
Browse files Browse the repository at this point in the history
* Add intermidiate certs

Signed-off-by: nyagamunene <[email protected]>

* Add options for issuing certs

Signed-off-by: nyagamunene <[email protected]>

* Implement generation CRL and fix tests

Signed-off-by: nyagamunene <[email protected]>

* Fix download certs SDK

Signed-off-by: nyagamunene <[email protected]>

* Fix list generate crl

Signed-off-by: nyagamunene <[email protected]>

* Fix tests for CRL

Signed-off-by: nyagamunene <[email protected]>

* Address comments

Signed-off-by: nyagamunene <[email protected]>

* Change root and intermediate cert validity period

Signed-off-by: nyagamunene <[email protected]>

* Update postgres methods

Signed-off-by: nyagamunene <[email protected]>

* Update json tags

Signed-off-by: nyagamunene <[email protected]>

* Address comments

Signed-off-by: nyagamunene <[email protected]>

* Rebase from main

Signed-off-by: nyagamunene <[email protected]>

* Fix failing CI

Signed-off-by: nyagamunene <[email protected]>

---------

Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene authored Sep 25, 2024
1 parent 7188f89 commit 0f074e3
Show file tree
Hide file tree
Showing 22 changed files with 1,232 additions and 210 deletions.
6 changes: 4 additions & 2 deletions api/http/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
errors.Contains(err, ErrEmptyToken),
errors.Contains(err, ErrInvalidQueryParams),
errors.Contains(err, ErrValidation),
errors.Contains(err, ErrInvalidRequest):
errors.Contains(err, ErrInvalidRequest),
errors.Contains(err, ErrMissingCN):
err = unwrap(err)
w.WriteHeader(http.StatusBadRequest)

Expand All @@ -63,7 +64,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
w.WriteHeader(http.StatusUnprocessableEntity)

case errors.Contains(err, certs.ErrNotFound),
errors.Contains(err, certs.ErrRootCANotFound):
errors.Contains(err, certs.ErrRootCANotFound),
errors.Contains(err, certs.ErrIntermediateCANotFound):
err = unwrap(err)
w.WriteHeader(http.StatusNotFound)

Expand Down
19 changes: 18 additions & 1 deletion api/http/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func issueCertEndpoint(svc certs.Service) endpoint.Endpoint {
return issueCertRes{}, err
}

serialNumber, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs)
serialNumber, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options)
if err != nil {
return issueCertRes{}, err
}
Expand Down Expand Up @@ -219,3 +219,20 @@ func ocspEndpoint(svc certs.Service) endpoint.Endpoint {
}, nil
}
}

func generateCRLEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(crlReq)
if err := req.validate(); err != nil {
return crlRes{}, err
}
crlBytes, err := svc.GenerateCRL(ctx, req.certtype)
if err != nil {
return crlRes{}, err
}

return crlRes{
CrlBytes: crlBytes,
}, nil
}
}
3 changes: 3 additions & 0 deletions api/http/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,7 @@ var (

// ErrInvalidRequest indicates that the request is invalid.
ErrInvalidRequest = errors.New("invalid request")

// ErrMissingCN indicates missing common name.
ErrMissingCN = errors.New("missing common name")
)
18 changes: 15 additions & 3 deletions api/http/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,22 @@ func (req viewReq) validate() error {
return nil
}

type crlReq struct {
certtype certs.CertType
}

func (req crlReq) validate() error {
if req.certtype != certs.IntermediateCA {
return errors.Wrap(certs.ErrMalformedEntity, errors.New("invalid CA type"))
}
return nil
}

type issueCertReq struct {
entityID string `json:"-"`
TTL string `json:"ttl"`
IpAddrs []string `json:"ip_addresses"`
entityID string `json:"-"`
TTL string `json:"ttl"`
IpAddrs []string `json:"ip_addresses"`
Options certs.SubjectOptions `json:"options"`
}

func (req issueCertReq) validate() error {
Expand Down
20 changes: 18 additions & 2 deletions api/http/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func (res listCertsRes) Empty() bool {

type viewCertRes struct {
SerialNumber string `json:"serial_number"`
Certificate string `json:"certificate"`
Key string `json:"key"`
Certificate string `json:"certificate,omitempty"`
Key string `json:"key,omitempty"`
Revoked bool `json:"revoked"`
ExpiryTime time.Time `json:"expiry_time"`
EntityID string `json:"entity_id"`
Expand All @@ -154,6 +154,22 @@ func (res viewCertRes) Empty() bool {
return false
}

type crlRes struct {
CrlBytes []byte `json:"crl"`
}

func (res crlRes) Code() int {
return http.StatusOK
}

func (res crlRes) Headers() map[string]string {
return map[string]string{}
}

func (res crlRes) Empty() bool {
return false
}

type ocspRes struct {
template ocsp.Response
signer crypto.Signer
Expand Down
33 changes: 31 additions & 2 deletions api/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ const (
offsetKey = "offset"
limitKey = "limit"
entityKey = "entity_id"
commonName = "common_name"
token = "token"
ocspStatusParam = "force_status"
entityIDParam = "entityID"
defOffset = 0
defLimit = 10
defType = 1
)

// MakeHandler returns a HTTP handler for API endpoints.
Expand Down Expand Up @@ -91,6 +94,12 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http
encodeOSCPResponse,
opts...,
), "ocsp").ServeHTTP)
r.Get("/crl", otelhttp.NewHandler(kithttp.NewServer(
generateCRLEndpoint(svc),
decodeCRL,
EncodeResponse,
opts...,
), "generate_crl").ServeHTTP)
})

r.Get("/health", certs.Health("certs", instanceID))
Expand All @@ -106,8 +115,19 @@ func decodeView(_ context.Context, r *http.Request) (interface{}, error) {
return req, nil
}

func decodeCRL(_ context.Context, r *http.Request) (interface{}, error) {
certType, err := readNumQuery(r, "", defType)
if err != nil {
return nil, err
}
req := crlReq{
certtype: certs.CertType(certType),
}
return req, nil
}

func decodeDownloadCerts(_ context.Context, r *http.Request) (interface{}, error) {
token, err := readStringQuery(r, "token", "")
token, err := readStringQuery(r, token, "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -140,13 +160,22 @@ func decodeIssueCert(_ context.Context, r *http.Request) (interface{}, error) {
if err != nil {
return nil, err
}
cn, err := readStringQuery(r, commonName, "")
if err != nil {
return nil, err
}
if cn == "" {
return nil, ErrMissingCN
}
req := issueCertReq{
entityID: chi.URLParam(r, entityIDParam),
Options: certs.SubjectOptions{
CommonName: cn,
},
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, errors.Wrap(ErrInvalidRequest, err)
}

return req, nil
}

Expand Down
16 changes: 14 additions & 2 deletions api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri
return lm.svc.RetrieveCertDownloadToken(ctx, serialNumber)
}

func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (serialNumber string, err error) {
func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (serialNumber string, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin))
if err != nil {
Expand All @@ -82,7 +82,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs)
return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options)
}

func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) {
Expand Down Expand Up @@ -132,3 +132,15 @@ func (lm *loggingMiddleware) GetEntityID(ctx context.Context, serialNumber strin
}(time.Now())
return lm.svc.GetEntityID(ctx, serialNumber)
}

func (lm *loggingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertType) (crl []byte, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method generate_crl took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.GenerateCRL(ctx, caType)
}
12 changes: 10 additions & 2 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ func (mm *metricsMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri
return mm.svc.RetrieveCertDownloadToken(ctx, serialNumber)
}

func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (string, error) {
func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (string, error) {
defer func(begin time.Time) {
mm.counter.With("method", "issue_certificate").Add(1)
mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs)
return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options)
}

func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) {
Expand Down Expand Up @@ -100,3 +100,11 @@ func (mm *metricsMiddleware) GetEntityID(ctx context.Context, serialNumber strin
}(time.Now())
return mm.svc.GetEntityID(ctx, serialNumber)
}

func (mm *metricsMiddleware) GenerateCRL(ctx context.Context, caType certs.CertType) ([]byte, error) {
defer func(begin time.Time) {
mm.counter.With("method", "generate_crl").Add(1)
mm.latency.With("method", "generate_crl").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.GenerateCRL(ctx, caType)
}
12 changes: 11 additions & 1 deletion certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Certificate struct {
Revoked bool `db:"revoked"`
ExpiryTime time.Time `db:"expiry_time"`
EntityID string `db:"entity_id"`
Type CertType `db:"type"`
DownloadUrl string `db:"-"`
}

Expand Down Expand Up @@ -51,13 +52,16 @@ type Service interface {
RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error)

// IssueCert issues a certificate from the database.
IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (string, error)
IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (string, error)

// OCSP retrieves the OCSP response for a certificate.
OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error)

// GetEntityID retrieves the entity ID for a certificate.
GetEntityID(ctx context.Context, serialNumber string) (string, error)

// GenerateCRL creates
GenerateCRL(ctx context.Context, caType CertType) ([]byte, error)
}

type Repository interface {
Expand All @@ -72,4 +76,10 @@ type Repository interface {

// ListCerts retrieves the certificates from the database while applying filters.
ListCerts(ctx context.Context, pm PageMetadata) (CertificatePage, error)

// GetCAs retrieves rootCA and intermediateCA from database.
GetCAs(ctx context.Context, caType ...CertType) ([]Certificate, error)

// ListRevokedCerts retrieves revoked lists from database.
ListRevokedCerts(ctx context.Context) ([]Certificate, error)
}
Loading

0 comments on commit 0f074e3

Please sign in to comment.