Skip to content

Commit

Permalink
Refactor CSR generation
Browse files Browse the repository at this point in the history
Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene committed Nov 28, 2024
1 parent f2b8a9a commit eba2f56
Show file tree
Hide file tree
Showing 20 changed files with 163 additions and 1,367 deletions.
41 changes: 3 additions & 38 deletions api/http/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func createCSREndpoint(svc certs.Service) endpoint.Endpoint {
return createCSRRes{created: false}, err
}

csr, err := svc.CreateCSR(ctx, req.Metadata, req.Metadata.EntityID, req.privKey)
csr, err := svc.CreateCSR(ctx, req.Metadata, req.privKey)
if err != nil {
return createCSRRes{created: false}, err
}
Expand All @@ -336,49 +336,14 @@ func signCSREndpoint(svc certs.Service) endpoint.Endpoint {
return signCSRRes{signed: false}, err
}

err = svc.SignCSR(ctx, req.csrID, req.approve)
cert, err := svc.SignCSR(ctx, req.entityID, req.ttl, certs.CSR{CSR: req.CSR})
if err != nil {
return signCSRRes{signed: false}, err
}

return signCSRRes{
crt: cert,
signed: true,
}, nil
}
}

func retrieveCSREndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(retrieveCSRReq)
if err := req.validate(); err != nil {
return retrieveCSRRes{}, err
}

csr, err := svc.RetrieveCSR(ctx, req.csrID)
if err != nil {
return retrieveCSRRes{}, err
}

return retrieveCSRRes{
CSR: csr,
}, nil
}
}

func listCSRsEndpoint(svc certs.Service) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
req := request.(listCSRsReq)
if err := req.validate(); err != nil {
return listCSRsRes{}, err
}

cp, err := svc.ListCSRs(ctx, req.pm)
if err != nil {
return listCSRsRes{}, err
}

return listCSRsRes{
cp,
}, nil
}
}
4 changes: 2 additions & 2 deletions api/http/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ var (
// ErrMissingCN indicates missing common name.
ErrMissingCN = errors.New("missing common name")

// ErrMissingStatus indicates missing status.
ErrMissingStatus = errors.New("missing status")
// ErrMissingCSR indicates missing csr.
ErrMissingCSR = errors.New("missing CSR")
)
34 changes: 8 additions & 26 deletions api/http/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,43 +97,25 @@ type createCSRReq struct {
}

func (req createCSRReq) validate() error {
if req.Metadata.EntityID == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID)
if req.Metadata.CommonName == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCN)
}
return nil
}

type SignCSRReq struct {
csrID string
approve bool
entityID string
ttl string
CSR []byte `json:"csr"`
}

func (req SignCSRReq) validate() error {
if req.csrID == "" {
if req.entityID == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID)
}

return nil
}

type listCSRsReq struct {
pm certs.PageMetadata
}

func (req listCSRsReq) validate() error {
if req.pm.Status.String() == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrMissingStatus)
if len(req.CSR) == 0 {
return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCSR)
}
return nil
}

type retrieveCSRReq struct {
csrID string
}

func (req retrieveCSRReq) validate() error {
if req.csrID == "" {
return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID)
}
return nil
}
3 changes: 2 additions & 1 deletion api/http/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ func (res createCSRRes) Empty() bool {
}

type signCSRRes struct {
crt certs.Certificate
signed bool
}

Expand All @@ -237,7 +238,7 @@ func (res signCSRRes) Headers() map[string]string {
}

func (res signCSRRes) Empty() bool {
return true
return false
}

type listCSRsRes struct {

Check failure on line 244 in api/http/responses.go

View workflow job for this annotation

GitHub Actions / Lint and Build

type `listCSRsRes` is unused (unused)
Expand Down
103 changes: 16 additions & 87 deletions api/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const (
token = "token"
ocspStatusParam = "force_status"
entityIDParam = "entityID"
ttl = "ttl"
defOffset = 0
defLimit = 10
defType = 1
Expand Down Expand Up @@ -142,30 +143,18 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http
opts...,
), "download_ca").ServeHTTP)
r.Route("/csrs", func(r chi.Router) {
r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer(
r.Post("/create", otelhttp.NewHandler(kithttp.NewServer(
createCSREndpoint(svc),
decodeCreateCSR,
EncodeResponse,
opts...,
), "create_csr").ServeHTTP)
r.Patch("/{id}", otelhttp.NewHandler(kithttp.NewServer(
r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer(
signCSREndpoint(svc),
decodeUpdateCSR,
decodeSignCSR,
EncodeResponse,
opts...,
), "sign_csr").ServeHTTP)
r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer(
retrieveCSREndpoint(svc),
decodeRetrieveCSR,
EncodeResponse,
opts...,
), "view_csr").ServeHTTP)
r.Get("/", otelhttp.NewHandler(kithttp.NewServer(
listCSRsEndpoint(svc),
decodeListCSR,
EncodeResponse,
opts...,
), "list_csrs").ServeHTTP)
})
})

Expand Down Expand Up @@ -293,83 +282,41 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) {

func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) {
req := createCSRReq{}
req.Metadata.EntityID = chi.URLParam(r, "entityID")
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}

