Skip to content

Commit

Permalink
Add retrieve CA option
Browse files Browse the repository at this point in the history
Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene committed Oct 3, 2024
1 parent 0ea242d commit 7efccc9
Show file tree
Hide file tree
Showing 16 changed files with 1,058 additions and 76 deletions.
53 changes: 52 additions & 1 deletion api/http/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func downloadCertEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(downloadReq)
if err := req.validate(); err != nil {
return downloadCertRes{}, err
return fileDownloadRes{}, err
}
cert, ca, err := svc.RetrieveCert(ctx, req.token, req.id)
if err != nil {
Expand Down Expand Up @@ -243,3 +243,54 @@ func generateCRLEndpoint(svc certs.Service) endpoint.Endpoint {
}, nil
}
}

func getDownloadCATokenEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
token, err := svc.RetrieveCertDownloadToken(ctx)
if err != nil {
return requestCertDownloadTokenRes{}, err
}

return requestCertDownloadTokenRes{Token: token}, nil
}
}

func downloadCAEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(downloadReq)
if err := req.validate(); err != nil {
return fileDownloadRes{}, err
}

cert, err := svc.GetSigningCA(ctx, req.token)
if err != nil {
return fileDownloadRes{}, err
}

return fileDownloadRes{
Certificate: cert.Certificate,
PrivateKey: cert.Key,
Filename: "ca.zip",
ContentType: "application/zip",
}, nil
}
}

func viewCAEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(downloadReq)
if err := req.validate(); err != nil {
return viewCertRes{}, err
}

cert, err := svc.GetSigningCA(ctx, req.token)
if err != nil {
return viewCertRes{}, err
}

return viewCertRes{
Certificate: string(cert.Certificate),
Key: string(cert.Key),
}, nil
}
}
3 changes: 0 additions & 3 deletions api/http/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ type downloadReq struct {
}

func (req downloadReq) validate() error {
if req.id == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrEmptySerialNo)
}
if req.token == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrEmptyToken)
}
Expand Down
10 changes: 5 additions & 5 deletions api/http/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ func (res listCertsRes) Empty() bool {
}

type viewCertRes struct {
SerialNumber string `json:"serial_number"`
SerialNumber string `json:"serial_number,omitempty"`
Certificate string `json:"certificate,omitempty"`
Key string `json:"key,omitempty"`
Revoked bool `json:"revoked"`
ExpiryTime time.Time `json:"expiry_time"`
EntityID string `json:"entity_id"`
Key string `json:"key,omitempty,omitempty"`
Revoked bool `json:"revoked,omitempty"`
ExpiryTime time.Time `json:"expiry_time,omitempty"`
EntityID string `json:"entity_id,omitempty"`
}

func (res viewCertRes) Code() int {
Expand Down
65 changes: 65 additions & 0 deletions api/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,24 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http
EncodeResponse,
opts...,
), "generate_crl").ServeHTTP)
r.Get("/get-ca/token", otelhttp.NewHandler(kithttp.NewServer(
getDownloadCATokenEndpoint(svc),
decodeView,
EncodeResponse,
opts...,
), "get_ca_token").ServeHTTP)
r.Get("/view-ca", otelhttp.NewHandler(kithttp.NewServer(
viewCAEndpoint(svc),
decodeDownloadCA,
EncodeResponse,
opts...,
), "view_ca").ServeHTTP)
r.Get("/download-ca", otelhttp.NewHandler(kithttp.NewServer(
downloadCAEndpoint(svc),
decodeDownloadCA,
encodeCADownloadResponse,
opts...,
), "download_ca").ServeHTTP)
})

r.Get("/health", certs.Health("certs", instanceID))
Expand Down Expand Up @@ -139,6 +157,18 @@ func decodeDownloadCerts(_ context.Context, r *http.Request) (interface{}, error
return req, nil
}

func decodeDownloadCA(_ context.Context, r *http.Request) (interface{}, error) {
token, err := readStringQuery(r, token, "")
if err != nil {
return nil, err
}
req := downloadReq{
token: token,
}

return req, nil
}

