Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Certs - Add intermediate CA, CA rotation and CRL #17

Merged
merged 13 commits into from
Sep 25, 2024
6 changes: 4 additions & 2 deletions api/http/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
errors.Contains(err, ErrEmptyToken),
errors.Contains(err, ErrInvalidQueryParams),
errors.Contains(err, ErrValidation),
errors.Contains(err, ErrInvalidRequest):
errors.Contains(err, ErrInvalidRequest),
errors.Contains(err, ErrMissingCN):
err = unwrap(err)
w.WriteHeader(http.StatusBadRequest)

Expand All @@ -63,7 +64,8 @@ func EncodeError(_ context.Context, err error, w http.ResponseWriter) {
w.WriteHeader(http.StatusUnprocessableEntity)

case errors.Contains(err, certs.ErrNotFound),
errors.Contains(err, certs.ErrRootCANotFound):
errors.Contains(err, certs.ErrRootCANotFound),
errors.Contains(err, certs.ErrIntermediateCANotFound):
err = unwrap(err)
w.WriteHeader(http.StatusNotFound)

Expand Down
19 changes: 18 additions & 1 deletion api/http/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func issueCertEndpoint(svc certs.Service) endpoint.Endpoint {
return issueCertRes{}, err
}

serialNumber, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs)
serialNumber, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options)
if err != nil {
return issueCertRes{}, err
}
Expand Down Expand Up @@ -219,3 +219,20 @@ func ocspEndpoint(svc certs.Service) endpoint.Endpoint {
}, nil
}
}

func generateCRLEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(crlReq)
if err := req.validate(); err != nil {
return crlRes{}, err
}
crlBytes, err := svc.GenerateCRL(ctx, req.certtype)
if err != nil {
return crlRes{}, err
}

return crlRes{
CrlBytes: crlBytes,
}, nil
}
}
3 changes: 3 additions & 0 deletions api/http/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,7 @@ var (

// ErrInvalidRequest indicates that the request is invalid.
ErrInvalidRequest = errors.New("invalid request")

// ErrMissingCN indicates missing common name.
ErrMissingCN = errors.New("missing common name")
)
18 changes: 15 additions & 3 deletions api/http/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,22 @@ func (req viewReq) validate() error {
return nil
}

type crlReq struct {
certtype certs.CertType
}

func (req crlReq) validate() error {
if req.certtype != certs.IntermediateCA {
return errors.Wrap(certs.ErrMalformedEntity, errors.New("invalid CA type"))
}
return nil
}

type issueCertReq struct {
entityID string `json:"-"`
TTL string `json:"ttl"`
IpAddrs []string `json:"ip_addresses"`
entityID string `json:"-"`
TTL string `json:"ttl"`
IpAddrs []string `json:"ip_addresses"`
Options certs.SubjectOptions `json:"options"`
}

func (req issueCertReq) validate() error {
Expand Down
20 changes: 18 additions & 2 deletions api/http/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ func (res listCertsRes) Empty() bool {

type viewCertRes struct {
SerialNumber string `json:"serial_number"`
Certificate string `json:"certificate"`
Key string `json:"key"`
Certificate string `json:"certificate,omitempty"`
Key string `json:"key,omitempty"`
Revoked bool `json:"revoked"`
ExpiryTime time.Time `json:"expiry_time"`
EntityID string `json:"entity_id"`
Expand All @@ -154,6 +154,22 @@ func (res viewCertRes) Empty() bool {
return false
}

type crlRes struct {
CrlBytes []byte `json:"crl"`
}

func (res crlRes) Code() int {
return http.StatusOK
}

func (res crlRes) Headers() map[string]string {
return map[string]string{}
}

func (res crlRes) Empty() bool {
return false
}

type ocspRes struct {
template ocsp.Response
signer crypto.Signer
Expand Down
33 changes: 31 additions & 2 deletions api/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@ const (
offsetKey = "offset"
limitKey = "limit"
entityKey = "entity_id"
commonName = "common_name"
token = "token"
ocspStatusParam = "force_status"
entityIDParam = "entityID"
defOffset = 0
defLimit = 10
defType = 1
)

// MakeHandler returns a HTTP handler for API endpoints.
Expand Down Expand Up @@ -91,6 +94,12 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http
encodeOSCPResponse,
opts...,
), "ocsp").ServeHTTP)
r.Get("/crl", otelhttp.NewHandler(kithttp.NewServer(
generateCRLEndpoint(svc),
decodeCRL,
EncodeResponse,
opts...,
), "generate_crl").ServeHTTP)
})