if len(req.PrivateKey) > 0 {
block, _ := pem.Decode(req.PrivateKey)
if block != nil {
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(ErrInvalidRequest, err)
}
req.privKey = privateKey
block, _ := pem.Decode(req.PrivateKey)
if block != nil {
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, errors.Wrap(ErrInvalidRequest, err)
}
req.privKey = privateKey
}

return req, nil
}

func decodeUpdateCSR(_ context.Context, r *http.Request) (interface{}, error) {
app, err := readBoolQuery(r, approve, false)
func decodeSignCSR(_ context.Context, r *http.Request) (interface{}, error) {
t, err := readStringQuery(r, ttl, "")
if err != nil {
return nil, err
}

req := SignCSRReq{
csrID: chi.URLParam(r, "id"),
approve: app,
entityID: chi.URLParam(r, "entityID"),
ttl: t,
}

return req, nil
}

func decodeRetrieveCSR(_ context.Context, r *http.Request) (interface{}, error) {
req := retrieveCSRReq{
csrID: chi.URLParam(r, "id"),
}

return req, nil
}

func decodeListCSR(_ context.Context, r *http.Request) (interface{}, error) {
o, err := readNumQuery(r, offsetKey, defOffset)
if err != nil {
return nil, err
}

l, err := readNumQuery(r, limitKey, defLimit)
if err != nil {
return nil, err
}

s, err := readStringQuery(r, status, "")
if err != nil {
return nil, err
}
e, err := readStringQuery(r, entityKey, "")
if err != nil {
return nil, err
}

stat, err := certs.ParseCSRStatus(strings.ToLower(s))
if err != nil {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, err
}

req := listCSRsReq{
pm: certs.PageMetadata{
Offset: o,
Limit: l,
EntityID: e,
Status: stat,
},
}
return req, nil
}


// EncodeResponse encodes successful response.
func EncodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error {
if ar, ok := response.(Response); ok {
Expand Down Expand Up @@ -539,21 +486,3 @@ func readNumQuery(r *http.Request, key string, def uint64) (uint64, error) {
}
return v, nil
}

func readBoolQuery(r *http.Request, key string, def bool) (bool, error) {
vals := r.URL.Query()[key]
if len(vals) > 1 {
return false, ErrInvalidQueryParams
}

if len(vals) == 0 {
return def, nil
}

b, err := strconv.ParseBool(vals[0])
if err != nil {
return false, errors.Wrap(ErrInvalidQueryParams, err)
}

return b, nil
}
32 changes: 4 additions & 28 deletions api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (lm *loggingMiddleware) GetChainCA(ctx context.Context, token string) (cert
return lm.svc.GetChainCA(ctx, token)
}

func (lm *loggingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (csr certs.CSR, err error) {
func (lm *loggingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (csr certs.CSR, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method create_csr took %s to complete", time.Since(begin))
if err != nil {
Expand All @@ -191,10 +191,10 @@ func (lm *loggingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetada
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.CreateCSR(ctx, meta, entityID, key...)
return lm.svc.CreateCSR(ctx, metadata, privKey)
}

func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) (err error) {
func (lm *loggingMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (c certs.Certificate, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method sign_csr took %s to complete", time.Since(begin))
if err != nil {
Expand All @@ -203,29 +203,5 @@ func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.SignCSR(ctx, csrID, approve)
}

func (lm *loggingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (cp certs.CSRPage, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method list_csrs took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.ListCSRs(ctx, pm)
}

func (lm *loggingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (csr certs.CSR, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method retrieve_csr took %s to complete", time.Since(begin))
if err != nil {
lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err))
return
}
lm.logger.Info(message)
}(time.Now())
return lm.svc.RetrieveCSR(ctx, csrID)
return lm.svc.SignCSR(ctx, entityID, ttl, csr)
}
24 changes: 4 additions & 20 deletions api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,34 +137,18 @@ func (mm *metricsMiddleware) GetChainCA(ctx context.Context, token string) (cert
return mm.svc.GetChainCA(ctx, token)
}

func (mm *metricsMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (certs.CSR, error) {
func (mm *metricsMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) {
defer func(begin time.Time) {
mm.counter.With("method", "create_csr").Add(1)
mm.latency.With("method", "create_csr").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.CreateCSR(ctx, meta, entityID, key...)
return mm.svc.CreateCSR(ctx, metadata, privKey)
}

func (mm *metricsMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error {
func (mm *metricsMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) {
defer func(begin time.Time) {
mm.counter.With("method", "sign_csr").Add(1)
mm.latency.With("method", "sign_csr").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.SignCSR(ctx, csrID, approve)
}

func (mm *metricsMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) {
defer func(begin time.Time) {
mm.counter.With("method", "retrieve_csr").Add(1)
mm.latency.With("method", "retrieve_csr").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.RetrieveCSR(ctx, csrID)
}

func (mm *metricsMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) {
defer func(begin time.Time) {
mm.counter.With("method", "list_csrs").Add(1)
mm.latency.With("method", "list_csrs").Observe(time.Since(begin).Seconds())
}(time.Now())
return mm.svc.ListCSRs(ctx, pm)
return mm.svc.SignCSR(ctx, entityID, ttl, csr)
}
Loading

0 comments on commit eba2f56

Please sign in to comment.