func decodeOCSPRequest(_ context.Context, r *http.Request) (interface{}, error) {
body, err := io.ReadAll(r.Body)
if err != nil {
Expand Down Expand Up @@ -280,6 +310,41 @@ func encodeFileDownloadResponse(_ context.Context, w http.ResponseWriter, respon
return err
}

func encodeCADownloadResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
resp := response.(fileDownloadRes)
var buffer bytes.Buffer
zw := zip.NewWriter(&buffer)

f, err := zw.Create("ca.crt")
if err != nil {
return err
}

if _, err = f.Write(resp.Certificate); err != nil {
return err
}

f, err = zw.Create("ca.key")
if err != nil {
return err
}

if _, err = f.Write(resp.PrivateKey); err != nil {
return err
}

if err := zw.Close(); err != nil {
return err
}

w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", resp.Filename))
w.Header().Set("Content-Type", resp.ContentType)

_, err = w.Write(buffer.Bytes())

return err
}

// loggingErrorEncoder is a go-kit error encoder logging decorator.
func loggingErrorEncoder(logger *slog.Logger, enc kithttp.ErrorEncoder) kithttp.ErrorEncoder {
return func(ctx context.Context, err error, w http.ResponseWriter) {
Expand Down
22 changes: 17 additions & 5 deletions api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ func (lm *loggingMiddleware) RevokeCert(ctx context.Context, serialNumber string
return lm.svc.RevokeCert(ctx, serialNumber)
}

func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (tokenString string, err error) {
func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, serialNumber ...string) (tokenString string, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method get_cert_download_token for cert %s took %s to complete", serialNumber, time.Since(begin))
message := fmt.Sprintf("Method get_cert_download_token for cert took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.RetrieveCertDownloadToken(ctx, serialNumber)
return lm.svc.RetrieveCertDownloadToken(ctx, serialNumber...)
}

func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) {
Expand All @@ -97,7 +97,7 @@ func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadat
return lm.svc.ListCerts(ctx, pm)
}

func (lm *loggingMiddleware) ViewCert(ctx context.Context, serialNumber string) (cert certs.Certificate, err error) {
func (lm *loggingMiddleware) ViewCert(ctx context.Context, serialNumber ...string) (cert certs.Certificate, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method view_cert for serial number %s took %s to complete", serialNumber, time.Since(begin))
if err != nil {
Expand All @@ -106,7 +106,7 @@ func (lm *loggingMiddleware) ViewCert(ctx context.Context, serialNumber string)
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.ViewCert(ctx, serialNumber)
return lm.svc.ViewCert(ctx, serialNumber...)
}

func (lm *loggingMiddleware) OCSP(ctx context.Context, serialNumber string) (cert *certs.Certificate, ocspStatus int, rootCACert *x509.Certificate, err error) {
Expand Down Expand Up @@ -144,3 +144,15 @@ func (lm *loggingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT
}(time.Now())
return lm.svc.GenerateCRL(ctx, caType)
}

func (lm *loggingMiddleware) GetSigningCA(ctx context.Context, token string) (cert certs.Certificate, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method get_signing_ca took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.GetSigningCA(ctx, token)
}
18 changes: 14 additions & 4 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ func (mm *metricsMiddleware) RevokeCert(ctx context.Context, serialNumber string
return mm.svc.RevokeCert(ctx, serialNumber)
}

func (mm *metricsMiddleware) RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error) {
func (mm *metricsMiddleware) RetrieveCertDownloadToken(ctx context.Context, serialNumber ...string) (string, error) {
defer func(begin time.Time) {
mm.counter.With("method", "get_certificate_download_token").Add(1)
mm.latency.With("method", "get_certificate_download_token").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.RetrieveCertDownloadToken(ctx, serialNumber)

return mm.svc.RetrieveCertDownloadToken(ctx, serialNumber...)
}

func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) {
Expand All @@ -77,12 +78,13 @@ func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadat
return mm.svc.ListCerts(ctx, pm)
}

func (mm *metricsMiddleware) ViewCert(ctx context.Context, serialNumber string) (certs.Certificate, error) {
func (mm *metricsMiddleware) ViewCert(ctx context.Context, serialNumber ...string) (certs.Certificate, error) {
defer func(begin time.Time) {
mm.counter.With("method", "view_certificate").Add(1)
mm.latency.With("method", "view_certificate").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ViewCert(ctx, serialNumber)

return mm.svc.ViewCert(ctx, serialNumber...)
}

func (mm *metricsMiddleware) OCSP(ctx context.Context, serialNumber string) (*certs.Certificate, int, *x509.Certificate, error) {
Expand All @@ -108,3 +110,11 @@ func (mm *metricsMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT
}(time.Now())
return mm.svc.GenerateCRL(ctx, caType)
}

func (mm *metricsMiddleware) GetSigningCA(ctx context.Context, token string) (certs.Certificate, error) {
defer func(begin time.Time) {
mm.counter.With("method", "get_signing_ca").Add(1)
mm.latency.With("method", "get_signing_ca").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.GetSigningCA(ctx, token)
}
11 changes: 7 additions & 4 deletions certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ type Service interface {
RevokeCert(ctx context.Context, serialNumber string) error

// RetrieveCert retrieves a certificate record from the database.
RetrieveCert(ctx context.Context, token string, serialNumber string) (Certificate, []byte, error)
RetrieveCert(ctx context.Context, token, serialNumber string) (Certificate, []byte, error)

// ViewCert retrieves a certificate record from the database.
ViewCert(ctx context.Context, serialNumber string) (Certificate, error)
ViewCert(ctx context.Context, serialNumber ...string) (Certificate, error)

// ListCerts retrieves the certificates from the database while applying filters.
ListCerts(ctx context.Context, pm PageMetadata) (CertificatePage, error)

// RetrieveCertDownloadToken retrieves a certificate download token.
RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error)
RetrieveCertDownloadToken(ctx context.Context, serialNumber ...string) (string, error)

// IssueCert issues a certificate from the database.
IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error)
Expand All @@ -60,8 +60,11 @@ type Service interface {
// GetEntityID retrieves the entity ID for a certificate.
GetEntityID(ctx context.Context, serialNumber string) (string, error)

// GenerateCRL creates
// GenerateCRL creates cert revocation list.
GenerateCRL(ctx context.Context, caType CertType) ([]byte, error)

// Retrieves the signing CA.
GetSigningCA(ctx context.Context, token string) (Certificate, error)
}

type Repository interface {
Expand Down
53 changes: 52 additions & 1 deletion cli/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ var cmdCerts = []cobra.Command{
},
},
{
Use: "view <serial_number> ",
Use: "view <serial_number>",
Short: "View certificate",
Long: `Views a certificate for a given serial number.`,
Run: func(cmd *cobra.Command, args []string) {
Expand All @@ -155,6 +155,57 @@ var cmdCerts = []cobra.Command{
logJSONCmd(*cmd, cert)
},
},
{
Use: "view-ca <token>",
Short: "View-ca certificate",
Long: `Views ca certificate key with a given token.`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 1 {
logUsageCmd(*cmd, cmd.Use)
return
}
cert, err := sdk.ViewCA(args[0])
if err != nil {
logErrorCmd(*cmd, err)
return
}
logJSONCmd(*cmd, cert)
},
},
{
Use: "download-ca <token>",
Short: "Download signing CA",
Long: `Download intermediate cert and ca with a given token.`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 1 {
logUsageCmd(*cmd, cmd.Use)
return
}
bundle, err := sdk.DownloadCA(args[0])
if err != nil {
logErrorCmd(*cmd, err)
return
}
logSaveCAFiles(*cmd, bundle)
},
},
{
Use: "token-ca",
Short: "Get CA token",
Long: `Gets a download token for CA.`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 0 {
logUsageCmd(*cmd, cmd.Use)
return
}
token, err := sdk.GetCAToken()
if err != nil {
logErrorCmd(*cmd, err)
return
}
logJSONCmd(*cmd, token)
},
},
}

// NewCertsCmd returns certificate command.
Expand Down
Loading

0 comments on commit 7efccc9

Please sign in to comment.