r.Get("/health", certs.Health("certs", instanceID))
Expand All @@ -106,8 +115,19 @@ func decodeView(_ context.Context, r *http.Request) (interface{}, error) {
return req, nil
}

func decodeCRL(_ context.Context, r *http.Request) (interface{}, error) {
certType, err := readNumQuery(r, "", defType)
if err != nil {
return nil, err
}
req := crlReq{
certtype: certs.CertType(certType),
}
return req, nil
}

func decodeDownloadCerts(_ context.Context, r *http.Request) (interface{}, error) {
token, err := readStringQuery(r, "token", "")
token, err := readStringQuery(r, token, "")
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -140,13 +160,22 @@ func decodeIssueCert(_ context.Context, r *http.Request) (interface{}, error) {
if err != nil {
return nil, err
}
cn, err := readStringQuery(r, commonName, "")
if err != nil {
return nil, err
}
if cn == "" {
return nil, ErrMissingCN
}
req := issueCertReq{
entityID: chi.URLParam(r, entityIDParam),
Options: certs.SubjectOptions{
CommonName: cn,
},
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, errors.Wrap(ErrInvalidRequest, err)
}

return req, nil
}

Expand Down
16 changes: 14 additions & 2 deletions api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri
return lm.svc.RetrieveCertDownloadToken(ctx, serialNumber)
}

func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (serialNumber string, err error) {
func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (serialNumber string, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin))
if err != nil {
Expand All @@ -82,7 +82,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs)
return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options)
}

func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) {
Expand Down Expand Up @@ -132,3 +132,15 @@ func (lm *loggingMiddleware) GetEntityID(ctx context.Context, serialNumber strin
}(time.Now())
return lm.svc.GetEntityID(ctx, serialNumber)
}

func (lm *loggingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertType) (crl []byte, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method generate_crl took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.GenerateCRL(ctx, caType)
}
12 changes: 10 additions & 2 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ func (mm *metricsMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri
return mm.svc.RetrieveCertDownloadToken(ctx, serialNumber)
}

func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (string, error) {
func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (string, error) {
defer func(begin time.Time) {
mm.counter.With("method", "issue_certificate").Add(1)
mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs)
return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options)
}

func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) {
Expand Down Expand Up @@ -100,3 +100,11 @@ func (mm *metricsMiddleware) GetEntityID(ctx context.Context, serialNumber strin
}(time.Now())
return mm.svc.GetEntityID(ctx, serialNumber)
}

func (mm *metricsMiddleware) GenerateCRL(ctx context.Context, caType certs.CertType) ([]byte, error) {
defer func(begin time.Time) {
mm.counter.With("method", "generate_crl").Add(1)
mm.latency.With("method", "generate_crl").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.GenerateCRL(ctx, caType)
}
12 changes: 11 additions & 1 deletion certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Certificate struct {
Revoked bool `db:"revoked"`
ExpiryTime time.Time `db:"expiry_time"`
EntityID string `db:"entity_id"`
Type CertType `db:"type"`
DownloadUrl string `db:"-"`
}

Expand Down Expand Up @@ -51,13 +52,16 @@ type Service interface {
RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error)

// IssueCert issues a certificate from the database.
IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string) (string, error)
IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (string, error)

// OCSP retrieves the OCSP response for a certificate.
OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error)

// GetEntityID retrieves the entity ID for a certificate.
GetEntityID(ctx context.Context, serialNumber string) (string, error)

// GenerateCRL creates
GenerateCRL(ctx context.Context, caType CertType) ([]byte, error)
}

type Repository interface {
Expand All @@ -72,4 +76,10 @@ type Repository interface {

// ListCerts retrieves the certificates from the database while applying filters.
ListCerts(ctx context.Context, pm PageMetadata) (CertificatePage, error)

// GetCAs retrieves rootCA and intermediateCA from database.
GetCAs(ctx context.Context, caType ...CertType) ([]Certificate, error)

// ListRevokedCerts retrieves revoked lists from database.
ListRevokedCerts(ctx context.Context) ([]Certificate, error)
}
Loading
Loading