From eba2f563ef88df3612b77de9a4d0f8413c06e6db Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 28 Nov 2024 19:29:19 +0300 Subject: [PATCH 01/22] Refactor CSR generation Signed-off-by: nyagamunene --- api/http/endpoint.go | 41 +------ api/http/errors.go | 4 +- api/http/requests.go | 34 ++---- api/http/responses.go | 3 +- api/http/transport.go | 103 +++-------------- api/logging.go | 32 +----- api/metrics.go | 24 +--- certs.go | 90 ++------------- certs_test.go | 32 ++---- cli/certs.go | 88 ++------------- cmd/certs/main.go | 7 +- mocks/csr.go | 249 ------------------------------------------ mocks/service.go | 204 ++++++++-------------------------- mocks/uuid.go | 35 ------ postgres/csr/csr.go | 202 ---------------------------------- postgres/csr/init.go | 34 ------ sdk/mocks/sdk.go | 163 +++++---------------------- sdk/sdk.go | 82 ++------------ service.go | 83 ++------------ tracing/certs.go | 20 +--- 20 files changed, 163 insertions(+), 1367 deletions(-) delete mode 100644 mocks/csr.go delete mode 100644 mocks/uuid.go delete mode 100644 postgres/csr/csr.go delete mode 100644 postgres/csr/init.go diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 16010c7..43fa6ac 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -317,7 +317,7 @@ func createCSREndpoint(svc certs.Service) endpoint.Endpoint { return createCSRRes{created: false}, err } - csr, err := svc.CreateCSR(ctx, req.Metadata, req.Metadata.EntityID, req.privKey) + csr, err := svc.CreateCSR(ctx, req.Metadata, req.privKey) if err != nil { return createCSRRes{created: false}, err } @@ -336,49 +336,14 @@ func signCSREndpoint(svc certs.Service) endpoint.Endpoint { return signCSRRes{signed: false}, err } - err = svc.SignCSR(ctx, req.csrID, req.approve) + cert, err := svc.SignCSR(ctx, req.entityID, req.ttl, certs.CSR{CSR: req.CSR}) if err != nil { return signCSRRes{signed: false}, err } return signCSRRes{ + crt: cert, signed: true, }, nil } } - -func retrieveCSREndpoint(svc certs.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(retrieveCSRReq) - if err := req.validate(); err != nil { - return retrieveCSRRes{}, err - } - - csr, err := svc.RetrieveCSR(ctx, req.csrID) - if err != nil { - return retrieveCSRRes{}, err - } - - return retrieveCSRRes{ - CSR: csr, - }, nil - } -} - -func listCSRsEndpoint(svc certs.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(listCSRsReq) - if err := req.validate(); err != nil { - return listCSRsRes{}, err - } - - cp, err := svc.ListCSRs(ctx, req.pm) - if err != nil { - return listCSRsRes{}, err - } - - return listCSRsRes{ - cp, - }, nil - } -} diff --git a/api/http/errors.go b/api/http/errors.go index dcbfb14..8e6a5ae 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -33,6 +33,6 @@ var ( // ErrMissingCN indicates missing common name. ErrMissingCN = errors.New("missing common name") - // ErrMissingStatus indicates missing status. - ErrMissingStatus = errors.New("missing status") + // ErrMissingCSR indicates missing csr. + ErrMissingCSR = errors.New("missing CSR") ) diff --git a/api/http/requests.go b/api/http/requests.go index 2fbb7de..ed0f4cb 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -97,43 +97,25 @@ type createCSRReq struct { } func (req createCSRReq) validate() error { - if req.Metadata.EntityID == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + if req.Metadata.CommonName == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCN) } return nil } type SignCSRReq struct { - csrID string - approve bool + entityID string + ttl string + CSR []byte `json:"csr"` } func (req SignCSRReq) validate() error { - if req.csrID == "" { + if req.entityID == "" { return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) } - - return nil -} - -type listCSRsReq struct { - pm certs.PageMetadata -} - -func (req listCSRsReq) validate() error { - if req.pm.Status.String() == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrMissingStatus) + if len(req.CSR) == 0 { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCSR) } - return nil -} -type retrieveCSRReq struct { - csrID string -} - -func (req retrieveCSRReq) validate() error { - if req.csrID == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) - } return nil } diff --git a/api/http/responses.go b/api/http/responses.go index d24709f..f6b394b 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -225,6 +225,7 @@ func (res createCSRRes) Empty() bool { } type signCSRRes struct { + crt certs.Certificate signed bool } @@ -237,7 +238,7 @@ func (res signCSRRes) Headers() map[string]string { } func (res signCSRRes) Empty() bool { - return true + return false } type listCSRsRes struct { diff --git a/api/http/transport.go b/api/http/transport.go index 87e22be..6b5e391 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -39,6 +39,7 @@ const ( token = "token" ocspStatusParam = "force_status" entityIDParam = "entityID" + ttl = "ttl" defOffset = 0 defLimit = 10 defType = 1 @@ -142,30 +143,18 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http opts..., ), "download_ca").ServeHTTP) r.Route("/csrs", func(r chi.Router) { - r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( + r.Post("/create", otelhttp.NewHandler(kithttp.NewServer( createCSREndpoint(svc), decodeCreateCSR, EncodeResponse, opts..., ), "create_csr").ServeHTTP) - r.Patch("/{id}", otelhttp.NewHandler(kithttp.NewServer( + r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( signCSREndpoint(svc), - decodeUpdateCSR, + decodeSignCSR, EncodeResponse, opts..., ), "sign_csr").ServeHTTP) - r.Get("/{id}", otelhttp.NewHandler(kithttp.NewServer( - retrieveCSREndpoint(svc), - decodeRetrieveCSR, - EncodeResponse, - opts..., - ), "view_csr").ServeHTTP) - r.Get("/", otelhttp.NewHandler(kithttp.NewServer( - listCSRsEndpoint(svc), - decodeListCSR, - EncodeResponse, - opts..., - ), "list_csrs").ServeHTTP) }) }) @@ -293,83 +282,41 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) { func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { req := createCSRReq{} - req.Metadata.EntityID = chi.URLParam(r, "entityID") if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, err } - if len(req.PrivateKey) > 0 { - block, _ := pem.Decode(req.PrivateKey) - if block != nil { - privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - if err != nil { - return nil, errors.Wrap(ErrInvalidRequest, err) - } - req.privKey = privateKey + block, _ := pem.Decode(req.PrivateKey) + if block != nil { + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, errors.Wrap(ErrInvalidRequest, err) } + req.privKey = privateKey } return req, nil } -func decodeUpdateCSR(_ context.Context, r *http.Request) (interface{}, error) { - app, err := readBoolQuery(r, approve, false) +func decodeSignCSR(_ context.Context, r *http.Request) (interface{}, error) { + t, err := readStringQuery(r, ttl, "") if err != nil { return nil, err } req := SignCSRReq{ - csrID: chi.URLParam(r, "id"), - approve: app, + entityID: chi.URLParam(r, "entityID"), + ttl: t, } - return req, nil -} - -func decodeRetrieveCSR(_ context.Context, r *http.Request) (interface{}, error) { - req := retrieveCSRReq{ - csrID: chi.URLParam(r, "id"), - } - - return req, nil -} - -func decodeListCSR(_ context.Context, r *http.Request) (interface{}, error) { - o, err := readNumQuery(r, offsetKey, defOffset) - if err != nil { - return nil, err - } - - l, err := readNumQuery(r, limitKey, defLimit) - if err != nil { - return nil, err - } - - s, err := readStringQuery(r, status, "") - if err != nil { - return nil, err - } - e, err := readStringQuery(r, entityKey, "") - if err != nil { - return nil, err - } - - stat, err := certs.ParseCSRStatus(strings.ToLower(s)) - if err != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, err } - req := listCSRsReq{ - pm: certs.PageMetadata{ - Offset: o, - Limit: l, - EntityID: e, - Status: stat, - }, - } return req, nil } + // EncodeResponse encodes successful response. func EncodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { if ar, ok := response.(Response); ok { @@ -539,21 +486,3 @@ func readNumQuery(r *http.Request, key string, def uint64) (uint64, error) { } return v, nil } - -func readBoolQuery(r *http.Request, key string, def bool) (bool, error) { - vals := r.URL.Query()[key] - if len(vals) > 1 { - return false, ErrInvalidQueryParams - } - - if len(vals) == 0 { - return def, nil - } - - b, err := strconv.ParseBool(vals[0]) - if err != nil { - return false, errors.Wrap(ErrInvalidQueryParams, err) - } - - return b, nil -} diff --git a/api/logging.go b/api/logging.go index 80144f6..dcc6f9d 100644 --- a/api/logging.go +++ b/api/logging.go @@ -182,7 +182,7 @@ func (lm *loggingMiddleware) GetChainCA(ctx context.Context, token string) (cert return lm.svc.GetChainCA(ctx, token) } -func (lm *loggingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (csr certs.CSR, err error) { +func (lm *loggingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (csr certs.CSR, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method create_csr took %s to complete", time.Since(begin)) if err != nil { @@ -191,10 +191,10 @@ func (lm *loggingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetada } lm.logger.Info(message) }(time.Now()) - return lm.svc.CreateCSR(ctx, meta, entityID, key...) + return lm.svc.CreateCSR(ctx, metadata, privKey) } -func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) (err error) { +func (lm *loggingMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (c certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method sign_csr took %s to complete", time.Since(begin)) if err != nil { @@ -203,29 +203,5 @@ func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve } lm.logger.Info(message) }(time.Now()) - return lm.svc.SignCSR(ctx, csrID, approve) -} - -func (lm *loggingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (cp certs.CSRPage, err error) { - defer func(begin time.Time) { - message := fmt.Sprintf("Method list_csrs took %s to complete", time.Since(begin)) - if err != nil { - lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) - return - } - lm.logger.Info(message) - }(time.Now()) - return lm.svc.ListCSRs(ctx, pm) -} - -func (lm *loggingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (csr certs.CSR, err error) { - defer func(begin time.Time) { - message := fmt.Sprintf("Method retrieve_csr took %s to complete", time.Since(begin)) - if err != nil { - lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) - return - } - lm.logger.Info(message) - }(time.Now()) - return lm.svc.RetrieveCSR(ctx, csrID) + return lm.svc.SignCSR(ctx, entityID, ttl, csr) } diff --git a/api/metrics.go b/api/metrics.go index b0c6f4e..f627b51 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -137,34 +137,18 @@ func (mm *metricsMiddleware) GetChainCA(ctx context.Context, token string) (cert return mm.svc.GetChainCA(ctx, token) } -func (mm *metricsMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (certs.CSR, error) { +func (mm *metricsMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) { defer func(begin time.Time) { mm.counter.With("method", "create_csr").Add(1) mm.latency.With("method", "create_csr").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.CreateCSR(ctx, meta, entityID, key...) + return mm.svc.CreateCSR(ctx, metadata, privKey) } -func (mm *metricsMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error { +func (mm *metricsMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "sign_csr").Add(1) mm.latency.With("method", "sign_csr").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.SignCSR(ctx, csrID, approve) -} - -func (mm *metricsMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { - defer func(begin time.Time) { - mm.counter.With("method", "retrieve_csr").Add(1) - mm.latency.With("method", "retrieve_csr").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.RetrieveCSR(ctx, csrID) -} - -func (mm *metricsMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { - defer func(begin time.Time) { - mm.counter.With("method", "list_csrs").Add(1) - mm.latency.With("method", "list_csrs").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.ListCSRs(ctx, pm) + return mm.svc.SignCSR(ctx, entityID, ttl, csr) } diff --git a/certs.go b/certs.go index 6654c5b..71f0550 100644 --- a/certs.go +++ b/certs.go @@ -7,7 +7,6 @@ import ( "context" "crypto/rsa" "crypto/x509" - "encoding/json" "net" "time" @@ -55,56 +54,6 @@ func CertTypeFromString(s string) (CertType, error) { } } -type CSRStatus int - -const ( - Pending CSRStatus = iota - Signed - Rejected - All -) - -const ( - pending = "pending" - signed = "signed" - rejected = "rejected" - all = "all" -) - -func (c CSRStatus) String() string { - switch c { - case Pending: - return pending - case Signed: - return signed - case Rejected: - return rejected - case All: - return all - default: - return Unknown - } -} - -func ParseCSRStatus(s string) (CSRStatus, error) { - switch s { - case pending: - return Pending, nil - case signed: - return Signed, nil - case rejected: - return Rejected, nil - case all: - return All, nil - default: - return -1, errors.New("unknown CSR status") - } -} - -func (c CSRStatus) MarshalJSON() ([]byte, error) { - return json.Marshal(c.String()) -} - type CA struct { Type CertType Certificate *x509.Certificate @@ -129,15 +78,13 @@ type CertificatePage struct { } type PageMetadata struct { - Total uint64 `json:"total" db:"total"` - Offset uint64 `json:"offset,omitempty" db:"offset"` - Limit uint64 `json:"limit" db:"limit"` - EntityID string `json:"entity_id,omitempty" db:"entity_id"` - Status CSRStatus `json:"status,omitempty" db:"status"` + Total uint64 `json:"total" db:"total"` + Offset uint64 `json:"offset,omitempty" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + EntityID string `json:"entity_id,omitempty" db:"entity_id"` } type CSRMetadata struct { - EntityID string CommonName string `json:"common_name"` Organization []string `json:"organization"` OrganizationalUnit []string `json:"organizational_unit"` @@ -152,14 +99,8 @@ type CSRMetadata struct { } type CSR struct { - ID string `json:"id" db:"id"` - CSR []byte `json:"csr,omitempty" db:"csr"` - PrivateKey []byte `json:"private_key,omitempty" db:"private_key"` - EntityID string `json:"entity_id" db:"entity_id"` - Status CSRStatus `json:"status" db:"status"` - SubmittedAt time.Time `json:"submitted_at" db:"submitted_at"` - SignedAt time.Time `json:"signed_at,omitempty" db:"signed_at"` - SerialNumber string `json:"serial_number,omitempty" db:"serial_number"` + CSR []byte `json:"csr"` + PrivateKey []byte `json:"private_key"` } type CSRPage struct { @@ -235,16 +176,10 @@ type Service interface { RemoveCert(ctx context.Context, entityId string) error // CreateCSR creates a new Certificate Signing Request - CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string, privKey ...*rsa.PrivateKey) (CSR, error) - - // SignCSR processes a pending CSR and either approves or rejects it - SignCSR(ctx context.Context, csrID string, approve bool) error + CreateCSR(ctx context.Context, metadata CSRMetadata, privKey *rsa.PrivateKey) (CSR, error) - // RetrieveCSR retrieves a specific CSR by ID - RetrieveCSR(ctx context.Context, csrID string) (CSR, error) - - // ListCSRs returns a list of CSRs based on filter criteria - ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) + // SignCSR parses and signs a CSR + SignCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) } type Repository interface { @@ -269,10 +204,3 @@ type Repository interface { // RemoveCert deletes cert from database. RemoveCert(ctx context.Context, entityId string) error } - -type CSRRepository interface { - CreateCSR(context.Context, CSR) error - UpdateCSR(context.Context, CSR) error - RetrieveCSR(context.Context, string) (CSR, error) - ListCSRs(context.Context, PageMetadata) (CSRPage, error) -} diff --git a/certs_test.go b/certs_test.go index 184773f..28616e2 100644 --- a/certs_test.go +++ b/certs_test.go @@ -34,12 +34,10 @@ var ( func TestIssueCert(t *testing.T) { cRepo := new(mocks.MockRepository) - csrRepo := new(mocks.MockCSRRepository) - idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) + svc, err := certs.NewService(context.Background(), cRepo, &config) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -78,14 +76,12 @@ func TestIssueCert(t *testing.T) { func TestRevokeCert(t *testing.T) { cRepo := new(mocks.MockRepository) - csrRepo := new(mocks.MockCSRRepository) - idProvider := mocks.NewMock() invalidSerialNumber := "invalid serial number" repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) + svc, err := certs.NewService(context.Background(), cRepo, &config) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -134,12 +130,10 @@ func TestRevokeCert(t *testing.T) { func TestGetCertDownloadToken(t *testing.T) { cRepo := new(mocks.MockRepository) - csrRepo := new(mocks.MockCSRRepository) - idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) + svc, err := certs.NewService(context.Background(), cRepo, &config) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -166,8 +160,6 @@ func TestGetCertDownloadToken(t *testing.T) { func TestGetCert(t *testing.T) { cRepo := new(mocks.MockRepository) - csrRepo := new(mocks.MockCSRRepository) - idProvider := mocks.NewMock() jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ExpiresAt: time.Now().Add(time.Minute * 5).Unix(), Issuer: certs.Organization, Subject: "certs"}) validToken, err := jwtToken.SignedString([]byte(serialNumber)) @@ -175,7 +167,7 @@ func TestGetCert(t *testing.T) { repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) + svc, err := certs.NewService(context.Background(), cRepo, &config) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -219,8 +211,6 @@ func TestGetCert(t *testing.T) { func TestRenewCert(t *testing.T) { cRepo := new(mocks.MockRepository) - csrRepo := new(mocks.MockCSRRepository) - idProvider := mocks.NewMock() serialNumber := big.NewInt(1) expiredSerialNumber := big.NewInt(2) @@ -272,7 +262,7 @@ func TestRenewCert(t *testing.T) { repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) + svc, err := certs.NewService(context.Background(), cRepo, &config) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -361,12 +351,10 @@ func TestRenewCert(t *testing.T) { func TestGetEntityID(t *testing.T) { cRepo := new(mocks.MockRepository) - csrRepo := new(mocks.MockCSRRepository) - idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) + svc, err := certs.NewService(context.Background(), cRepo, &config) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -394,12 +382,10 @@ func TestGetEntityID(t *testing.T) { func TestListCerts(t *testing.T) { cRepo := new(mocks.MockRepository) - csrRepo := new(mocks.MockCSRRepository) - idProvider := mocks.NewMock() repoCall := cRepo.On("GetCAs", mock.Anything).Return([]certs.Certificate{}, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) + svc, err := certs.NewService(context.Background(), cRepo, &config) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() @@ -433,8 +419,6 @@ func TestListCerts(t *testing.T) { func TestGenerateCRL(t *testing.T) { cRepo := new(mocks.MockRepository) - csrRepo := new(mocks.MockCSRRepository) - idProvider := mocks.NewMock() privateKey, _ := rsa.GenerateKey(rand.Reader, 2048) template := &x509.Certificate{ @@ -456,7 +440,7 @@ func TestGenerateCRL(t *testing.T) { {Type: certs.IntermediateCA, Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)})}, }, nil) repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(nil) - svc, err := certs.NewService(context.Background(), cRepo, csrRepo, &config, idProvider) + svc, err := certs.NewService(context.Background(), cRepo, &config) require.NoError(t, err) repoCall.Unset() repoCall1.Unset() diff --git a/cli/certs.go b/cli/certs.go index 2cf131b..4a56b6b 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -5,10 +5,8 @@ package cli import ( "encoding/json" - "fmt" "os" - "github.com/absmach/certs/errors" ctxsdk "github.com/absmach/certs/sdk" "github.com/spf13/cobra" ) @@ -245,7 +243,7 @@ var cmdCerts = []cobra.Command{ Short: "Create CSR", Long: `Creates a CSR.`, Run: func(cmd *cobra.Command, args []string) { - if len(args) > 3 || len(args) == 0 { + if len(args) != 0 { logUsageCmd(*cmd, cmd.Use) return } @@ -256,25 +254,13 @@ var cmdCerts = []cobra.Command{ return } - var csr ctxsdk.CSR - var err error - if len(args) == 1 { - csr, err = sdk.CreateCSR(pm, []byte{}) - if err != nil { - logErrorCmd(*cmd, err) - return - } - logJSONCmd(*cmd, csr) - return - } - data, err := os.ReadFile(args[1]) if err != nil { logErrorCmd(*cmd, err) return } - csr, err = sdk.CreateCSR(pm, data) + csr, err := sdk.CreateCSR(pm, data) if err != nil { logErrorCmd(*cmd, err) return @@ -284,80 +270,22 @@ var cmdCerts = []cobra.Command{ }, }, { - Use: "sign ", + Use: "sign ", Short: "Sign CSR", Long: `Signs a CSR for a given csr id.`, Run: func(cmd *cobra.Command, args []string) { - if len(args) != 2 { + if len(args) != 3 { logUsageCmd(*cmd, cmd.Use) return } - var sign bool - switch args[1] { - case "true": - sign = true - case "false": - sign = false - default: - logErrorCmd(*cmd, errors.NewSDKError(fmt.Errorf("unknown type"))) - return - } - err := sdk.SignCSR(args[0], sign) + data, err := os.ReadFile(args[2]) if err != nil { logErrorCmd(*cmd, err) return } - logOKCmd(*cmd) - }, - }, - { - Use: "get-csr [all | ] ", - Short: "Get csr", - Long: `Gets CSRs for a given entity ID or all CSR.`, - Run: func(cmd *cobra.Command, args []string) { - if len(args) != 2 { - logUsageCmd(*cmd, cmd.Use) - return - } - if args[0] == "all" { - pm := ctxsdk.PageMetadata{ - Limit: Limit, - Offset: Offset, - Status: args[1], - } - page, err := sdk.ListCSRs(pm) - if err != nil { - logErrorCmd(*cmd, err) - return - } - logJSONCmd(*cmd, page) - return - } - pm := ctxsdk.PageMetadata{ - EntityID: args[0], - Limit: Limit, - Offset: Offset, - Status: args[1], - } - page, err := sdk.ListCSRs(pm) - if err != nil { - logErrorCmd(*cmd, err) - return - } - logJSONCmd(*cmd, page) - }, - }, - { - Use: "view-csr ", - Short: "View CSR", - Long: `Views a CSR for a given csr id.`, - Run: func(cmd *cobra.Command, args []string) { - if len(args) != 1 { - logUsageCmd(*cmd, cmd.Use) - return - } - cert, err := sdk.RetrieveCSR(args[0]) + + cert, err := sdk.SignCSR(args[0], args[1], data) if err != nil { logErrorCmd(*cmd, err) return @@ -407,7 +335,7 @@ func NewCertsCmd() *cobra.Command { issueCmd.Flags().StringVar(&ttl, "ttl", "8760h", "certificate time to live in duration") cmd := cobra.Command{ - Use: "certs [issue | get | revoke | renew | ocsp | token | download]", + Use: "certs [issue | get | revoke | renew | ocsp | token | download | download-ca | download-ca | csr | sign]", Short: "Certificates management", Long: `Certificates management: issue, get all, get by entity ID, revoke, renew, OCSP, token, download.`, } diff --git a/cmd/certs/main.go b/cmd/certs/main.go index c174da5..69deb97 100644 --- a/cmd/certs/main.go +++ b/cmd/certs/main.go @@ -25,7 +25,6 @@ import ( httpserver "github.com/absmach/certs/internal/server/http" "github.com/absmach/certs/internal/uuid" cpostgres "github.com/absmach/certs/postgres/certs" - csrpostgres "github.com/absmach/certs/postgres/csr" "github.com/absmach/certs/tracing" "github.com/caarlos0/env/v10" "github.com/jmoiron/sqlx" @@ -80,8 +79,6 @@ func main() { logger.Error(err.Error()) } cm := cpostgres.Migration() - sm := csrpostgres.Migration() - cm.Migrations = append(cm.Migrations, sm.Migrations...) db, err := pgClient.Setup(dbConfig, *cm) if err != nil { log.Fatalf(fmt.Sprintf("Failed to connect to %s database: %s", svcName, err)) @@ -150,9 +147,7 @@ func main() { func newService(ctx context.Context, db *sqlx.DB, tracer trace.Tracer, logger *slog.Logger, dbConfig pgClient.Config, config *certs.Config) (certs.Service, error) { database := postgres.NewDatabase(db, dbConfig, tracer) repo := cpostgres.NewRepository(database) - csrRepo := csrpostgres.NewRepository(database) - idp := uuid.New() - svc, err := certs.NewService(ctx, repo, csrRepo, config, idp) + svc, err := certs.NewService(ctx, repo, config) if err != nil { return nil, err } diff --git a/mocks/csr.go b/mocks/csr.go deleted file mode 100644 index 79340c1..0000000 --- a/mocks/csr.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -// Code generated by mockery v2.43.2. DO NOT EDIT. - -package mocks - -import ( - context "context" - - certs "github.com/absmach/certs" - - mock "github.com/stretchr/testify/mock" -) - -// MockCSRRepository is an autogenerated mock type for the CSRRepository type -type MockCSRRepository struct { - mock.Mock -} - -type MockCSRRepository_Expecter struct { - mock *mock.Mock -} - -func (_m *MockCSRRepository) EXPECT() *MockCSRRepository_Expecter { - return &MockCSRRepository_Expecter{mock: &_m.Mock} -} - -// CreateCSR provides a mock function with given fields: _a0, _a1 -func (_m *MockCSRRepository) CreateCSR(_a0 context.Context, _a1 certs.CSR) error { - ret := _m.Called(_a0, _a1) - - if len(ret) == 0 { - panic("no return value specified for CreateCSR") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, certs.CSR) error); ok { - r0 = rf(_a0, _a1) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockCSRRepository_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' -type MockCSRRepository_CreateCSR_Call struct { - *mock.Call -} - -// CreateCSR is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 certs.CSR -func (_e *MockCSRRepository_Expecter) CreateCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_CreateCSR_Call { - return &MockCSRRepository_CreateCSR_Call{Call: _e.mock.On("CreateCSR", _a0, _a1)} -} - -func (_c *MockCSRRepository_CreateCSR_Call) Run(run func(_a0 context.Context, _a1 certs.CSR)) *MockCSRRepository_CreateCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(certs.CSR)) - }) - return _c -} - -func (_c *MockCSRRepository_CreateCSR_Call) Return(_a0 error) *MockCSRRepository_CreateCSR_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockCSRRepository_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSR) error) *MockCSRRepository_CreateCSR_Call { - _c.Call.Return(run) - return _c -} - -// ListCSRs provides a mock function with given fields: _a0, _a1 -func (_m *MockCSRRepository) ListCSRs(_a0 context.Context, _a1 certs.PageMetadata) (certs.CSRPage, error) { - ret := _m.Called(_a0, _a1) - - if len(ret) == 0 { - panic("no return value specified for ListCSRs") - } - - var r0 certs.CSRPage - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) (certs.CSRPage, error)); ok { - return rf(_a0, _a1) - } - if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) certs.CSRPage); ok { - r0 = rf(_a0, _a1) - } else { - r0 = ret.Get(0).(certs.CSRPage) - } - - if rf, ok := ret.Get(1).(func(context.Context, certs.PageMetadata) error); ok { - r1 = rf(_a0, _a1) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockCSRRepository_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' -type MockCSRRepository_ListCSRs_Call struct { - *mock.Call -} - -// ListCSRs is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 certs.PageMetadata -func (_e *MockCSRRepository_Expecter) ListCSRs(_a0 interface{}, _a1 interface{}) *MockCSRRepository_ListCSRs_Call { - return &MockCSRRepository_ListCSRs_Call{Call: _e.mock.On("ListCSRs", _a0, _a1)} -} - -func (_c *MockCSRRepository_ListCSRs_Call) Run(run func(_a0 context.Context, _a1 certs.PageMetadata)) *MockCSRRepository_ListCSRs_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(certs.PageMetadata)) - }) - return _c -} - -func (_c *MockCSRRepository_ListCSRs_Call) Return(_a0 certs.CSRPage, _a1 error) *MockCSRRepository_ListCSRs_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockCSRRepository_ListCSRs_Call) RunAndReturn(run func(context.Context, certs.PageMetadata) (certs.CSRPage, error)) *MockCSRRepository_ListCSRs_Call { - _c.Call.Return(run) - return _c -} - -// RetrieveCSR provides a mock function with given fields: _a0, _a1 -func (_m *MockCSRRepository) RetrieveCSR(_a0 context.Context, _a1 string) (certs.CSR, error) { - ret := _m.Called(_a0, _a1) - - if len(ret) == 0 { - panic("no return value specified for RetrieveCSR") - } - - var r0 certs.CSR - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (certs.CSR, error)); ok { - return rf(_a0, _a1) - } - if rf, ok := ret.Get(0).(func(context.Context, string) certs.CSR); ok { - r0 = rf(_a0, _a1) - } else { - r0 = ret.Get(0).(certs.CSR) - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(_a0, _a1) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockCSRRepository_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' -type MockCSRRepository_RetrieveCSR_Call struct { - *mock.Call -} - -// RetrieveCSR is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 string -func (_e *MockCSRRepository_Expecter) RetrieveCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_RetrieveCSR_Call { - return &MockCSRRepository_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", _a0, _a1)} -} - -func (_c *MockCSRRepository_RetrieveCSR_Call) Run(run func(_a0 context.Context, _a1 string)) *MockCSRRepository_RetrieveCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string)) - }) - return _c -} - -func (_c *MockCSRRepository_RetrieveCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockCSRRepository_RetrieveCSR_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockCSRRepository_RetrieveCSR_Call) RunAndReturn(run func(context.Context, string) (certs.CSR, error)) *MockCSRRepository_RetrieveCSR_Call { - _c.Call.Return(run) - return _c -} - -// UpdateCSR provides a mock function with given fields: _a0, _a1 -func (_m *MockCSRRepository) UpdateCSR(_a0 context.Context, _a1 certs.CSR) error { - ret := _m.Called(_a0, _a1) - - if len(ret) == 0 { - panic("no return value specified for UpdateCSR") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, certs.CSR) error); ok { - r0 = rf(_a0, _a1) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockCSRRepository_UpdateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateCSR' -type MockCSRRepository_UpdateCSR_Call struct { - *mock.Call -} - -// UpdateCSR is a helper method to define mock.On call -// - _a0 context.Context -// - _a1 certs.CSR -func (_e *MockCSRRepository_Expecter) UpdateCSR(_a0 interface{}, _a1 interface{}) *MockCSRRepository_UpdateCSR_Call { - return &MockCSRRepository_UpdateCSR_Call{Call: _e.mock.On("UpdateCSR", _a0, _a1)} -} - -func (_c *MockCSRRepository_UpdateCSR_Call) Run(run func(_a0 context.Context, _a1 certs.CSR)) *MockCSRRepository_UpdateCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(certs.CSR)) - }) - return _c -} - -func (_c *MockCSRRepository_UpdateCSR_Call) Return(_a0 error) *MockCSRRepository_UpdateCSR_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockCSRRepository_UpdateCSR_Call) RunAndReturn(run func(context.Context, certs.CSR) error) *MockCSRRepository_UpdateCSR_Call { - _c.Call.Return(run) - return _c -} - -// NewMockCSRRepository creates a new instance of MockCSRRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockCSRRepository(t interface { - mock.TestingT - Cleanup(func()) -}) *MockCSRRepository { - mock := &MockCSRRepository{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/mocks/service.go b/mocks/service.go index b7f2c1f..978ab36 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -30,16 +30,9 @@ func (_m *MockService) EXPECT() *MockService_Expecter { return &MockService_Expecter{mock: &_m.Mock} } -// CreateCSR provides a mock function with given fields: ctx, metadata, entityID, privKey -func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, entityID string, privKey ...*rsa.PrivateKey) (certs.CSR, error) { - _va := make([]interface{}, len(privKey)) - for _i := range privKey { - _va[_i] = privKey[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, metadata, entityID) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +// CreateCSR provides a mock function with given fields: ctx, metadata, privKey +func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) { + ret := _m.Called(ctx, metadata, privKey) if len(ret) == 0 { panic("no return value specified for CreateCSR") @@ -47,17 +40,17 @@ func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata var r0 certs.CSR var r1 error - if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) (certs.CSR, error)); ok { - return rf(ctx, metadata, entityID, privKey...) + if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, *rsa.PrivateKey) (certs.CSR, error)); ok { + return rf(ctx, metadata, privKey) } - if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) certs.CSR); ok { - r0 = rf(ctx, metadata, entityID, privKey...) + if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, *rsa.PrivateKey) certs.CSR); ok { + r0 = rf(ctx, metadata, privKey) } else { r0 = ret.Get(0).(certs.CSR) } - if rf, ok := ret.Get(1).(func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) error); ok { - r1 = rf(ctx, metadata, entityID, privKey...) + if rf, ok := ret.Get(1).(func(context.Context, certs.CSRMetadata, *rsa.PrivateKey) error); ok { + r1 = rf(ctx, metadata, privKey) } else { r1 = ret.Error(1) } @@ -73,22 +66,14 @@ type MockService_CreateCSR_Call struct { // CreateCSR is a helper method to define mock.On call // - ctx context.Context // - metadata certs.CSRMetadata -// - entityID string -// - privKey ...*rsa.PrivateKey -func (_e *MockService_Expecter) CreateCSR(ctx interface{}, metadata interface{}, entityID interface{}, privKey ...interface{}) *MockService_CreateCSR_Call { - return &MockService_CreateCSR_Call{Call: _e.mock.On("CreateCSR", - append([]interface{}{ctx, metadata, entityID}, privKey...)...)} +// - privKey *rsa.PrivateKey +func (_e *MockService_Expecter) CreateCSR(ctx interface{}, metadata interface{}, privKey interface{}) *MockService_CreateCSR_Call { + return &MockService_CreateCSR_Call{Call: _e.mock.On("CreateCSR", ctx, metadata, privKey)} } -func (_c *MockService_CreateCSR_Call) Run(run func(ctx context.Context, metadata certs.CSRMetadata, entityID string, privKey ...*rsa.PrivateKey)) *MockService_CreateCSR_Call { +func (_c *MockService_CreateCSR_Call) Run(run func(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey)) *MockService_CreateCSR_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*rsa.PrivateKey, len(args)-3) - for i, a := range args[3:] { - if a != nil { - variadicArgs[i] = a.(*rsa.PrivateKey) - } - } - run(args[0].(context.Context), args[1].(certs.CSRMetadata), args[2].(string), variadicArgs...) + run(args[0].(context.Context), args[1].(certs.CSRMetadata), args[2].(*rsa.PrivateKey)) }) return _c } @@ -98,7 +83,7 @@ func (_c *MockService_CreateCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockServ return _c } -func (_c *MockService_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSRMetadata, string, ...*rsa.PrivateKey) (certs.CSR, error)) *MockService_CreateCSR_Call { +func (_c *MockService_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSRMetadata, *rsa.PrivateKey) (certs.CSR, error)) *MockService_CreateCSR_Call { _c.Call.Return(run) return _c } @@ -351,63 +336,6 @@ func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, str return _c } -// ListCSRs provides a mock function with given fields: ctx, pm -func (_m *MockService) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { - ret := _m.Called(ctx, pm) - - if len(ret) == 0 { - panic("no return value specified for ListCSRs") - } - - var r0 certs.CSRPage - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) (certs.CSRPage, error)); ok { - return rf(ctx, pm) - } - if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) certs.CSRPage); ok { - r0 = rf(ctx, pm) - } else { - r0 = ret.Get(0).(certs.CSRPage) - } - - if rf, ok := ret.Get(1).(func(context.Context, certs.PageMetadata) error); ok { - r1 = rf(ctx, pm) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockService_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' -type MockService_ListCSRs_Call struct { - *mock.Call -} - -// ListCSRs is a helper method to define mock.On call -// - ctx context.Context -// - pm certs.PageMetadata -func (_e *MockService_Expecter) ListCSRs(ctx interface{}, pm interface{}) *MockService_ListCSRs_Call { - return &MockService_ListCSRs_Call{Call: _e.mock.On("ListCSRs", ctx, pm)} -} - -func (_c *MockService_ListCSRs_Call) Run(run func(ctx context.Context, pm certs.PageMetadata)) *MockService_ListCSRs_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(certs.PageMetadata)) - }) - return _c -} - -func (_c *MockService_ListCSRs_Call) Return(_a0 certs.CSRPage, _a1 error) *MockService_ListCSRs_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockService_ListCSRs_Call) RunAndReturn(run func(context.Context, certs.PageMetadata) (certs.CSRPage, error)) *MockService_ListCSRs_Call { - _c.Call.Return(run) - return _c -} - // ListCerts provides a mock function with given fields: ctx, pm func (_m *MockService) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { ret := _m.Called(ctx, pm) @@ -690,63 +618,6 @@ func (_c *MockService_RetrieveCAToken_Call) RunAndReturn(run func(context.Contex return _c } -// RetrieveCSR provides a mock function with given fields: ctx, csrID -func (_m *MockService) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { - ret := _m.Called(ctx, csrID) - - if len(ret) == 0 { - panic("no return value specified for RetrieveCSR") - } - - var r0 certs.CSR - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (certs.CSR, error)); ok { - return rf(ctx, csrID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) certs.CSR); ok { - r0 = rf(ctx, csrID) - } else { - r0 = ret.Get(0).(certs.CSR) - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, csrID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockService_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' -type MockService_RetrieveCSR_Call struct { - *mock.Call -} - -// RetrieveCSR is a helper method to define mock.On call -// - ctx context.Context -// - csrID string -func (_e *MockService_Expecter) RetrieveCSR(ctx interface{}, csrID interface{}) *MockService_RetrieveCSR_Call { - return &MockService_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", ctx, csrID)} -} - -func (_c *MockService_RetrieveCSR_Call) Run(run func(ctx context.Context, csrID string)) *MockService_RetrieveCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string)) - }) - return _c -} - -func (_c *MockService_RetrieveCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockService_RetrieveCSR_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockService_RetrieveCSR_Call) RunAndReturn(run func(context.Context, string) (certs.CSR, error)) *MockService_RetrieveCSR_Call { - _c.Call.Return(run) - return _c -} - // RetrieveCert provides a mock function with given fields: ctx, token, serialNumber func (_m *MockService) RetrieveCert(ctx context.Context, token string, serialNumber string) (certs.Certificate, []byte, error) { ret := _m.Called(ctx, token, serialNumber) @@ -918,22 +789,32 @@ func (_c *MockService_RevokeCert_Call) RunAndReturn(run func(context.Context, st return _c } -// SignCSR provides a mock function with given fields: ctx, csrID, approve -func (_m *MockService) SignCSR(ctx context.Context, csrID string, approve bool) error { - ret := _m.Called(ctx, csrID, approve) +// SignCSR provides a mock function with given fields: ctx, entityID, ttl, csr +func (_m *MockService) SignCSR(ctx context.Context, entityID string, ttl string, csr certs.CSR) (certs.Certificate, error) { + ret := _m.Called(ctx, entityID, ttl, csr) if len(ret) == 0 { panic("no return value specified for SignCSR") } - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok { - r0 = rf(ctx, csrID, approve) + var r0 certs.Certificate + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, certs.CSR) (certs.Certificate, error)); ok { + return rf(ctx, entityID, ttl, csr) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, certs.CSR) certs.Certificate); ok { + r0 = rf(ctx, entityID, ttl, csr) } else { - r0 = ret.Error(0) + r0 = ret.Get(0).(certs.Certificate) } - return r0 + if rf, ok := ret.Get(1).(func(context.Context, string, string, certs.CSR) error); ok { + r1 = rf(ctx, entityID, ttl, csr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 } // MockService_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' @@ -943,25 +824,26 @@ type MockService_SignCSR_Call struct { // SignCSR is a helper method to define mock.On call // - ctx context.Context -// - csrID string -// - approve bool -func (_e *MockService_Expecter) SignCSR(ctx interface{}, csrID interface{}, approve interface{}) *MockService_SignCSR_Call { - return &MockService_SignCSR_Call{Call: _e.mock.On("SignCSR", ctx, csrID, approve)} +// - entityID string +// - ttl string +// - csr certs.CSR +func (_e *MockService_Expecter) SignCSR(ctx interface{}, entityID interface{}, ttl interface{}, csr interface{}) *MockService_SignCSR_Call { + return &MockService_SignCSR_Call{Call: _e.mock.On("SignCSR", ctx, entityID, ttl, csr)} } -func (_c *MockService_SignCSR_Call) Run(run func(ctx context.Context, csrID string, approve bool)) *MockService_SignCSR_Call { +func (_c *MockService_SignCSR_Call) Run(run func(ctx context.Context, entityID string, ttl string, csr certs.CSR)) *MockService_SignCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(bool)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(certs.CSR)) }) return _c } -func (_c *MockService_SignCSR_Call) Return(_a0 error) *MockService_SignCSR_Call { - _c.Call.Return(_a0) +func (_c *MockService_SignCSR_Call) Return(_a0 certs.Certificate, _a1 error) *MockService_SignCSR_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *MockService_SignCSR_Call) RunAndReturn(run func(context.Context, string, bool) error) *MockService_SignCSR_Call { +func (_c *MockService_SignCSR_Call) RunAndReturn(run func(context.Context, string, string, certs.CSR) (certs.Certificate, error)) *MockService_SignCSR_Call { _c.Call.Return(run) return _c } diff --git a/mocks/uuid.go b/mocks/uuid.go deleted file mode 100644 index 065daba..0000000 --- a/mocks/uuid.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package mocks - -import ( - "fmt" - "sync" - - "github.com/absmach/certs/internal/uuid" -) - -// Prefix represents the prefix used to generate UUID mocks. -const Prefix = "123e4567-e89b-12d3-a456-" - -var _ uuid.IDProvider = (*uuidProviderMock)(nil) - -type uuidProviderMock struct { - mu sync.Mutex - counter int -} - -func (up *uuidProviderMock) ID() (string, error) { - up.mu.Lock() - defer up.mu.Unlock() - - up.counter++ - return fmt.Sprintf("%s%012d", Prefix, up.counter), nil -} - -// NewMock creates "mirror" uuid provider, i.e. generated -// token will hold value provided by the caller. -func NewMock() uuid.IDProvider { - return &uuidProviderMock{} -} diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go deleted file mode 100644 index f1a2008..0000000 --- a/postgres/csr/csr.go +++ /dev/null @@ -1,202 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package postgres - -import ( - "context" - "database/sql" - "fmt" - "strings" - "time" - - "github.com/absmach/certs" - "github.com/absmach/certs/errors" - "github.com/absmach/certs/internal/postgres" - "github.com/jackc/pgx/v5/pgconn" -) - -// Postgres error codes: -// https://www.postgresql.org/docs/current/errcodes-appendix.html -const ( - errDuplicate = "23505" // unique_violation - errTruncation = "22001" // string_data_right_truncation - errFK = "23503" // foreign_key_violation - errInvalid = "22P02" // invalid_text_representation - errUntranslatable = "22P05" // untranslatable_character - errInvalidChar = "22021" // character_not_in_repertoire -) - -var ( - ErrConflict = errors.New("entity already exists") - ErrMalformedEntity = errors.New("malformed entity") - ErrCreateEntity = errors.New("failed to create entity") -) - -type CSRRepo struct { - db postgres.Database -} - -func NewRepository(db postgres.Database) certs.CSRRepository { - return CSRRepo{ - db: db, - } -} - -func (repo CSRRepo) CreateCSR(ctx context.Context, csr certs.CSR) error { - q := ` - INSERT INTO csrs (id, serial_number, csr, private_key, entity_id, status, submitted_at, signed_at) - VALUES (:id, :serial_number, :csr, :private_key, :entity_id, :status, :submitted_at, :signed_at)` - _, err := repo.db.NamedExecContext(ctx, q, csr) - if err != nil { - return handleError(certs.ErrCreateEntity, err) - } - return nil -} - -func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { - updateData := rawCSR{ - ID: csr.ID, - SerialNumber: csr.SerialNumber, - Status: csr.Status.String(), - PrivateKey: csr.PrivateKey, - SubmittedAt: csr.SubmittedAt, - SignedAt: csr.SignedAt, - } - - q := `UPDATE csrs SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, signed_at = :signed_at WHERE id = :id` - res, err := repo.db.NamedExecContext(ctx, q, updateData) - if err != nil { - return handleError(certs.ErrUpdateEntity, err) - } - count, err := res.RowsAffected() - if err != nil { - return errors.Wrap(certs.ErrUpdateEntity, err) - } - if count == 0 { - return certs.ErrNotFound - } - return nil -} - -func (repo CSRRepo) RetrieveCSR(ctx context.Context, id string) (certs.CSR, error) { - q := `SELECT id, serial_number, csr, private_key, entity_id, status, submitted_at, signed_at FROM csrs WHERE id = $1` - var csrRaw rawCSR - if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csrRaw); err != nil { - if err == sql.ErrNoRows { - return certs.CSR{}, errors.Wrap(certs.ErrNotFound, err) - } - return certs.CSR{}, errors.Wrap(certs.ErrViewEntity, err) - } - - status, err := certs.ParseCSRStatus(csrRaw.Status) - if err != nil { - return certs.CSR{}, errors.Wrap(certs.ErrViewEntity, fmt.Errorf("invalid status: %s", csrRaw.Status)) - } - return certs.CSR{ - ID: csrRaw.ID, - SerialNumber: csrRaw.SerialNumber, - CSR: csrRaw.CSR, - PrivateKey: csrRaw.PrivateKey, - EntityID: csrRaw.EntityID, - Status: status, - SubmittedAt: csrRaw.SubmittedAt, - SignedAt: csrRaw.SignedAt, - }, nil -} - -func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { - var query []string - params := map[string]interface{}{ - "limit": pm.Limit, - "offset": pm.Offset, - } - if pm.EntityID != "" { - query = append(query, `c.entity_id = :entity_id`) - params["entity_id"] = pm.EntityID - } - if pm.Status != certs.All { - query = append(query, `c.status = :status`) - params["status"] = pm.Status - } - - var str string - if len(query) > 0 { - str = fmt.Sprintf(`WHERE %s`, strings.Join(query, ` AND `)) - } - - q := fmt.Sprintf(` - SELECT - c.id, - c.serial_number, - c.submitted_at, - c.signed_at, - c.entity_id - FROM csrs c %s LIMIT :limit OFFSET :offset;`, str) - - rows, err := repo.db.NamedQueryContext(ctx, q, pm) - if err != nil { - return certs.CSRPage{}, handleError(certs.ErrViewEntity, err) - } - defer rows.Close() - var csrs []certs.CSR - for rows.Next() { - csr := certs.CSR{} - if err := rows.StructScan(&csr); err != nil { - return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) - } - csrs = append(csrs, csr) - } - - cq := fmt.Sprintf(`SELECT COUNT(*) FROM csrs c %s;`, str) - pm.Total, err = repo.total(ctx, cq, pm) - if err != nil { - return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) - } - return certs.CSRPage{ - PageMetadata: pm, - CSRs: csrs, - }, nil -} - -func (repo CSRRepo) total(ctx context.Context, query string, params interface{}) (uint64, error) { - rows, err := repo.db.NamedQueryContext(ctx, query, params) - if err != nil { - return 0, err - } - defer rows.Close() - total := uint64(0) - if rows.Next() { - if err := rows.Scan(&total); err != nil { - return 0, err - } - } - return total, nil -} - -func handleError(wrapper, err error) error { - pqErr, ok := err.(*pgconn.PgError) - if ok { - switch pqErr.Code { - case errDuplicate: - return errors.Wrap(ErrConflict, err) - case errInvalid, errInvalidChar, errTruncation, errUntranslatable: - return errors.Wrap(ErrMalformedEntity, err) - case errFK: - return errors.Wrap(ErrCreateEntity, err) - } - } - - return errors.Wrap(wrapper, err) -} - -type rawCSR struct { - ID string `db:"id"` - SerialNumber string `db:"serial_number"` - CSR []byte `db:"csr"` - PrivateKey []byte `db:"private_key"` - EntityID string `db:"entity_id"` - Status string `db:"status"` - SubmittedAt time.Time `db:"submitted_at"` - SignedAt time.Time `db:"signed_at"` -} diff --git a/postgres/csr/init.go b/postgres/csr/init.go deleted file mode 100644 index 9ef8156..0000000 --- a/postgres/csr/init.go +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) Abstract Machines -// SPDX-License-Identifier: Apache-2.0 - -package postgres - -import ( - _ "github.com/jackc/pgx/v5/stdlib" - migrate "github.com/rubenv/sql-migrate" -) - -func Migration() *migrate.MemoryMigrationSource { - return &migrate.MemoryMigrationSource{ - Migrations: []*migrate.Migration{ - { - Id: "csrs_1", - Up: []string{ - `CREATE TABLE IF NOT EXISTS csrs ( - id VARCHAR(36) PRIMARY KEY, - serial_number VARCHAR(40), - csr TEXT, - private_key TEXT, - entity_id VARCHAR(36), - status TEXT CHECK (status IN ('pending', 'signed', 'rejected')), - submitted_at TIMESTAMP, - signed_at TIMESTAMP - )`, - }, - Down: []string{ - "DROP TABLE csr", - }, - }, - }, - } -} diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index 23be435..8384bc5 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -367,64 +367,6 @@ func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string return _c } -// ListCSRs provides a mock function with given fields: pm -func (_m *MockSDK) ListCSRs(pm sdk.PageMetadata) (sdk.CSRPage, errors.SDKError) { - ret := _m.Called(pm) - - if len(ret) == 0 { - panic("no return value specified for ListCSRs") - } - - var r0 sdk.CSRPage - var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(sdk.PageMetadata) (sdk.CSRPage, errors.SDKError)); ok { - return rf(pm) - } - if rf, ok := ret.Get(0).(func(sdk.PageMetadata) sdk.CSRPage); ok { - r0 = rf(pm) - } else { - r0 = ret.Get(0).(sdk.CSRPage) - } - - if rf, ok := ret.Get(1).(func(sdk.PageMetadata) errors.SDKError); ok { - r1 = rf(pm) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(errors.SDKError) - } - } - - return r0, r1 -} - -// MockSDK_ListCSRs_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListCSRs' -type MockSDK_ListCSRs_Call struct { - *mock.Call -} - -// ListCSRs is a helper method to define mock.On call -// - pm sdk.PageMetadata -func (_e *MockSDK_Expecter) ListCSRs(pm interface{}) *MockSDK_ListCSRs_Call { - return &MockSDK_ListCSRs_Call{Call: _e.mock.On("ListCSRs", pm)} -} - -func (_c *MockSDK_ListCSRs_Call) Run(run func(pm sdk.PageMetadata)) *MockSDK_ListCSRs_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata)) - }) - return _c -} - -func (_c *MockSDK_ListCSRs_Call) Return(_a0 sdk.CSRPage, _a1 errors.SDKError) *MockSDK_ListCSRs_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockSDK_ListCSRs_Call) RunAndReturn(run func(sdk.PageMetadata) (sdk.CSRPage, errors.SDKError)) *MockSDK_ListCSRs_Call { - _c.Call.Return(run) - return _c -} - // ListCerts provides a mock function with given fields: pm func (_m *MockSDK) ListCerts(pm sdk.PageMetadata) (sdk.CertificatePage, errors.SDKError) { ret := _m.Called(pm) @@ -590,64 +532,6 @@ func (_c *MockSDK_RenewCert_Call) RunAndReturn(run func(string) errors.SDKError) return _c } -// RetrieveCSR provides a mock function with given fields: csrID -func (_m *MockSDK) RetrieveCSR(csrID string) (sdk.CSR, errors.SDKError) { - ret := _m.Called(csrID) - - if len(ret) == 0 { - panic("no return value specified for RetrieveCSR") - } - - var r0 sdk.CSR - var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string) (sdk.CSR, errors.SDKError)); ok { - return rf(csrID) - } - if rf, ok := ret.Get(0).(func(string) sdk.CSR); ok { - r0 = rf(csrID) - } else { - r0 = ret.Get(0).(sdk.CSR) - } - - if rf, ok := ret.Get(1).(func(string) errors.SDKError); ok { - r1 = rf(csrID) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(errors.SDKError) - } - } - - return r0, r1 -} - -// MockSDK_RetrieveCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCSR' -type MockSDK_RetrieveCSR_Call struct { - *mock.Call -} - -// RetrieveCSR is a helper method to define mock.On call -// - csrID string -func (_e *MockSDK_Expecter) RetrieveCSR(csrID interface{}) *MockSDK_RetrieveCSR_Call { - return &MockSDK_RetrieveCSR_Call{Call: _e.mock.On("RetrieveCSR", csrID)} -} - -func (_c *MockSDK_RetrieveCSR_Call) Run(run func(csrID string)) *MockSDK_RetrieveCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) - }) - return _c -} - -func (_c *MockSDK_RetrieveCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *MockSDK_RetrieveCSR_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockSDK_RetrieveCSR_Call) RunAndReturn(run func(string) (sdk.CSR, errors.SDKError)) *MockSDK_RetrieveCSR_Call { - _c.Call.Return(run) - return _c -} - // RetrieveCertDownloadToken provides a mock function with given fields: serialNumber func (_m *MockSDK) RetrieveCertDownloadToken(serialNumber string) (sdk.Token, errors.SDKError) { ret := _m.Called(serialNumber) @@ -754,24 +638,34 @@ func (_c *MockSDK_RevokeCert_Call) RunAndReturn(run func(string) errors.SDKError return _c } -// SignCSR provides a mock function with given fields: csrID, sign -func (_m *MockSDK) SignCSR(csrID string, sign bool) errors.SDKError { - ret := _m.Called(csrID, sign) +// SignCSR provides a mock function with given fields: entityID, ttl, csr +func (_m *MockSDK) SignCSR(entityID string, ttl string, csr []byte) (sdk.Certificate, errors.SDKError) { + ret := _m.Called(entityID, ttl, csr) if len(ret) == 0 { panic("no return value specified for SignCSR") } - var r0 errors.SDKError - if rf, ok := ret.Get(0).(func(string, bool) errors.SDKError); ok { - r0 = rf(csrID, sign) + var r0 sdk.Certificate + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(string, string, []byte) (sdk.Certificate, errors.SDKError)); ok { + return rf(entityID, ttl, csr) + } + if rf, ok := ret.Get(0).(func(string, string, []byte) sdk.Certificate); ok { + r0 = rf(entityID, ttl, csr) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(errors.SDKError) + r0 = ret.Get(0).(sdk.Certificate) + } + + if rf, ok := ret.Get(1).(func(string, string, []byte) errors.SDKError); ok { + r1 = rf(entityID, ttl, csr) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) } } - return r0 + return r0, r1 } // MockSDK_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' @@ -780,25 +674,26 @@ type MockSDK_SignCSR_Call struct { } // SignCSR is a helper method to define mock.On call -// - csrID string -// - sign bool -func (_e *MockSDK_Expecter) SignCSR(csrID interface{}, sign interface{}) *MockSDK_SignCSR_Call { - return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", csrID, sign)} +// - entityID string +// - ttl string +// - csr []byte +func (_e *MockSDK_Expecter) SignCSR(entityID interface{}, ttl interface{}, csr interface{}) *MockSDK_SignCSR_Call { + return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", entityID, ttl, csr)} } -func (_c *MockSDK_SignCSR_Call) Run(run func(csrID string, sign bool)) *MockSDK_SignCSR_Call { +func (_c *MockSDK_SignCSR_Call) Run(run func(entityID string, ttl string, csr []byte)) *MockSDK_SignCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(bool)) + run(args[0].(string), args[1].(string), args[2].([]byte)) }) return _c } -func (_c *MockSDK_SignCSR_Call) Return(_a0 errors.SDKError) *MockSDK_SignCSR_Call { - _c.Call.Return(_a0) +func (_c *MockSDK_SignCSR_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKError) *MockSDK_SignCSR_Call { + _c.Call.Return(_a0, _a1) return _c } -func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, bool) errors.SDKError) *MockSDK_SignCSR_Call { +func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, string, []byte) (sdk.Certificate, errors.SDKError)) *MockSDK_SignCSR_Call { _c.Call.Return(run) return _c } diff --git a/sdk/sdk.go b/sdk/sdk.go index 6ca2fca..6616b20 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -93,7 +93,7 @@ type PageMetadata struct { IPAddresses []string `json:"ip_addresses,omitempty"` EmailAddresses []string `json:"email_addresses,omitempty"` Status string `json:"status,omitempty"` - Sign bool `json:"sign,omitempty"` + TTL string `json:"ttl,omitempty"` } type Options struct { @@ -176,19 +176,8 @@ type CSRMetadata struct { } type CSR struct { - ID string `json:"id,omitempty"` CSR []byte `json:"csr,omitempty"` PrivateKey []byte `json:"private_key,omitempty"` - EntityID string `json:"entity_id,omitempty"` - Status string `json:"status,omitempty"` - SubmittedAt time.Time `json:"submitted_at,omitempty"` - SignedAt time.Time `json:"signed_at,omitempty"` - SerialNumber string `json:"serial_number,omitempty"` -} - -type CSRPage struct { - PageMetadata - CSRs []CSR `json:"csrs,omitempty"` } type SDK interface { @@ -287,21 +276,9 @@ type SDK interface { // SignCSR processes a pending CSR and either signs or rejects it // // example: - // err := sdk.SignCSR( "csr_id", "privKeyPath") + // certs, err := sdk.SignCSR( "entityID", "ttl", []bytes("csrFile")) // fmt.Println(err) - SignCSR(csrID string, sign bool) errors.SDKError - - // RetrieveCSR retrieves a specific CSR by ID - // - // response, _ := sdk.RetrieveCSR("csr_id") - // fmt.Println(response) - RetrieveCSR(csrID string) (CSR, errors.SDKError) - - // ListCSRs returns a list of CSRs based on filter criteria - // - // response, _ := sdk.ListCSRs(sdk.PageMetadata{EntityID: "entity_id", Status: "pending"}) - // fmt.Println(response) - ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) + SignCSR(entityID, ttl string, csr []byte) (Certificate,errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -588,7 +565,7 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKErro if err != nil { return CSR{}, errors.NewSDKError(err) } - url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint, pm.EntityID) + url := fmt.Sprintf("%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint) _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK) if sdkerr != nil { return CSR{}, sdkerr @@ -601,52 +578,20 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKErro return csr, nil } -func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { +func (sdk mgSDK) SignCSR(entityID, ttl string, csr []byte) (Certificate,errors.SDKError) { pm := PageMetadata{ - Sign: sign, + TTL: ttl, } - url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, csrID), pm) + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, entityID), pm) if err != nil { - return errors.NewSDKError(err) + return Certificate{},errors.NewSDKError(err) } _, _, sdkerr := sdk.processRequest(http.MethodPatch, url, nil, nil, http.StatusOK) if sdkerr != nil { - return sdkerr - } - return nil -} - -func (sdk mgSDK) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) { - url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s", certsEndpoint, csrEndpoint), pm) - if err != nil { - return CSRPage{}, errors.NewSDKError(err) - } - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) - if sdkerr != nil { - return CSRPage{}, sdkerr - } - - var cp CSRPage - if err := json.Unmarshal(body, &cp); err != nil { - return CSRPage{}, errors.NewSDKError(err) - } - return cp, nil -} - -func (sdk mgSDK) RetrieveCSR(csrID string) (CSR, errors.SDKError) { - url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint, csrID) - - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusCreated) - if sdkerr != nil { - return CSR{}, sdkerr - } - - var csr CSR - if err := json.Unmarshal(body, &csr); err != nil { - return CSR{}, errors.NewSDKError(err) + return Certificate{},sdkerr } - return csr, nil + return Certificate{},nil } func NewSDK(conf Config) SDK { @@ -735,11 +680,8 @@ func (pm PageMetadata) query() (string, error) { if pm.CommonName != "" { q.Add("common_name", pm.CommonName) } - if pm.Sign { - q.Add("status", "true") - } - if pm.Status != "" { - q.Add("status", pm.Status) + if pm.TTL != "" { + q.Add("ttl", pm.TTL) } return q.Encode(), nil diff --git a/service.go b/service.go index a4ad34c..f36c3db 100644 --- a/service.go +++ b/service.go @@ -16,7 +16,6 @@ import ( "time" "github.com/absmach/certs/errors" - "github.com/absmach/certs/internal/uuid" "github.com/golang-jwt/jwt" "golang.org/x/crypto/ocsp" ) @@ -52,20 +51,16 @@ var ( type service struct { repo Repository - csrRepo CSRRepository rootCA *CA intermediateCA *CA - idProvider uuid.IDProvider } var _ Service = (*service)(nil) -func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, config *Config, idp uuid.IDProvider) (Service, error) { +func NewService(ctx context.Context, repo Repository, config *Config) (Service, error) { var svc service svc.repo = repo - svc.csrRepo = csrRepo - svc.idProvider = idp if err := svc.loadCACerts(ctx); err != nil { return &svc, err } @@ -407,25 +402,7 @@ func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, er return s.getConcatCAs(ctx) } -func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID string, privateKey ...*rsa.PrivateKey) (CSR, error) { - var privKey *rsa.PrivateKey - var err error - - // Check if a private key is provided else generate a new private key. - if len(privateKey) > 0 && privateKey[0] != nil { - privKey = privateKey[0] - } else { - privKey, err = rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - if err != nil { - return CSR{}, errors.Wrap(ErrCreateEntity, err) - } - } - - csrID, err := s.idProvider.ID() - if err != nil { - return CSR{}, err - } - +func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey *rsa.PrivateKey) (CSR, error) { template := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: metadata.CommonName, @@ -464,57 +441,34 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID }) csr := CSR{ - ID: csrID, CSR: csrPEM, PrivateKey: privKeyPEM, - EntityID: entityID, - Status: Pending, - SubmittedAt: time.Now(), - } - - if err := s.csrRepo.CreateCSR(ctx, csr); err != nil { - return CSR{}, errors.Wrap(ErrCreateEntity, err) } return csr, nil } -func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error { - csr, err := s.csrRepo.RetrieveCSR(ctx, csrID) - if err != nil { - return errors.Wrap(ErrViewEntity, err) - } - - if csr.Status != Pending { - return ErrConflict - } - - if !approve { - csr.Status = Rejected - csr.SignedAt = time.Now() - return s.csrRepo.UpdateCSR(ctx, csr) - } - +func (s *service) SignCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) { block, _ := pem.Decode(csr.CSR) if block == nil { - return errors.New("failed to parse CSR PEM") + return Certificate{}, errors.New("failed to parse CSR PEM") } parsedCSR, err := x509.ParseCertificateRequest(block.Bytes) if err != nil { - return errors.Wrap(ErrMalformedEntity, err) + return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } if err := parsedCSR.CheckSignature(); err != nil { - return errors.Wrap(ErrMalformedEntity, err) + return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } privKey, err := extractPrivateKey(csr.PrivateKey) if err != nil { - return errors.Wrap(ErrMalformedEntity, err) + return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } - cert, err := s.IssueCert(ctx, csr.EntityID, "", nil, SubjectOptions{ + cert, err := s.IssueCert(ctx, entityID, ttl, nil, SubjectOptions{ CommonName: parsedCSR.Subject.CommonName, Organization: parsedCSR.Subject.Organization, OrganizationalUnit: parsedCSR.Subject.OrganizationalUnit, @@ -525,27 +479,10 @@ func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error PostalCode: parsedCSR.Subject.PostalCode, }, privKey) if err != nil { - return errors.Wrap(ErrCreateEntity, err) - } - - csr.Status = Signed - csr.SignedAt = time.Now() - csr.SerialNumber = cert.SerialNumber - - return s.csrRepo.UpdateCSR(ctx, csr) -} - -func (s *service) ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) { - cp, err := s.csrRepo.ListCSRs(ctx, pm) - if err != nil { - return CSRPage{}, errors.Wrap(ErrViewEntity, err) + return Certificate{}, errors.Wrap(ErrCreateEntity, err) } - return cp, nil -} - -func (s *service) RetrieveCSR(ctx context.Context, csrID string) (CSR, error) { - return s.csrRepo.RetrieveCSR(ctx, csrID) + return cert, nil } func (s *service) getConcatCAs(ctx context.Context) (Certificate, error) { diff --git a/tracing/certs.go b/tracing/certs.go index 4faeeb5..67e7314 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -102,26 +102,14 @@ func (tm *tracingMiddleware) GetChainCA(ctx context.Context, token string) (cert return tm.svc.GetChainCA(ctx, token) } -func (tm *tracingMiddleware) CreateCSR(ctx context.Context, meta certs.CSRMetadata, entityID string, key ...*rsa.PrivateKey) (certs.CSR, error) { +func (tm *tracingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) { ctx, span := tm.tracer.Start(ctx, "create_csr") defer span.End() - return tm.svc.CreateCSR(ctx, meta, entityID, key...) + return tm.svc.CreateCSR(ctx, metadata, privKey) } -func (tm *tracingMiddleware) SignCSR(ctx context.Context, csrID string, approve bool) error { +func (tm *tracingMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "sign_csr") defer span.End() - return tm.svc.SignCSR(ctx, csrID, approve) -} - -func (tm *tracingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { - ctx, span := tm.tracer.Start(ctx, "list_csrs") - defer span.End() - return tm.svc.ListCSRs(ctx, pm) -} - -func (tm *tracingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { - ctx, span := tm.tracer.Start(ctx, "retrieve_csr") - defer span.End() - return tm.svc.RetrieveCSR(ctx, csrID) + return tm.svc.SignCSR(ctx, entityID,ttl, csr) } From 290d256aed98db9d364cdfd7a1e4f3faf36f2d17 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 28 Nov 2024 19:41:46 +0300 Subject: [PATCH 02/22] failing linter Signed-off-by: nyagamunene --- api/http/errors.go | 3 +++ api/http/requests.go | 8 ++++++-- api/http/responses.go | 32 -------------------------------- api/http/transport.go | 7 +++---- sdk/sdk.go | 16 ++++++++-------- service.go | 4 ++-- tracing/certs.go | 2 +- 7 files changed, 23 insertions(+), 49 deletions(-) diff --git a/api/http/errors.go b/api/http/errors.go index 8e6a5ae..f69e03b 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -35,4 +35,7 @@ var ( // ErrMissingCSR indicates missing csr. ErrMissingCSR = errors.New("missing CSR") + + // ErrMissingPrivKey indicates missing csr. + ErrMissingPrivKey = errors.New("missing private key") ) diff --git a/api/http/requests.go b/api/http/requests.go index ed0f4cb..d148f37 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -100,13 +100,17 @@ func (req createCSRReq) validate() error { if req.Metadata.CommonName == "" { return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCN) } + + if len(req.PrivateKey) == 0 { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingPrivKey) + } return nil } type SignCSRReq struct { entityID string - ttl string - CSR []byte `json:"csr"` + ttl string + CSR []byte `json:"csr"` } func (req SignCSRReq) validate() error { diff --git a/api/http/responses.go b/api/http/responses.go index f6b394b..cc8ce13 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -240,35 +240,3 @@ func (res signCSRRes) Headers() map[string]string { func (res signCSRRes) Empty() bool { return false } - -type listCSRsRes struct { - certs.CSRPage -} - -func (res listCSRsRes) Code() int { - return http.StatusOK -} - -func (res listCSRsRes) Headers() map[string]string { - return map[string]string{} -} - -func (res listCSRsRes) Empty() bool { - return false -} - -type retrieveCSRRes struct { - certs.CSR -} - -func (res retrieveCSRRes) Code() int { - return http.StatusOK -} - -func (res retrieveCSRRes) Headers() map[string]string { - return map[string]string{} -} - -func (res retrieveCSRRes) Empty() bool { - return false -} diff --git a/api/http/transport.go b/api/http/transport.go index 6b5e391..8a8574d 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -39,7 +39,7 @@ const ( token = "token" ocspStatusParam = "force_status" entityIDParam = "entityID" - ttl = "ttl" + ttl = "ttl" defOffset = 0 defLimit = 10 defType = 1 @@ -305,8 +305,8 @@ func decodeSignCSR(_ context.Context, r *http.Request) (interface{}, error) { } req := SignCSRReq{ - entityID: chi.URLParam(r, "entityID"), - ttl: t, + entityID: chi.URLParam(r, "entityID"), + ttl: t, } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -316,7 +316,6 @@ func decodeSignCSR(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } - // EncodeResponse encodes successful response. func EncodeResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { if ar, ok := response.(Response); ok { diff --git a/sdk/sdk.go b/sdk/sdk.go index 6616b20..e1438db 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -93,7 +93,7 @@ type PageMetadata struct { IPAddresses []string `json:"ip_addresses,omitempty"` EmailAddresses []string `json:"email_addresses,omitempty"` Status string `json:"status,omitempty"` - TTL string `json:"ttl,omitempty"` + TTL string `json:"ttl,omitempty"` } type Options struct { @@ -176,8 +176,8 @@ type CSRMetadata struct { } type CSR struct { - CSR []byte `json:"csr,omitempty"` - PrivateKey []byte `json:"private_key,omitempty"` + CSR []byte `json:"csr,omitempty"` + PrivateKey []byte `json:"private_key,omitempty"` } type SDK interface { @@ -278,7 +278,7 @@ type SDK interface { // example: // certs, err := sdk.SignCSR( "entityID", "ttl", []bytes("csrFile")) // fmt.Println(err) - SignCSR(entityID, ttl string, csr []byte) (Certificate,errors.SDKError) + SignCSR(entityID, ttl string, csr []byte) (Certificate, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -578,20 +578,20 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKErro return csr, nil } -func (sdk mgSDK) SignCSR(entityID, ttl string, csr []byte) (Certificate,errors.SDKError) { +func (sdk mgSDK) SignCSR(entityID, ttl string, csr []byte) (Certificate, errors.SDKError) { pm := PageMetadata{ TTL: ttl, } url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, entityID), pm) if err != nil { - return Certificate{},errors.NewSDKError(err) + return Certificate{}, errors.NewSDKError(err) } _, _, sdkerr := sdk.processRequest(http.MethodPatch, url, nil, nil, http.StatusOK) if sdkerr != nil { - return Certificate{},sdkerr + return Certificate{}, sdkerr } - return Certificate{},nil + return Certificate{}, nil } func NewSDK(conf Config) SDK { diff --git a/service.go b/service.go index f36c3db..81ea9d8 100644 --- a/service.go +++ b/service.go @@ -441,8 +441,8 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey * }) csr := CSR{ - CSR: csrPEM, - PrivateKey: privKeyPEM, + CSR: csrPEM, + PrivateKey: privKeyPEM, } return csr, nil diff --git a/tracing/certs.go b/tracing/certs.go index 67e7314..d1fc2d9 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -111,5 +111,5 @@ func (tm *tracingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMe func (tm *tracingMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "sign_csr") defer span.End() - return tm.svc.SignCSR(ctx, entityID,ttl, csr) + return tm.svc.SignCSR(ctx, entityID, ttl, csr) } From 459ffb397d562c12b5035e29b15268b956abad56 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Fri, 29 Nov 2024 03:33:16 +0300 Subject: [PATCH 03/22] handle multiple private key types Signed-off-by: nyagamunene --- api/http/endpoint.go | 15 ++++++++++----- api/http/requests.go | 6 ++---- api/http/responses.go | 14 +++++++++----- api/http/transport.go | 10 +++++++--- api/logging.go | 2 +- api/metrics.go | 2 +- certs.go | 2 +- cli/certs.go | 5 +++-- sdk/sdk.go | 41 ++++++++++++++++++++++++++--------------- service.go | 28 +++++++++++++++++++++++++--- tracing/certs.go | 2 +- 11 files changed, 86 insertions(+), 41 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 43fa6ac..aa5ddfa 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -8,6 +8,7 @@ import ( "crypto" "crypto/x509" "encoding/pem" + "fmt" "math/rand" "strings" "time" @@ -316,15 +317,15 @@ func createCSREndpoint(svc certs.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return createCSRRes{created: false}, err } - csr, err := svc.CreateCSR(ctx, req.Metadata, req.privKey) if err != nil { return createCSRRes{created: false}, err } return createCSRRes{ - created: true, - CSR: csr, + created: true, + CSR: csr.CSR, + PrivateKey: csr.PrivateKey, }, nil } } @@ -342,8 +343,12 @@ func signCSREndpoint(svc certs.Service) endpoint.Endpoint { } return signCSRRes{ - crt: cert, - signed: true, + SerialNumber: cert.SerialNumber, + Certificate: string(cert.Certificate), + Revoked: cert.Revoked, + ExpiryTime: cert.ExpiryTime, + EntityID: cert.EntityID, + signed: true, }, nil } } diff --git a/api/http/requests.go b/api/http/requests.go index d148f37..4b5a0f9 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -4,8 +4,6 @@ package http import ( - "crypto/rsa" - "github.com/absmach/certs" "github.com/absmach/certs/errors" "golang.org/x/crypto/ocsp" @@ -92,8 +90,8 @@ func (req ocspReq) validate() error { type createCSRReq struct { Metadata certs.CSRMetadata `json:"metadata"` - PrivateKey []byte `json:"private_Key"` - privKey *rsa.PrivateKey + PrivateKey string `json:"private_key"` + privKey any } func (req createCSRReq) validate() error { diff --git a/api/http/responses.go b/api/http/responses.go index cc8ce13..471ff1b 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -9,7 +9,6 @@ import ( "net/http" "time" - "github.com/absmach/certs" "golang.org/x/crypto/ocsp" ) @@ -204,8 +203,9 @@ type fileDownloadRes struct { } type createCSRRes struct { - certs.CSR - created bool + CSR []byte `json:"csr"` + PrivateKey []byte `json:"private_key"` + created bool } func (res createCSRRes) Code() int { @@ -225,8 +225,12 @@ func (res createCSRRes) Empty() bool { } type signCSRRes struct { - crt certs.Certificate - signed bool + SerialNumber string `json:"serial_number"` + Certificate string `json:"certificate,omitempty"` + Revoked bool `json:"revoked"` + ExpiryTime time.Time `json:"expiry_time"` + EntityID string `json:"entity_id"` + signed bool } func (res signCSRRes) Code() int { diff --git a/api/http/transport.go b/api/http/transport.go index 8a8574d..036d155 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -143,7 +143,7 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http opts..., ), "download_ca").ServeHTTP) r.Route("/csrs", func(r chi.Router) { - r.Post("/create", otelhttp.NewHandler(kithttp.NewServer( + r.Post("/", otelhttp.NewHandler(kithttp.NewServer( createCSREndpoint(svc), decodeCreateCSR, EncodeResponse, @@ -286,12 +286,16 @@ func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { return nil, err } - block, _ := pem.Decode(req.PrivateKey) + block, _ := pem.Decode([]byte(req.PrivateKey)) if block != nil { - privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + + privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) + fmt.Println(err) + if err != nil { return nil, errors.Wrap(ErrInvalidRequest, err) } + req.privKey = privateKey } diff --git a/api/logging.go b/api/logging.go index dcc6f9d..c249e4d 100644 --- a/api/logging.go +++ b/api/logging.go @@ -182,7 +182,7 @@ func (lm *loggingMiddleware) GetChainCA(ctx context.Context, token string) (cert return lm.svc.GetChainCA(ctx, token) } -func (lm *loggingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (csr certs.CSR, err error) { +func (lm *loggingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (csr certs.CSR, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method create_csr took %s to complete", time.Since(begin)) if err != nil { diff --git a/api/metrics.go b/api/metrics.go index f627b51..fcbb650 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -137,7 +137,7 @@ func (mm *metricsMiddleware) GetChainCA(ctx context.Context, token string) (cert return mm.svc.GetChainCA(ctx, token) } -func (mm *metricsMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) { +func (mm *metricsMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, error) { defer func(begin time.Time) { mm.counter.With("method", "create_csr").Add(1) mm.latency.With("method", "create_csr").Observe(time.Since(begin).Seconds()) diff --git a/certs.go b/certs.go index 71f0550..1208ae6 100644 --- a/certs.go +++ b/certs.go @@ -176,7 +176,7 @@ type Service interface { RemoveCert(ctx context.Context, entityId string) error // CreateCSR creates a new Certificate Signing Request - CreateCSR(ctx context.Context, metadata CSRMetadata, privKey *rsa.PrivateKey) (CSR, error) + CreateCSR(ctx context.Context, metadata CSRMetadata, privKey any) (CSR, error) // SignCSR parses and signs a CSR SignCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) diff --git a/cli/certs.go b/cli/certs.go index 4a56b6b..cce9234 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -7,6 +7,7 @@ import ( "encoding/json" "os" + "github.com/absmach/certs/errors" ctxsdk "github.com/absmach/certs/sdk" "github.com/spf13/cobra" ) @@ -243,14 +244,14 @@ var cmdCerts = []cobra.Command{ Short: "Create CSR", Long: `Creates a CSR.`, Run: func(cmd *cobra.Command, args []string) { - if len(args) != 0 { + if len(args) != 2 { logUsageCmd(*cmd, cmd.Use) return } var pm ctxsdk.PageMetadata if err := json.Unmarshal([]byte(args[0]), &pm); err != nil { - logErrorCmd(*cmd, err) + logErrorCmd(*cmd, errors.New("here 1")) return } diff --git a/sdk/sdk.go b/sdk/sdk.go index e1438db..23088e9 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -549,17 +549,19 @@ func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) { func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) { r := csrReq{ - Organization: pm.Organization, - OrganizationalUnit: pm.OrganizationalUnit, - Country: pm.Country, - Province: pm.Province, - Locality: pm.Locality, - StreetAddress: pm.StreetAddress, - PostalCode: pm.PostalCode, - DNSNames: pm.DNSNames, - IPAddresses: pm.IPAddresses, - EmailAddresses: pm.EmailAddresses, - PrivateKey: privKey, + Metadata: meta{ + Organization: pm.Organization, + OrganizationalUnit: pm.OrganizationalUnit, + Country: pm.Country, + Province: pm.Province, + Locality: pm.Locality, + StreetAddress: pm.StreetAddress, + PostalCode: pm.PostalCode, + DNSNames: pm.DNSNames, + IPAddresses: pm.IPAddresses, + EmailAddresses: pm.EmailAddresses, + }, + PrivateKey: privKey, } d, err := json.Marshal(r) if err != nil { @@ -568,7 +570,7 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKErro url := fmt.Sprintf("%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint) _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK) if sdkerr != nil { - return CSR{}, sdkerr + return CSR{}, errors.NewSDKError(err) } var csr CSR @@ -587,11 +589,16 @@ func (sdk mgSDK) SignCSR(entityID, ttl string, csr []byte) (Certificate, errors. return Certificate{}, errors.NewSDKError(err) } - _, _, sdkerr := sdk.processRequest(http.MethodPatch, url, nil, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(http.MethodPost, url, nil, nil, http.StatusOK) if sdkerr != nil { return Certificate{}, sdkerr } - return Certificate{}, nil + + var cert Certificate + if err := json.Unmarshal(body, &cert); err != nil { + return Certificate{}, errors.NewSDKError(err) + } + return cert, nil } func NewSDK(conf Config) SDK { @@ -703,6 +710,11 @@ type certReq struct { } type csrReq struct { + Metadata meta `json:"metadata"` + PrivateKey []byte `json:"private_key"` +} + +type meta struct { Organization []string `json:"organization"` OrganizationalUnit []string `json:"organizational_unit"` Country []string `json:"country"` @@ -713,5 +725,4 @@ type csrReq struct { DNSNames []string `json:"dns_names"` IPAddresses []string `json:"ip_addresses"` EmailAddresses []string `json:"email_addresses"` - PrivateKey []byte `json:"private_key"` } diff --git a/service.go b/service.go index 81ea9d8..28c97eb 100644 --- a/service.go +++ b/service.go @@ -5,6 +5,8 @@ package certs import ( "context" + "crypto/ecdsa" + "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/x509" @@ -402,7 +404,7 @@ func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, er return s.getConcatCAs(ctx) } -func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey *rsa.PrivateKey) (CSR, error) { +func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey any) (CSR, error) { template := &x509.CertificateRequest{ Subject: pkix.Name{ CommonName: metadata.CommonName, @@ -435,9 +437,29 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey * Bytes: csrBytes, }) + var privKeyBytes []byte + var privKeyType string + switch key := privKey.(type) { + case *rsa.PrivateKey: + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = "RSA PRIVATE KEY" + case *ecdsa.PrivateKey: + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = "EC PRIVATE KEY" + case ed25519.PrivateKey: + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = "PRIVATE KEY" + default: + return CSR{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type")) + } + + if err != nil { + return CSR{}, errors.Wrap(ErrCreateEntity, err) + } + privKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(privKey), + Type: privKeyType, + Bytes: privKeyBytes, }) csr := CSR{ diff --git a/tracing/certs.go b/tracing/certs.go index d1fc2d9..4a720db 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -102,7 +102,7 @@ func (tm *tracingMiddleware) GetChainCA(ctx context.Context, token string) (cert return tm.svc.GetChainCA(ctx, token) } -func (tm *tracingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) { +func (tm *tracingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, error) { ctx, span := tm.tracer.Start(ctx, "create_csr") defer span.End() return tm.svc.CreateCSR(ctx, metadata, privKey) From 67e51e211a1e51e741c1bf1a5a9c5c875f124186 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Fri, 29 Nov 2024 03:41:00 +0300 Subject: [PATCH 04/22] fix failing linter Signed-off-by: nyagamunene --- api/http/endpoint.go | 1 - mocks/service.go | 16 ++++++++-------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index aa5ddfa..c3017d5 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -8,7 +8,6 @@ import ( "crypto" "crypto/x509" "encoding/pem" - "fmt" "math/rand" "strings" "time" diff --git a/mocks/service.go b/mocks/service.go index 978ab36..84daa16 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -31,7 +31,7 @@ func (_m *MockService) EXPECT() *MockService_Expecter { } // CreateCSR provides a mock function with given fields: ctx, metadata, privKey -func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) { +func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey interface{}) (certs.CSR, error) { ret := _m.Called(ctx, metadata, privKey) if len(ret) == 0 { @@ -40,16 +40,16 @@ func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata var r0 certs.CSR var r1 error - if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, *rsa.PrivateKey) (certs.CSR, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, interface{}) (certs.CSR, error)); ok { return rf(ctx, metadata, privKey) } - if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, *rsa.PrivateKey) certs.CSR); ok { + if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, interface{}) certs.CSR); ok { r0 = rf(ctx, metadata, privKey) } else { r0 = ret.Get(0).(certs.CSR) } - if rf, ok := ret.Get(1).(func(context.Context, certs.CSRMetadata, *rsa.PrivateKey) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, certs.CSRMetadata, interface{}) error); ok { r1 = rf(ctx, metadata, privKey) } else { r1 = ret.Error(1) @@ -66,14 +66,14 @@ type MockService_CreateCSR_Call struct { // CreateCSR is a helper method to define mock.On call // - ctx context.Context // - metadata certs.CSRMetadata -// - privKey *rsa.PrivateKey +// - privKey interface{} func (_e *MockService_Expecter) CreateCSR(ctx interface{}, metadata interface{}, privKey interface{}) *MockService_CreateCSR_Call { return &MockService_CreateCSR_Call{Call: _e.mock.On("CreateCSR", ctx, metadata, privKey)} } -func (_c *MockService_CreateCSR_Call) Run(run func(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey)) *MockService_CreateCSR_Call { +func (_c *MockService_CreateCSR_Call) Run(run func(ctx context.Context, metadata certs.CSRMetadata, privKey interface{})) *MockService_CreateCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(certs.CSRMetadata), args[2].(*rsa.PrivateKey)) + run(args[0].(context.Context), args[1].(certs.CSRMetadata), args[2].(interface{})) }) return _c } @@ -83,7 +83,7 @@ func (_c *MockService_CreateCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockServ return _c } -func (_c *MockService_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSRMetadata, *rsa.PrivateKey) (certs.CSR, error)) *MockService_CreateCSR_Call { +func (_c *MockService_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSRMetadata, interface{}) (certs.CSR, error)) *MockService_CreateCSR_Call { _c.Call.Return(run) return _c } From cd6f2fd2c22441fc737d517c3cad8944e3346683 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Fri, 29 Nov 2024 03:55:13 +0300 Subject: [PATCH 05/22] fix failing linter Signed-off-by: nyagamunene --- api/http/transport.go | 1 - 1 file changed, 1 deletion(-) diff --git a/api/http/transport.go b/api/http/transport.go index 036d155..a6e91dd 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -288,7 +288,6 @@ func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { block, _ := pem.Decode([]byte(req.PrivateKey)) if block != nil { - privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) fmt.Println(err) From c795d21af4f83095e869dc4b79f908c15073a035 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Fri, 29 Nov 2024 10:47:29 +0300 Subject: [PATCH 06/22] fix parsing in transport layer Signed-off-by: nyagamunene --- api/http/transport.go | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/api/http/transport.go b/api/http/transport.go index a6e91dd..b7f837d 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -287,16 +287,32 @@ func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { } block, _ := pem.Decode([]byte(req.PrivateKey)) - if block != nil { - privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes) - fmt.Println(err) + if block == nil { + return nil, errors.Wrap(ErrInvalidRequest, errors.New("invalid PEM format")) + } - if err != nil { - return nil, errors.Wrap(ErrInvalidRequest, err) - } + var ( + privateKey any + err error + ) - req.privKey = privateKey + switch block.Type { + case "RSA PRIVATE KEY": + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + privateKey, err = x509.ParseECPrivateKey(block.Bytes) + case "PRIVATE KEY", "PKCS8 PRIVATE KEY": + privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + case "ED25519 PRIVATE KEY": + privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + default: + err = errors.New("unsupported private key type") + } + + if err != nil { + return nil, errors.Wrap(ErrInvalidRequest, err) } + req.privKey = privateKey return req, nil } From 5374de7e32e5d14f491d3759246f5ba2e831e988 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 00:54:19 +0300 Subject: [PATCH 07/22] Remove logging Signed-off-by: nyagamunene --- api/http/endpoint.go | 9 ++++---- api/http/responses.go | 10 ++------- cli/certs.go | 9 ++++---- cli/utils.go | 16 ++++++++++++++ sdk/mocks/sdk.go | 32 ++++++++++++++-------------- sdk/sdk.go | 49 +++++++++++++++++++++++++++++++------------ service.go | 1 - 7 files changed, 78 insertions(+), 48 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index c3017d5..6554d97 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -314,17 +314,16 @@ func createCSREndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { req := request.(createCSRReq) if err := req.validate(); err != nil { - return createCSRRes{created: false}, err + return createCSRRes{}, err } + csr, err := svc.CreateCSR(ctx, req.Metadata, req.privKey) if err != nil { - return createCSRRes{created: false}, err + return createCSRRes{}, err } return createCSRRes{ - created: true, - CSR: csr.CSR, - PrivateKey: csr.PrivateKey, + CSR: string(csr.CSR), }, nil } } diff --git a/api/http/responses.go b/api/http/responses.go index 471ff1b..2be1618 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -203,17 +203,11 @@ type fileDownloadRes struct { } type createCSRRes struct { - CSR []byte `json:"csr"` - PrivateKey []byte `json:"private_key"` - created bool + CSR string `json:"csr"` } func (res createCSRRes) Code() int { - if res.created { - return http.StatusCreated - } - - return http.StatusNoContent + return http.StatusCreated } func (res createCSRRes) Headers() map[string]string { diff --git a/cli/certs.go b/cli/certs.go index cce9234..0f18d92 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -7,7 +7,6 @@ import ( "encoding/json" "os" - "github.com/absmach/certs/errors" ctxsdk "github.com/absmach/certs/sdk" "github.com/spf13/cobra" ) @@ -251,7 +250,7 @@ var cmdCerts = []cobra.Command{ var pm ctxsdk.PageMetadata if err := json.Unmarshal([]byte(args[0]), &pm); err != nil { - logErrorCmd(*cmd, errors.New("here 1")) + logErrorCmd(*cmd, err) return } @@ -261,13 +260,13 @@ var cmdCerts = []cobra.Command{ return } - csr, err := sdk.CreateCSR(pm, data) + csr, err := sdk.CreateCSR(pm, string(data)) if err != nil { logErrorCmd(*cmd, err) return } - logJSONCmd(*cmd, csr) + logSaveCSRFiles(*cmd, csr) }, }, { @@ -286,7 +285,7 @@ var cmdCerts = []cobra.Command{ return } - cert, err := sdk.SignCSR(args[0], args[1], data) + cert, err := sdk.SignCSR(args[0], args[1], string(data)) if err != nil { logErrorCmd(*cmd, err) return diff --git a/cli/utils.go b/cli/utils.go index d5e413e..58e8b14 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -98,6 +98,22 @@ func logSaveCAFiles(cmd cobra.Command, certBundle ctxsdk.CertificateBundle) { fmt.Fprintf(cmd.OutOrStdout(), "\nAll certificate files have been saved successfully.\n") } +func logSaveCSRFiles(cmd cobra.Command, csr ctxsdk.CSR) { + files := map[string][]byte{ + "file.csr": []byte(csr.CSR), + } + + for filename, content := range files { + err := saveToFile(filename, content) + if err != nil { + logErrorCmd(cmd, err) + return + } + fmt.Fprintf(cmd.OutOrStdout(), "Saved %s\n", filename) + } + fmt.Fprintf(cmd.OutOrStdout(), "\nCSR file have been saved successfully.\n") +} + func saveToFile(filename string, content []byte) error { cwd, err := os.Getwd() if err != nil { diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index 8384bc5..ce5cd68 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -26,7 +26,7 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { } // CreateCSR provides a mock function with given fields: pm, privKey -func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKey []byte) (sdk.CSR, errors.SDKError) { +func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKey string) (sdk.CSR, errors.SDKError) { ret := _m.Called(pm, privKey) if len(ret) == 0 { @@ -35,16 +35,16 @@ func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKey []byte) (sdk.CSR, erro var r0 sdk.CSR var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(sdk.PageMetadata, []byte) (sdk.CSR, errors.SDKError)); ok { + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)); ok { return rf(pm, privKey) } - if rf, ok := ret.Get(0).(func(sdk.PageMetadata, []byte) sdk.CSR); ok { + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.CSR); ok { r0 = rf(pm, privKey) } else { r0 = ret.Get(0).(sdk.CSR) } - if rf, ok := ret.Get(1).(func(sdk.PageMetadata, []byte) errors.SDKError); ok { + if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok { r1 = rf(pm, privKey) } else { if ret.Get(1) != nil { @@ -62,14 +62,14 @@ type MockSDK_CreateCSR_Call struct { // CreateCSR is a helper method to define mock.On call // - pm sdk.PageMetadata -// - privKey []byte +// - privKey string func (_e *MockSDK_Expecter) CreateCSR(pm interface{}, privKey interface{}) *MockSDK_CreateCSR_Call { return &MockSDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", pm, privKey)} } -func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKey []byte)) *MockSDK_CreateCSR_Call { +func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKey string)) *MockSDK_CreateCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata), args[1].([]byte)) + run(args[0].(sdk.PageMetadata), args[1].(string)) }) return _c } @@ -79,7 +79,7 @@ func (_c *MockSDK_CreateCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *Mock return _c } -func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, []byte) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { +func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { _c.Call.Return(run) return _c } @@ -639,7 +639,7 @@ func (_c *MockSDK_RevokeCert_Call) RunAndReturn(run func(string) errors.SDKError } // SignCSR provides a mock function with given fields: entityID, ttl, csr -func (_m *MockSDK) SignCSR(entityID string, ttl string, csr []byte) (sdk.Certificate, errors.SDKError) { +func (_m *MockSDK) SignCSR(entityID string, ttl string, csr string) (sdk.Certificate, errors.SDKError) { ret := _m.Called(entityID, ttl, csr) if len(ret) == 0 { @@ -648,16 +648,16 @@ func (_m *MockSDK) SignCSR(entityID string, ttl string, csr []byte) (sdk.Certifi var r0 sdk.Certificate var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string, []byte) (sdk.Certificate, errors.SDKError)); ok { + if rf, ok := ret.Get(0).(func(string, string, string) (sdk.Certificate, errors.SDKError)); ok { return rf(entityID, ttl, csr) } - if rf, ok := ret.Get(0).(func(string, string, []byte) sdk.Certificate); ok { + if rf, ok := ret.Get(0).(func(string, string, string) sdk.Certificate); ok { r0 = rf(entityID, ttl, csr) } else { r0 = ret.Get(0).(sdk.Certificate) } - if rf, ok := ret.Get(1).(func(string, string, []byte) errors.SDKError); ok { + if rf, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { r1 = rf(entityID, ttl, csr) } else { if ret.Get(1) != nil { @@ -676,14 +676,14 @@ type MockSDK_SignCSR_Call struct { // SignCSR is a helper method to define mock.On call // - entityID string // - ttl string -// - csr []byte +// - csr string func (_e *MockSDK_Expecter) SignCSR(entityID interface{}, ttl interface{}, csr interface{}) *MockSDK_SignCSR_Call { return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", entityID, ttl, csr)} } -func (_c *MockSDK_SignCSR_Call) Run(run func(entityID string, ttl string, csr []byte)) *MockSDK_SignCSR_Call { +func (_c *MockSDK_SignCSR_Call) Run(run func(entityID string, ttl string, csr string)) *MockSDK_SignCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].([]byte)) + run(args[0].(string), args[1].(string), args[2].(string)) }) return _c } @@ -693,7 +693,7 @@ func (_c *MockSDK_SignCSR_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKError) return _c } -func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, string, []byte) (sdk.Certificate, errors.SDKError)) *MockSDK_SignCSR_Call { +func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_SignCSR_Call { _c.Call.Return(run) return _c } diff --git a/sdk/sdk.go b/sdk/sdk.go index 23088e9..93e1d1d 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -176,8 +176,7 @@ type CSRMetadata struct { } type CSR struct { - CSR []byte `json:"csr,omitempty"` - PrivateKey []byte `json:"private_key,omitempty"` + CSR []byte `json:"csr,omitempty"` } type SDK interface { @@ -268,17 +267,17 @@ type SDK interface { // CreateCSR creates a new Certificate Signing Request // // example: - // pm = sdk.CSRMetadata{CommonName: "common_name", EntityID: "entity_id" } - // response, _ := sdk.CreateCSR(pm, []bytes("privKey")) + // pm = sdk.CSRMetadata{CommonName: "common_name"} + // response, _ := sdk.CreateCSR(pm, "privKey") // fmt.Println(response) - CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) + CreateCSR(pm PageMetadata, privKey string) (CSR, errors.SDKError) // SignCSR processes a pending CSR and either signs or rejects it // // example: - // certs, err := sdk.SignCSR( "entityID", "ttl", []bytes("csrFile")) + // certs, err := sdk.SignCSR( "entityID", "ttl", "csrFile") // fmt.Println(err) - SignCSR(entityID, ttl string, csr []byte) (Certificate, errors.SDKError) + SignCSR(entityID, ttl string, csr string) (Certificate, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -547,9 +546,10 @@ func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) { return tk, nil } -func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) { +func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey string) (CSR, errors.SDKError) { r := csrReq{ Metadata: meta{ + CommonName: pm.CommonName, Organization: pm.Organization, OrganizationalUnit: pm.OrganizationalUnit, Country: pm.Country, @@ -573,23 +573,44 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKErro return CSR{}, errors.NewSDKError(err) } - var csr CSR - if err := json.Unmarshal(body, &csr); err != nil { + zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { return CSR{}, errors.NewSDKError(err) } + + var csr CSR + for _, file := range zipReader.File { + fileContent, err := readZipFile(file) + if err != nil { + return CSR{}, errors.NewSDKError(err) + } + + csr.CSR = fileContent + } + return csr, nil } -func (sdk mgSDK) SignCSR(entityID, ttl string, csr []byte) (Certificate, errors.SDKError) { +func (sdk mgSDK) SignCSR(entityID, ttl string, csr string) (Certificate, errors.SDKError) { pm := PageMetadata{ TTL: ttl, } + + r := csrReq{ + CSR: csr, + } + + d, err := json.Marshal(r) + if err != nil { + return Certificate{}, errors.NewSDKError(err) + } + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, entityID), pm) if err != nil { return Certificate{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodPost, url, nil, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK) if sdkerr != nil { return Certificate{}, sdkerr } @@ -711,10 +732,12 @@ type certReq struct { type csrReq struct { Metadata meta `json:"metadata"` - PrivateKey []byte `json:"private_key"` + PrivateKey string `json:"private_key"` + CSR string `json:"csr"` } type meta struct { + CommonName string `json:"common_name"` Organization []string `json:"organization"` OrganizationalUnit []string `json:"organizational_unit"` Country []string `json:"country"` diff --git a/service.go b/service.go index 28c97eb..6eaa1dc 100644 --- a/service.go +++ b/service.go @@ -426,7 +426,6 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey a template.IPAddresses = append(template.IPAddresses, parsedIP) } } - csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) if err != nil { return CSR{}, errors.Wrap(ErrCreateEntity, err) From 8588a8fefca7433dd5d5962d1e43d20f5f3336d5 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 02:16:35 +0300 Subject: [PATCH 08/22] add multiple key type support Signed-off-by: nyagamunene --- api/http/endpoint.go | 2 +- api/http/requests.go | 7 ++++--- api/http/transport.go | 10 ++++++++-- cli/certs.go | 16 +++++++++++----- sdk/sdk.go | 15 ++++++++------- service.go | 24 ++++++++++++++++++++---- 6 files changed, 52 insertions(+), 22 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 6554d97..5743c7b 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -335,7 +335,7 @@ func signCSREndpoint(svc certs.Service) endpoint.Endpoint { return signCSRRes{signed: false}, err } - cert, err := svc.SignCSR(ctx, req.entityID, req.ttl, certs.CSR{CSR: req.CSR}) + cert, err := svc.SignCSR(ctx, req.entityID, req.ttl, certs.CSR{CSR: []byte(req.CSR), PrivateKey: []byte(req.PrivateKey)}) if err != nil { return signCSRRes{signed: false}, err } diff --git a/api/http/requests.go b/api/http/requests.go index 4b5a0f9..6119aae 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -106,9 +106,10 @@ func (req createCSRReq) validate() error { } type SignCSRReq struct { - entityID string - ttl string - CSR []byte `json:"csr"` + entityID string + ttl string + CSR string `json:"csr"` + PrivateKey string `json:"private_key"` } func (req SignCSRReq) validate() error { diff --git a/api/http/transport.go b/api/http/transport.go index b7f837d..72df1af 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -328,8 +328,14 @@ func decodeSignCSR(_ context.Context, r *http.Request) (interface{}, error) { ttl: t, } - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, errors.Wrap(ErrInvalidRequest, errors.New("failed to read request body")) + } + defer r.Body.Close() + + if err := json.Unmarshal(body, &req); err != nil { + return nil, errors.Wrap(ErrInvalidRequest, errors.New("failed to decode JSON")) } return req, nil diff --git a/cli/certs.go b/cli/certs.go index 0f18d92..ca97a97 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -270,22 +270,28 @@ var cmdCerts = []cobra.Command{ }, }, { - Use: "sign ", + Use: "sign ", Short: "Sign CSR", - Long: `Signs a CSR for a given csr id.`, + Long: `Signs a CSR for a given csr.`, Run: func(cmd *cobra.Command, args []string) { - if len(args) != 3 { + if len(args) != 4 { logUsageCmd(*cmd, cmd.Use) return } - data, err := os.ReadFile(args[2]) + csrData, err := os.ReadFile(args[2]) if err != nil { logErrorCmd(*cmd, err) return } - cert, err := sdk.SignCSR(args[0], args[1], string(data)) + privData, err := os.ReadFile(args[3]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + cert, err := sdk.SignCSR(args[0], args[1], string(csrData), string(privData)) if err != nil { logErrorCmd(*cmd, err) return diff --git a/sdk/sdk.go b/sdk/sdk.go index 93e1d1d..82e66f8 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -275,9 +275,9 @@ type SDK interface { // SignCSR processes a pending CSR and either signs or rejects it // // example: - // certs, err := sdk.SignCSR( "entityID", "ttl", "csrFile") + // certs, err := sdk.SignCSR( "entityID", "ttl", "csrFile", "privKey") // fmt.Println(err) - SignCSR(entityID, ttl string, csr string) (Certificate, errors.SDKError) + SignCSR(entityID, ttl string, csr, privKey string) (Certificate, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -591,13 +591,14 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey string) (CSR, errors.SDKErro return csr, nil } -func (sdk mgSDK) SignCSR(entityID, ttl string, csr string) (Certificate, errors.SDKError) { +func (sdk mgSDK) SignCSR(entityID, ttl string, csr, privKey string) (Certificate, errors.SDKError) { pm := PageMetadata{ TTL: ttl, } r := csrReq{ - CSR: csr, + CSR: csr, + PrivateKey: privKey, } d, err := json.Marshal(r) @@ -731,9 +732,9 @@ type certReq struct { } type csrReq struct { - Metadata meta `json:"metadata"` - PrivateKey string `json:"private_key"` - CSR string `json:"csr"` + Metadata meta `json:"metadata,omitempty"` + PrivateKey string `json:"private_key,omitempty"` + CSR string `json:"csr,omitempty"` } type meta struct { diff --git a/service.go b/service.go index 6eaa1dc..88de439 100644 --- a/service.go +++ b/service.go @@ -829,16 +829,32 @@ func (s *service) loadCACerts(ctx context.Context) error { return nil } -func extractPrivateKey(pemKey []byte) (*rsa.PrivateKey, error) { +func extractPrivateKey(pemKey []byte) (any, error) { block, _ := pem.Decode(pemKey) if block == nil { return nil, errors.New("failed to parse private key PEM") } - privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + var ( + privateKey any + err error + ) + + switch block.Type { + case "RSA PRIVATE KEY": + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + privateKey, err = x509.ParseECPrivateKey(block.Bytes) + case "PRIVATE KEY", "PKCS8 PRIVATE KEY": + privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + case "ED25519 PRIVATE KEY": + privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + default: + err = errors.New("unsupported private key type") + } if err != nil { - return nil, err + return nil, errors.New("failed to parse key") } - return privKey, nil + return privateKey, nil } From a3fcde7bb602930cfdfde2b5dcb9aacf5d847ad3 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 11:20:31 +0300 Subject: [PATCH 09/22] update issue cert method Signed-off-by: nyagamunene --- api/logging.go | 3 +-- api/metrics.go | 3 +-- certs.go | 2 +- mocks/service.go | 26 ++++++++++---------------- sdk/mocks/sdk.go | 29 +++++++++++++++-------------- service.go | 38 ++++++++++++++++++++++++++++++++------ tracing/certs.go | 3 +-- 7 files changed, 61 insertions(+), 43 deletions(-) diff --git a/api/logging.go b/api/logging.go index c249e4d..4e797ea 100644 --- a/api/logging.go +++ b/api/logging.go @@ -5,7 +5,6 @@ package api import ( "context" - "crypto/rsa" "crypto/x509" "fmt" "log/slog" @@ -86,7 +85,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s return lm.svc.RetrieveCAToken(ctx) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { diff --git a/api/metrics.go b/api/metrics.go index fcbb650..d9db09c 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -5,7 +5,6 @@ package api import ( "context" - "crypto/rsa" "crypto/x509" "time" @@ -72,7 +71,7 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error return mm.svc.RetrieveCAToken(ctx) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) diff --git a/certs.go b/certs.go index 1208ae6..8cc1a2f 100644 --- a/certs.go +++ b/certs.go @@ -158,7 +158,7 @@ type Service interface { RetrieveCAToken(ctx context.Context) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...*rsa.PrivateKey) (Certificate, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...any) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) diff --git a/mocks/service.go b/mocks/service.go index 84daa16..5e91bc7 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -12,8 +12,6 @@ import ( mock "github.com/stretchr/testify/mock" - rsa "crypto/rsa" - x509 "crypto/x509" ) @@ -262,14 +260,10 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s } // IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option, privKey -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { - _va := make([]interface{}, len(privKey)) - for _i := range privKey { - _va[_i] = privKey[_i] - } +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...interface{}) (certs.Certificate, error) { var _ca []interface{} _ca = append(_ca, ctx, entityID, ttl, ipAddrs, option) - _ca = append(_ca, _va...) + _ca = append(_ca, privKey...) ret := _m.Called(_ca...) if len(ret) == 0 { @@ -278,16 +272,16 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) (certs.Certificate, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) (certs.Certificate, error)); ok { return rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) certs.Certificate); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) certs.Certificate); ok { r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r0 = ret.Get(0).(certs.Certificate) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) error); ok { r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r1 = ret.Error(1) @@ -307,18 +301,18 @@ type MockService_IssueCert_Call struct { // - ttl string // - ipAddrs []string // - option certs.SubjectOptions -// - privKey ...*rsa.PrivateKey +// - privKey ...interface{} func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey ...interface{}) *MockService_IssueCert_Call { return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", append([]interface{}{ctx, entityID, ttl, ipAddrs, option}, privKey...)...)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...*rsa.PrivateKey)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...interface{})) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]*rsa.PrivateKey, len(args)-5) + variadicArgs := make([]interface{}, len(args)-5) for i, a := range args[5:] { if a != nil { - variadicArgs[i] = a.(*rsa.PrivateKey) + variadicArgs[i] = a.(interface{}) } } run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), variadicArgs...) @@ -331,7 +325,7 @@ func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) * return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...*rsa.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) (certs.Certificate, error)) *MockService_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index ce5cd68..ad87e77 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -638,9 +638,9 @@ func (_c *MockSDK_RevokeCert_Call) RunAndReturn(run func(string) errors.SDKError return _c } -// SignCSR provides a mock function with given fields: entityID, ttl, csr -func (_m *MockSDK) SignCSR(entityID string, ttl string, csr string) (sdk.Certificate, errors.SDKError) { - ret := _m.Called(entityID, ttl, csr) +// SignCSR provides a mock function with given fields: entityID, ttl, csr, privKey +func (_m *MockSDK) SignCSR(entityID string, ttl string, csr string, privKey string) (sdk.Certificate, errors.SDKError) { + ret := _m.Called(entityID, ttl, csr, privKey) if len(ret) == 0 { panic("no return value specified for SignCSR") @@ -648,17 +648,17 @@ func (_m *MockSDK) SignCSR(entityID string, ttl string, csr string) (sdk.Certifi var r0 sdk.Certificate var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string, string) (sdk.Certificate, errors.SDKError)); ok { - return rf(entityID, ttl, csr) + if rf, ok := ret.Get(0).(func(string, string, string, string) (sdk.Certificate, errors.SDKError)); ok { + return rf(entityID, ttl, csr, privKey) } - if rf, ok := ret.Get(0).(func(string, string, string) sdk.Certificate); ok { - r0 = rf(entityID, ttl, csr) + if rf, ok := ret.Get(0).(func(string, string, string, string) sdk.Certificate); ok { + r0 = rf(entityID, ttl, csr, privKey) } else { r0 = ret.Get(0).(sdk.Certificate) } - if rf, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { - r1 = rf(entityID, ttl, csr) + if rf, ok := ret.Get(1).(func(string, string, string, string) errors.SDKError); ok { + r1 = rf(entityID, ttl, csr, privKey) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -677,13 +677,14 @@ type MockSDK_SignCSR_Call struct { // - entityID string // - ttl string // - csr string -func (_e *MockSDK_Expecter) SignCSR(entityID interface{}, ttl interface{}, csr interface{}) *MockSDK_SignCSR_Call { - return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", entityID, ttl, csr)} +// - privKey string +func (_e *MockSDK_Expecter) SignCSR(entityID interface{}, ttl interface{}, csr interface{}, privKey interface{}) *MockSDK_SignCSR_Call { + return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", entityID, ttl, csr, privKey)} } -func (_c *MockSDK_SignCSR_Call) Run(run func(entityID string, ttl string, csr string)) *MockSDK_SignCSR_Call { +func (_c *MockSDK_SignCSR_Call) Run(run func(entityID string, ttl string, csr string, privKey string)) *MockSDK_SignCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string)) + run(args[0].(string), args[1].(string), args[2].(string), args[3].(string)) }) return _c } @@ -693,7 +694,7 @@ func (_c *MockSDK_SignCSR_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKError) return _c } -func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_SignCSR_Call { +func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_SignCSR_Call { _c.Call.Return(run) return _c } diff --git a/service.go b/service.go index 88de439..cfb1a4a 100644 --- a/service.go +++ b/service.go @@ -5,6 +5,7 @@ package certs import ( "context" + "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rand" @@ -88,17 +89,17 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...*rsa.PrivateKey) (Certificate, error) { - var privKey rsa.PrivateKey +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...any) (Certificate, error) { + var privKey any var err error if len(key) == 0 { pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - privKey = *pKey + privKey = pKey if err != nil { return Certificate{}, err } } else { - privKey = *key[0] + privKey = key[0] } serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { @@ -132,12 +133,37 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ DNSNames: append(s.intermediateCA.Certificate.DNSNames, ipAddrs...), } - certBytes, err := x509.CreateCertificate(rand.Reader, &template, s.intermediateCA.Certificate, &privKey.PublicKey, s.intermediateCA.PrivateKey) + var pubKey crypto.PublicKey + var privKeyBytes []byte + var privKeyType string + + switch key := privKey.(type) { + case *rsa.PrivateKey: + pubKey = key.Public() + privKeyBytes = x509.MarshalPKCS1PrivateKey(key) + privKeyType = "RSA PRIVATE KEY" + case *ecdsa.PrivateKey: + pubKey = key.Public() + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = "EC PRIVATE KEY" + case ed25519.PrivateKey: + pubKey = key.Public() + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = "PRIVATE KEY" + default: + return Certificate{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type")) + } + + if err != nil { + return Certificate{}, err + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, s.intermediateCA.Certificate, pubKey, s.intermediateCA.PrivateKey) if err != nil { return Certificate{}, err } dbCert := Certificate{ - Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(&privKey)}), + Key: pem.EncodeToMemory(&pem.Block{Type: privKeyType, Bytes: privKeyBytes}), Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), SerialNumber: template.SerialNumber.String(), EntityID: entityID, diff --git a/tracing/certs.go b/tracing/certs.go index 4a720db..b9fd6ce 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -5,7 +5,6 @@ package tracing import ( "context" - "crypto/rsa" "crypto/x509" "github.com/absmach/certs" @@ -54,7 +53,7 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error return tm.svc.RetrieveCAToken(ctx) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...*rsa.PrivateKey) (certs.Certificate, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) From 670bc5293bcd05430af50f7c5f7c45df0ae6f28f Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 11:30:53 +0300 Subject: [PATCH 10/22] make the strings to constants Signed-off-by: nyagamunene --- service.go | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/service.go b/service.go index cfb1a4a..f847e5c 100644 --- a/service.go +++ b/service.go @@ -33,6 +33,9 @@ const ( rCertExpiryThreshold = time.Hour * 24 * 30 // 30 days iCertExpiryThreshold = time.Hour * 24 * 10 // 10 days downloadTokenExpiry = time.Minute * 5 + PrivateKey = "PRIVATE KEY" + RSAPrivateKey = "RSA PRIVATE KEY" + ECPrivateKey = "EC PRIVATE KEY" ) var ( @@ -141,15 +144,15 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ case *rsa.PrivateKey: pubKey = key.Public() privKeyBytes = x509.MarshalPKCS1PrivateKey(key) - privKeyType = "RSA PRIVATE KEY" + privKeyType = RSAPrivateKey case *ecdsa.PrivateKey: pubKey = key.Public() privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = "EC PRIVATE KEY" + privKeyType = ECPrivateKey case ed25519.PrivateKey: pubKey = key.Public() privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = "PRIVATE KEY" + privKeyType = PrivateKey default: return Certificate{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type")) } @@ -467,13 +470,13 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey a switch key := privKey.(type) { case *rsa.PrivateKey: privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = "RSA PRIVATE KEY" + privKeyType = RSAPrivateKey case *ecdsa.PrivateKey: privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = "EC PRIVATE KEY" + privKeyType = ECPrivateKey case ed25519.PrivateKey: privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = "PRIVATE KEY" + privKeyType = PrivateKey default: return CSR{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type")) } @@ -615,7 +618,7 @@ func (s *service) generateRootCA(ctx context.Context, config Config) (*CA, error func (s *service) saveCA(ctx context.Context, cert *x509.Certificate, privateKey *rsa.PrivateKey, CertType CertType) error { dbCert := Certificate{ - Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}), + Key: pem.EncodeToMemory(&pem.Block{Type: RSAPrivateKey, Bytes: x509.MarshalPKCS1PrivateKey(privateKey)}), Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}), SerialNumber: cert.SerialNumber.String(), ExpiryTime: cert.NotAfter, @@ -867,11 +870,11 @@ func extractPrivateKey(pemKey []byte) (any, error) { ) switch block.Type { - case "RSA PRIVATE KEY": + case RSAPrivateKey: privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) - case "EC PRIVATE KEY": + case ECPrivateKey: privateKey, err = x509.ParseECPrivateKey(block.Bytes) - case "PRIVATE KEY", "PKCS8 PRIVATE KEY": + case PrivateKey, "PKCS8 PRIVATE KEY": privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) case "ED25519 PRIVATE KEY": privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) From 5ccc6f5369ad643ba698c6fa41cdb0d1a354069f Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 12:03:09 +0300 Subject: [PATCH 11/22] rename signcsr to issuefromcsr Signed-off-by: nyagamunene --- api/http/endpoint.go | 13 +++-- api/http/requests.go | 4 +- api/http/responses.go | 9 ++-- api/http/transport.go | 10 ++-- api/logging.go | 6 +-- api/metrics.go | 8 +-- certs.go | 6 +-- cli/certs.go | 10 ++-- mocks/service.go | 118 ++++++++++++++++++++-------------------- sdk/mocks/sdk.go | 122 +++++++++++++++++++++--------------------- sdk/sdk.go | 8 +-- service.go | 2 +- tracing/certs.go | 6 +-- 13 files changed, 160 insertions(+), 162 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 5743c7b..709a4e7 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -328,25 +328,24 @@ func createCSREndpoint(svc certs.Service) endpoint.Endpoint { } } -func signCSREndpoint(svc certs.Service) endpoint.Endpoint { +func issueFromCSREndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(SignCSRReq) + req := request.(IssueFromCSRReq) if err := req.validate(); err != nil { - return signCSRRes{signed: false}, err + return issueFromCSRRes{}, err } - cert, err := svc.SignCSR(ctx, req.entityID, req.ttl, certs.CSR{CSR: []byte(req.CSR), PrivateKey: []byte(req.PrivateKey)}) + cert, err := svc.IssueFromCSR(ctx, req.entityID, req.ttl, certs.CSR{CSR: []byte(req.CSR), PrivateKey: []byte(req.PrivateKey)}) if err != nil { - return signCSRRes{signed: false}, err + return issueFromCSRRes{}, err } - return signCSRRes{ + return issueFromCSRRes{ SerialNumber: cert.SerialNumber, Certificate: string(cert.Certificate), Revoked: cert.Revoked, ExpiryTime: cert.ExpiryTime, EntityID: cert.EntityID, - signed: true, }, nil } } diff --git a/api/http/requests.go b/api/http/requests.go index 6119aae..a6c5b31 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -105,14 +105,14 @@ func (req createCSRReq) validate() error { return nil } -type SignCSRReq struct { +type IssueFromCSRReq struct { entityID string ttl string CSR string `json:"csr"` PrivateKey string `json:"private_key"` } -func (req SignCSRReq) validate() error { +func (req IssueFromCSRReq) validate() error { if req.entityID == "" { return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) } diff --git a/api/http/responses.go b/api/http/responses.go index 2be1618..ae401fc 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -218,23 +218,22 @@ func (res createCSRRes) Empty() bool { return false } -type signCSRRes struct { +type issueFromCSRRes struct { SerialNumber string `json:"serial_number"` Certificate string `json:"certificate,omitempty"` Revoked bool `json:"revoked"` ExpiryTime time.Time `json:"expiry_time"` EntityID string `json:"entity_id"` - signed bool } -func (res signCSRRes) Code() int { +func (res issueFromCSRRes) Code() int { return http.StatusOK } -func (res signCSRRes) Headers() map[string]string { +func (res issueFromCSRRes) Headers() map[string]string { return map[string]string{} } -func (res signCSRRes) Empty() bool { +func (res issueFromCSRRes) Empty() bool { return false } diff --git a/api/http/transport.go b/api/http/transport.go index 72df1af..a589cb9 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -150,11 +150,11 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http opts..., ), "create_csr").ServeHTTP) r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( - signCSREndpoint(svc), - decodeSignCSR, + issueFromCSREndpoint(svc), + decodeIssueFromCSR, EncodeResponse, opts..., - ), "sign_csr").ServeHTTP) + ), "issue_from_csr").ServeHTTP) }) }) @@ -317,13 +317,13 @@ func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } -func decodeSignCSR(_ context.Context, r *http.Request) (interface{}, error) { +func decodeIssueFromCSR(_ context.Context, r *http.Request) (interface{}, error) { t, err := readStringQuery(r, ttl, "") if err != nil { return nil, err } - req := SignCSRReq{ + req := IssueFromCSRReq{ entityID: chi.URLParam(r, "entityID"), ttl: t, } diff --git a/api/logging.go b/api/logging.go index 4e797ea..00ca13f 100644 --- a/api/logging.go +++ b/api/logging.go @@ -193,14 +193,14 @@ func (lm *loggingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMe return lm.svc.CreateCSR(ctx, metadata, privKey) } -func (lm *loggingMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (c certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueFromCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (c certs.Certificate, err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method sign_csr took %s to complete", time.Since(begin)) + message := fmt.Sprintf("Method issue_from_csr took %s to complete", time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return } lm.logger.Info(message) }(time.Now()) - return lm.svc.SignCSR(ctx, entityID, ttl, csr) + return lm.svc.IssueFromCSR(ctx, entityID, ttl, csr) } diff --git a/api/metrics.go b/api/metrics.go index d9db09c..06053a5 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -144,10 +144,10 @@ func (mm *metricsMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMe return mm.svc.CreateCSR(ctx, metadata, privKey) } -func (mm *metricsMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueFromCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { defer func(begin time.Time) { - mm.counter.With("method", "sign_csr").Add(1) - mm.latency.With("method", "sign_csr").Observe(time.Since(begin).Seconds()) + mm.counter.With("method", "issue_from_csr").Add(1) + mm.latency.With("method", "issue_from_csr").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.SignCSR(ctx, entityID, ttl, csr) + return mm.svc.IssueFromCSR(ctx, entityID, ttl, csr) } diff --git a/certs.go b/certs.go index 8cc1a2f..b530f46 100644 --- a/certs.go +++ b/certs.go @@ -175,11 +175,11 @@ type Service interface { // RemoveCert deletes a cert for a provided entityID. RemoveCert(ctx context.Context, entityId string) error - // CreateCSR creates a new Certificate Signing Request + // CreateCSR creates a new Certificate Signing Request. CreateCSR(ctx context.Context, metadata CSRMetadata, privKey any) (CSR, error) - // SignCSR parses and signs a CSR - SignCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) + // IssueFromCSR creates a certificate from a given CSR. + IssueFromCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) } type Repository interface { diff --git a/cli/certs.go b/cli/certs.go index ca97a97..9955444 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -270,9 +270,9 @@ var cmdCerts = []cobra.Command{ }, }, { - Use: "sign ", - Short: "Sign CSR", - Long: `Signs a CSR for a given csr.`, + Use: "issue-csr ", + Short: "Issue from CSR", + Long: `issues a certificate for a given csr.`, Run: func(cmd *cobra.Command, args []string) { if len(args) != 4 { logUsageCmd(*cmd, cmd.Use) @@ -291,7 +291,7 @@ var cmdCerts = []cobra.Command{ return } - cert, err := sdk.SignCSR(args[0], args[1], string(csrData), string(privData)) + cert, err := sdk.IssueFromCSR(args[0], args[1], string(csrData), string(privData)) if err != nil { logErrorCmd(*cmd, err) return @@ -341,7 +341,7 @@ func NewCertsCmd() *cobra.Command { issueCmd.Flags().StringVar(&ttl, "ttl", "8760h", "certificate time to live in duration") cmd := cobra.Command{ - Use: "certs [issue | get | revoke | renew | ocsp | token | download | download-ca | download-ca | csr | sign]", + Use: "certs [issue | get | revoke | renew | ocsp | token | download | download-ca | download-ca | csr | issue-csr]", Short: "Certificates management", Long: `Certificates management: issue, get all, get by entity ID, revoke, renew, OCSP, token, download.`, } diff --git a/mocks/service.go b/mocks/service.go index 5e91bc7..e805c6e 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -330,6 +330,65 @@ func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, str return _c } +// IssueFromCSR provides a mock function with given fields: ctx, entityID, ttl, csr +func (_m *MockService) IssueFromCSR(ctx context.Context, entityID string, ttl string, csr certs.CSR) (certs.Certificate, error) { + ret := _m.Called(ctx, entityID, ttl, csr) + + if len(ret) == 0 { + panic("no return value specified for IssueFromCSR") + } + + var r0 certs.Certificate + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, certs.CSR) (certs.Certificate, error)); ok { + return rf(ctx, entityID, ttl, csr) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, certs.CSR) certs.Certificate); ok { + r0 = rf(ctx, entityID, ttl, csr) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, certs.CSR) error); ok { + r1 = rf(ctx, entityID, ttl, csr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_IssueFromCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IssueFromCSR' +type MockService_IssueFromCSR_Call struct { + *mock.Call +} + +// IssueFromCSR is a helper method to define mock.On call +// - ctx context.Context +// - entityID string +// - ttl string +// - csr certs.CSR +func (_e *MockService_Expecter) IssueFromCSR(ctx interface{}, entityID interface{}, ttl interface{}, csr interface{}) *MockService_IssueFromCSR_Call { + return &MockService_IssueFromCSR_Call{Call: _e.mock.On("IssueFromCSR", ctx, entityID, ttl, csr)} +} + +func (_c *MockService_IssueFromCSR_Call) Run(run func(ctx context.Context, entityID string, ttl string, csr certs.CSR)) *MockService_IssueFromCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(certs.CSR)) + }) + return _c +} + +func (_c *MockService_IssueFromCSR_Call) Return(_a0 certs.Certificate, _a1 error) *MockService_IssueFromCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_IssueFromCSR_Call) RunAndReturn(run func(context.Context, string, string, certs.CSR) (certs.Certificate, error)) *MockService_IssueFromCSR_Call { + _c.Call.Return(run) + return _c +} + // ListCerts provides a mock function with given fields: ctx, pm func (_m *MockService) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { ret := _m.Called(ctx, pm) @@ -783,65 +842,6 @@ func (_c *MockService_RevokeCert_Call) RunAndReturn(run func(context.Context, st return _c } -// SignCSR provides a mock function with given fields: ctx, entityID, ttl, csr -func (_m *MockService) SignCSR(ctx context.Context, entityID string, ttl string, csr certs.CSR) (certs.Certificate, error) { - ret := _m.Called(ctx, entityID, ttl, csr) - - if len(ret) == 0 { - panic("no return value specified for SignCSR") - } - - var r0 certs.Certificate - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, certs.CSR) (certs.Certificate, error)); ok { - return rf(ctx, entityID, ttl, csr) - } - if rf, ok := ret.Get(0).(func(context.Context, string, string, certs.CSR) certs.Certificate); ok { - r0 = rf(ctx, entityID, ttl, csr) - } else { - r0 = ret.Get(0).(certs.Certificate) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, string, certs.CSR) error); ok { - r1 = rf(ctx, entityID, ttl, csr) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockService_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' -type MockService_SignCSR_Call struct { - *mock.Call -} - -// SignCSR is a helper method to define mock.On call -// - ctx context.Context -// - entityID string -// - ttl string -// - csr certs.CSR -func (_e *MockService_Expecter) SignCSR(ctx interface{}, entityID interface{}, ttl interface{}, csr interface{}) *MockService_SignCSR_Call { - return &MockService_SignCSR_Call{Call: _e.mock.On("SignCSR", ctx, entityID, ttl, csr)} -} - -func (_c *MockService_SignCSR_Call) Run(run func(ctx context.Context, entityID string, ttl string, csr certs.CSR)) *MockService_SignCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(certs.CSR)) - }) - return _c -} - -func (_c *MockService_SignCSR_Call) Return(_a0 certs.Certificate, _a1 error) *MockService_SignCSR_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockService_SignCSR_Call) RunAndReturn(run func(context.Context, string, string, certs.CSR) (certs.Certificate, error)) *MockService_SignCSR_Call { - _c.Call.Return(run) - return _c -} - // ViewCert provides a mock function with given fields: ctx, serialNumber func (_m *MockService) ViewCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { ret := _m.Called(ctx, serialNumber) diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index ad87e77..210102d 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -367,6 +367,67 @@ func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string return _c } +// IssueFromCSR provides a mock function with given fields: entityID, ttl, csr, privKey +func (_m *MockSDK) IssueFromCSR(entityID string, ttl string, csr string, privKey string) (sdk.Certificate, errors.SDKError) { + ret := _m.Called(entityID, ttl, csr, privKey) + + if len(ret) == 0 { + panic("no return value specified for IssueFromCSR") + } + + var r0 sdk.Certificate + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(string, string, string, string) (sdk.Certificate, errors.SDKError)); ok { + return rf(entityID, ttl, csr, privKey) + } + if rf, ok := ret.Get(0).(func(string, string, string, string) sdk.Certificate); ok { + r0 = rf(entityID, ttl, csr, privKey) + } else { + r0 = ret.Get(0).(sdk.Certificate) + } + + if rf, ok := ret.Get(1).(func(string, string, string, string) errors.SDKError); ok { + r1 = rf(entityID, ttl, csr, privKey) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_IssueFromCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IssueFromCSR' +type MockSDK_IssueFromCSR_Call struct { + *mock.Call +} + +// IssueFromCSR is a helper method to define mock.On call +// - entityID string +// - ttl string +// - csr string +// - privKey string +func (_e *MockSDK_Expecter) IssueFromCSR(entityID interface{}, ttl interface{}, csr interface{}, privKey interface{}) *MockSDK_IssueFromCSR_Call { + return &MockSDK_IssueFromCSR_Call{Call: _e.mock.On("IssueFromCSR", entityID, ttl, csr, privKey)} +} + +func (_c *MockSDK_IssueFromCSR_Call) Run(run func(entityID string, ttl string, csr string, privKey string)) *MockSDK_IssueFromCSR_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string), args[1].(string), args[2].(string), args[3].(string)) + }) + return _c +} + +func (_c *MockSDK_IssueFromCSR_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKError) *MockSDK_IssueFromCSR_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_IssueFromCSR_Call) RunAndReturn(run func(string, string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_IssueFromCSR_Call { + _c.Call.Return(run) + return _c +} + // ListCerts provides a mock function with given fields: pm func (_m *MockSDK) ListCerts(pm sdk.PageMetadata) (sdk.CertificatePage, errors.SDKError) { ret := _m.Called(pm) @@ -638,67 +699,6 @@ func (_c *MockSDK_RevokeCert_Call) RunAndReturn(run func(string) errors.SDKError return _c } -// SignCSR provides a mock function with given fields: entityID, ttl, csr, privKey -func (_m *MockSDK) SignCSR(entityID string, ttl string, csr string, privKey string) (sdk.Certificate, errors.SDKError) { - ret := _m.Called(entityID, ttl, csr, privKey) - - if len(ret) == 0 { - panic("no return value specified for SignCSR") - } - - var r0 sdk.Certificate - var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string, string, string) (sdk.Certificate, errors.SDKError)); ok { - return rf(entityID, ttl, csr, privKey) - } - if rf, ok := ret.Get(0).(func(string, string, string, string) sdk.Certificate); ok { - r0 = rf(entityID, ttl, csr, privKey) - } else { - r0 = ret.Get(0).(sdk.Certificate) - } - - if rf, ok := ret.Get(1).(func(string, string, string, string) errors.SDKError); ok { - r1 = rf(entityID, ttl, csr, privKey) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(errors.SDKError) - } - } - - return r0, r1 -} - -// MockSDK_SignCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SignCSR' -type MockSDK_SignCSR_Call struct { - *mock.Call -} - -// SignCSR is a helper method to define mock.On call -// - entityID string -// - ttl string -// - csr string -// - privKey string -func (_e *MockSDK_Expecter) SignCSR(entityID interface{}, ttl interface{}, csr interface{}, privKey interface{}) *MockSDK_SignCSR_Call { - return &MockSDK_SignCSR_Call{Call: _e.mock.On("SignCSR", entityID, ttl, csr, privKey)} -} - -func (_c *MockSDK_SignCSR_Call) Run(run func(entityID string, ttl string, csr string, privKey string)) *MockSDK_SignCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string), args[3].(string)) - }) - return _c -} - -func (_c *MockSDK_SignCSR_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKError) *MockSDK_SignCSR_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockSDK_SignCSR_Call) RunAndReturn(run func(string, string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_SignCSR_Call { - _c.Call.Return(run) - return _c -} - // ViewCA provides a mock function with given fields: token func (_m *MockSDK) ViewCA(token string) (sdk.Certificate, errors.SDKError) { ret := _m.Called(token) diff --git a/sdk/sdk.go b/sdk/sdk.go index 82e66f8..5f371aa 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -272,12 +272,12 @@ type SDK interface { // fmt.Println(response) CreateCSR(pm PageMetadata, privKey string) (CSR, errors.SDKError) - // SignCSR processes a pending CSR and either signs or rejects it + // IssueFromCSR issues certificate from provided CSR // // example: - // certs, err := sdk.SignCSR( "entityID", "ttl", "csrFile", "privKey") + // certs, err := sdk.IssueFromCSR( "entityID", "ttl", "csrFile", "privKey") // fmt.Println(err) - SignCSR(entityID, ttl string, csr, privKey string) (Certificate, errors.SDKError) + IssueFromCSR(entityID, ttl string, csr, privKey string) (Certificate, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -591,7 +591,7 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey string) (CSR, errors.SDKErro return csr, nil } -func (sdk mgSDK) SignCSR(entityID, ttl string, csr, privKey string) (Certificate, errors.SDKError) { +func (sdk mgSDK) IssueFromCSR(entityID, ttl string, csr, privKey string) (Certificate, errors.SDKError) { pm := PageMetadata{ TTL: ttl, } diff --git a/service.go b/service.go index f847e5c..62266d6 100644 --- a/service.go +++ b/service.go @@ -498,7 +498,7 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey a return csr, nil } -func (s *service) SignCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) { +func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) { block, _ := pem.Decode(csr.CSR) if block == nil { return Certificate{}, errors.New("failed to parse CSR PEM") diff --git a/tracing/certs.go b/tracing/certs.go index b9fd6ce..fc34651 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -107,8 +107,8 @@ func (tm *tracingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMe return tm.svc.CreateCSR(ctx, metadata, privKey) } -func (tm *tracingMiddleware) SignCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { - ctx, span := tm.tracer.Start(ctx, "sign_csr") +func (tm *tracingMiddleware) IssueFromCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "issue_from_csr") defer span.End() - return tm.svc.SignCSR(ctx, entityID, ttl, csr) + return tm.svc.IssueFromCSR(ctx, entityID, ttl, csr) } From f21503f51654d44817639486b8b6fee5f667cee9 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 15:39:50 +0300 Subject: [PATCH 12/22] move create csr to cli tool Signed-off-by: nyagamunene --- api/http/endpoint.go | 18 ---------- api/http/transport.go | 45 ------------------------- api/logging.go | 12 ------- api/metrics.go | 8 ----- certs.go | 7 ++-- cli/certs.go | 77 +++++++++++++++++++++++++++++++++++++++++-- cli/utils.go | 3 +- service.go | 70 ++------------------------------------- tracing/certs.go | 6 ---- 9 files changed, 81 insertions(+), 165 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 709a4e7..ae1a30c 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -310,24 +310,6 @@ func viewCAEndpoint(svc certs.Service) endpoint.Endpoint { } } -func createCSREndpoint(svc certs.Service) endpoint.Endpoint { - return func(ctx context.Context, request interface{}) (response interface{}, err error) { - req := request.(createCSRReq) - if err := req.validate(); err != nil { - return createCSRRes{}, err - } - - csr, err := svc.CreateCSR(ctx, req.Metadata, req.privKey) - if err != nil { - return createCSRRes{}, err - } - - return createCSRRes{ - CSR: string(csr.CSR), - }, nil - } -} - func issueFromCSREndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { req := request.(IssueFromCSRReq) diff --git a/api/http/transport.go b/api/http/transport.go index a589cb9..112f3da 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -7,10 +7,8 @@ import ( "archive/zip" "bytes" "context" - "crypto/x509" "encoding/asn1" "encoding/json" - "encoding/pem" "fmt" "io" "log/slog" @@ -143,12 +141,6 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http opts..., ), "download_ca").ServeHTTP) r.Route("/csrs", func(r chi.Router) { - r.Post("/", otelhttp.NewHandler(kithttp.NewServer( - createCSREndpoint(svc), - decodeCreateCSR, - EncodeResponse, - opts..., - ), "create_csr").ServeHTTP) r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( issueFromCSREndpoint(svc), decodeIssueFromCSR, @@ -280,43 +272,6 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } -func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { - req := createCSRReq{} - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - return nil, err - } - - block, _ := pem.Decode([]byte(req.PrivateKey)) - if block == nil { - return nil, errors.Wrap(ErrInvalidRequest, errors.New("invalid PEM format")) - } - - var ( - privateKey any - err error - ) - - switch block.Type { - case "RSA PRIVATE KEY": - privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) - case "EC PRIVATE KEY": - privateKey, err = x509.ParseECPrivateKey(block.Bytes) - case "PRIVATE KEY", "PKCS8 PRIVATE KEY": - privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) - case "ED25519 PRIVATE KEY": - privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) - default: - err = errors.New("unsupported private key type") - } - - if err != nil { - return nil, errors.Wrap(ErrInvalidRequest, err) - } - req.privKey = privateKey - - return req, nil -} - func decodeIssueFromCSR(_ context.Context, r *http.Request) (interface{}, error) { t, err := readStringQuery(r, ttl, "") if err != nil { diff --git a/api/logging.go b/api/logging.go index 00ca13f..45b7f32 100644 --- a/api/logging.go +++ b/api/logging.go @@ -181,18 +181,6 @@ func (lm *loggingMiddleware) GetChainCA(ctx context.Context, token string) (cert return lm.svc.GetChainCA(ctx, token) } -func (lm *loggingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (csr certs.CSR, err error) { - defer func(begin time.Time) { - message := fmt.Sprintf("Method create_csr took %s to complete", time.Since(begin)) - if err != nil { - lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) - return - } - lm.logger.Info(message) - }(time.Now()) - return lm.svc.CreateCSR(ctx, metadata, privKey) -} - func (lm *loggingMiddleware) IssueFromCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (c certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_from_csr took %s to complete", time.Since(begin)) diff --git a/api/metrics.go b/api/metrics.go index 06053a5..fb3e70e 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -136,14 +136,6 @@ func (mm *metricsMiddleware) GetChainCA(ctx context.Context, token string) (cert return mm.svc.GetChainCA(ctx, token) } -func (mm *metricsMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, error) { - defer func(begin time.Time) { - mm.counter.With("method", "create_csr").Add(1) - mm.latency.With("method", "create_csr").Observe(time.Since(begin).Seconds()) - }(time.Now()) - return mm.svc.CreateCSR(ctx, metadata, privKey) -} - func (mm *metricsMiddleware) IssueFromCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_from_csr").Add(1) diff --git a/certs.go b/certs.go index b530f46..f32eac9 100644 --- a/certs.go +++ b/certs.go @@ -99,8 +99,8 @@ type CSRMetadata struct { } type CSR struct { - CSR []byte `json:"csr"` - PrivateKey []byte `json:"private_key"` + CSR []byte `json:"csr,omitempty"` + PrivateKey []byte `json:"private_key,omitempty"` } type CSRPage struct { @@ -175,9 +175,6 @@ type Service interface { // RemoveCert deletes a cert for a provided entityID. RemoveCert(ctx context.Context, entityId string) error - // CreateCSR creates a new Certificate Signing Request. - CreateCSR(ctx context.Context, metadata CSRMetadata, privKey any) (CSR, error) - // IssueFromCSR creates a certificate from a given CSR. IssueFromCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) } diff --git a/cli/certs.go b/cli/certs.go index 9955444..1ffca49 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -4,9 +4,20 @@ package cli import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" "encoding/json" + "encoding/pem" + "net" "os" + "github.com/absmach/certs" + "github.com/absmach/certs/errors" ctxsdk "github.com/absmach/certs/sdk" "github.com/spf13/cobra" ) @@ -14,6 +25,8 @@ import ( // Keep SDK handle in global var. var sdk ctxsdk.SDK +var ErrCreateEntity = errors.New("failed to create entity") + func SetSDK(s ctxsdk.SDK) { sdk = s } @@ -248,7 +261,7 @@ var cmdCerts = []cobra.Command{ return } - var pm ctxsdk.PageMetadata + var pm certs.CSRMetadata if err := json.Unmarshal([]byte(args[0]), &pm); err != nil { logErrorCmd(*cmd, err) return @@ -260,7 +273,7 @@ var cmdCerts = []cobra.Command{ return } - csr, err := sdk.CreateCSR(pm, string(data)) + csr, err := CreateCSR(pm, data) if err != nil { logErrorCmd(*cmd, err) return @@ -354,3 +367,63 @@ func NewCertsCmd() *cobra.Command { return &cmd } + +func CreateCSR(metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKError) { + template := &x509.CertificateRequest{ + Subject: pkix.Name{ + CommonName: metadata.CommonName, + Organization: metadata.Organization, + OrganizationalUnit: metadata.OrganizationalUnit, + Country: metadata.Country, + Province: metadata.Province, + Locality: metadata.Locality, + StreetAddress: metadata.StreetAddress, + PostalCode: metadata.PostalCode, + }, + EmailAddresses: metadata.EmailAddresses, + DNSNames: metadata.DNSNames, + } + + for _, ip := range metadata.IPAddresses { + parsedIP := net.ParseIP(ip) + if parsedIP != nil { + template.IPAddresses = append(template.IPAddresses, parsedIP) + } + } + + var signer crypto.Signer + var err error + + switch key := privKey.(type) { + case *rsa.PrivateKey: + signer = key + case *ecdsa.PrivateKey: + signer = key + case ed25519.PrivateKey: + signer = key + case []byte: + parsedKey, err := certs.ExtractPrivateKey(key) + if err != nil { + return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, err)) + } + return CreateCSR(metadata, parsedKey) + default: + return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type"))) + } + + csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, signer) + if err != nil { + return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, err)) + } + + csrPEM := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE REQUEST", + Bytes: csrBytes, + }) + + csr := certs.CSR{ + CSR: csrPEM, + } + + return csr, nil +} diff --git a/cli/utils.go b/cli/utils.go index 58e8b14..c5a921e 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" + "github.com/absmach/certs" ctxsdk "github.com/absmach/certs/sdk" "github.com/fatih/color" "github.com/hokaccha/go-prettyjson" @@ -98,7 +99,7 @@ func logSaveCAFiles(cmd cobra.Command, certBundle ctxsdk.CertificateBundle) { fmt.Fprintf(cmd.OutOrStdout(), "\nAll certificate files have been saved successfully.\n") } -func logSaveCSRFiles(cmd cobra.Command, csr ctxsdk.CSR) { +func logSaveCSRFiles(cmd cobra.Command, csr certs.CSR) { files := map[string][]byte{ "file.csr": []byte(csr.CSR), } diff --git a/service.go b/service.go index 62266d6..8474dbe 100644 --- a/service.go +++ b/service.go @@ -15,7 +15,6 @@ import ( "encoding/asn1" "encoding/pem" "math/big" - "net" "time" "github.com/absmach/certs/errors" @@ -433,71 +432,6 @@ func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, er return s.getConcatCAs(ctx) } -func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey any) (CSR, error) { - template := &x509.CertificateRequest{ - Subject: pkix.Name{ - CommonName: metadata.CommonName, - Organization: metadata.Organization, - OrganizationalUnit: metadata.OrganizationalUnit, - Country: metadata.Country, - Province: metadata.Province, - Locality: metadata.Locality, - StreetAddress: metadata.StreetAddress, - PostalCode: metadata.PostalCode, - }, - EmailAddresses: metadata.EmailAddresses, - DNSNames: metadata.DNSNames, - } - - for _, ip := range metadata.IPAddresses { - parsedIP := net.ParseIP(ip) - if parsedIP != nil { - template.IPAddresses = append(template.IPAddresses, parsedIP) - } - } - csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, privKey) - if err != nil { - return CSR{}, errors.Wrap(ErrCreateEntity, err) - } - - csrPEM := pem.EncodeToMemory(&pem.Block{ - Type: "CERTIFICATE REQUEST", - Bytes: csrBytes, - }) - - var privKeyBytes []byte - var privKeyType string - switch key := privKey.(type) { - case *rsa.PrivateKey: - privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = RSAPrivateKey - case *ecdsa.PrivateKey: - privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = ECPrivateKey - case ed25519.PrivateKey: - privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = PrivateKey - default: - return CSR{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type")) - } - - if err != nil { - return CSR{}, errors.Wrap(ErrCreateEntity, err) - } - - privKeyPEM := pem.EncodeToMemory(&pem.Block{ - Type: privKeyType, - Bytes: privKeyBytes, - }) - - csr := CSR{ - CSR: csrPEM, - PrivateKey: privKeyPEM, - } - - return csr, nil -} - func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error) { block, _ := pem.Decode(csr.CSR) if block == nil { @@ -513,7 +447,7 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } - privKey, err := extractPrivateKey(csr.PrivateKey) + privKey, err := ExtractPrivateKey(csr.PrivateKey) if err != nil { return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } @@ -858,7 +792,7 @@ func (s *service) loadCACerts(ctx context.Context) error { return nil } -func extractPrivateKey(pemKey []byte) (any, error) { +func ExtractPrivateKey(pemKey []byte) (any, error) { block, _ := pem.Decode(pemKey) if block == nil { return nil, errors.New("failed to parse private key PEM") diff --git a/tracing/certs.go b/tracing/certs.go index fc34651..dd43064 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -101,12 +101,6 @@ func (tm *tracingMiddleware) GetChainCA(ctx context.Context, token string) (cert return tm.svc.GetChainCA(ctx, token) } -func (tm *tracingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, error) { - ctx, span := tm.tracer.Start(ctx, "create_csr") - defer span.End() - return tm.svc.CreateCSR(ctx, metadata, privKey) -} - func (tm *tracingMiddleware) IssueFromCSR(ctx context.Context, entityID, ttl string, csr certs.CSR) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_from_csr") defer span.End() From 2421ea4b415c69f8c4c7844f5d082cee9db5844b Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 15:46:39 +0300 Subject: [PATCH 13/22] fix failing linter Signed-off-by: nyagamunene --- api/http/requests.go | 17 ------------- api/http/responses.go | 16 ------------ mocks/service.go | 58 ------------------------------------------- 3 files changed, 91 deletions(-) diff --git a/api/http/requests.go b/api/http/requests.go index a6c5b31..948b6e6 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -88,23 +88,6 @@ func (req ocspReq) validate() error { return nil } -type createCSRReq struct { - Metadata certs.CSRMetadata `json:"metadata"` - PrivateKey string `json:"private_key"` - privKey any -} - -func (req createCSRReq) validate() error { - if req.Metadata.CommonName == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrMissingCN) - } - - if len(req.PrivateKey) == 0 { - return errors.Wrap(certs.ErrMalformedEntity, ErrMissingPrivKey) - } - return nil -} - type IssueFromCSRReq struct { entityID string ttl string diff --git a/api/http/responses.go b/api/http/responses.go index ae401fc..a21db40 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -202,22 +202,6 @@ type fileDownloadRes struct { ContentType string } -type createCSRRes struct { - CSR string `json:"csr"` -} - -func (res createCSRRes) Code() int { - return http.StatusCreated -} - -func (res createCSRRes) Headers() map[string]string { - return map[string]string{} -} - -func (res createCSRRes) Empty() bool { - return false -} - type issueFromCSRRes struct { SerialNumber string `json:"serial_number"` Certificate string `json:"certificate,omitempty"` diff --git a/mocks/service.go b/mocks/service.go index e805c6e..61c8d66 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -28,64 +28,6 @@ func (_m *MockService) EXPECT() *MockService_Expecter { return &MockService_Expecter{mock: &_m.Mock} } -// CreateCSR provides a mock function with given fields: ctx, metadata, privKey -func (_m *MockService) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey interface{}) (certs.CSR, error) { - ret := _m.Called(ctx, metadata, privKey) - - if len(ret) == 0 { - panic("no return value specified for CreateCSR") - } - - var r0 certs.CSR - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, interface{}) (certs.CSR, error)); ok { - return rf(ctx, metadata, privKey) - } - if rf, ok := ret.Get(0).(func(context.Context, certs.CSRMetadata, interface{}) certs.CSR); ok { - r0 = rf(ctx, metadata, privKey) - } else { - r0 = ret.Get(0).(certs.CSR) - } - - if rf, ok := ret.Get(1).(func(context.Context, certs.CSRMetadata, interface{}) error); ok { - r1 = rf(ctx, metadata, privKey) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockService_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' -type MockService_CreateCSR_Call struct { - *mock.Call -} - -// CreateCSR is a helper method to define mock.On call -// - ctx context.Context -// - metadata certs.CSRMetadata -// - privKey interface{} -func (_e *MockService_Expecter) CreateCSR(ctx interface{}, metadata interface{}, privKey interface{}) *MockService_CreateCSR_Call { - return &MockService_CreateCSR_Call{Call: _e.mock.On("CreateCSR", ctx, metadata, privKey)} -} - -func (_c *MockService_CreateCSR_Call) Run(run func(ctx context.Context, metadata certs.CSRMetadata, privKey interface{})) *MockService_CreateCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(certs.CSRMetadata), args[2].(interface{})) - }) - return _c -} - -func (_c *MockService_CreateCSR_Call) Return(_a0 certs.CSR, _a1 error) *MockService_CreateCSR_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockService_CreateCSR_Call) RunAndReturn(run func(context.Context, certs.CSRMetadata, interface{}) (certs.CSR, error)) *MockService_CreateCSR_Call { - _c.Call.Return(run) - return _c -} - // GenerateCRL provides a mock function with given fields: ctx, caType func (_m *MockService) GenerateCRL(ctx context.Context, caType certs.CertType) ([]byte, error) { ret := _m.Called(ctx, caType) From 2dbf0459887f0f4d2943ad3c1c8a63cd29dc0f3e Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 19:41:40 +0300 Subject: [PATCH 14/22] refactor issue cert method Signed-off-by: nyagamunene --- api/http/endpoint.go | 2 +- api/http/requests.go | 7 ++- cli/certs.go | 44 +++++++++++++---- sdk/sdk.go | 75 ++--------------------------- service.go | 110 ++++++++++++++++++++----------------------- 5 files changed, 91 insertions(+), 147 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index ae1a30c..72620ad 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -317,7 +317,7 @@ func issueFromCSREndpoint(svc certs.Service) endpoint.Endpoint { return issueFromCSRRes{}, err } - cert, err := svc.IssueFromCSR(ctx, req.entityID, req.ttl, certs.CSR{CSR: []byte(req.CSR), PrivateKey: []byte(req.PrivateKey)}) + cert, err := svc.IssueFromCSR(ctx, req.entityID, req.ttl, certs.CSR{CSR: []byte(req.CSR)}) if err != nil { return issueFromCSRRes{}, err } diff --git a/api/http/requests.go b/api/http/requests.go index 948b6e6..e540706 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -89,10 +89,9 @@ func (req ocspReq) validate() error { } type IssueFromCSRReq struct { - entityID string - ttl string - CSR string `json:"csr"` - PrivateKey string `json:"private_key"` + entityID string + ttl string + CSR string `json:"csr"` } func (req IssueFromCSRReq) validate() error { diff --git a/cli/certs.go b/cli/certs.go index 1ffca49..faedad6 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -283,11 +283,11 @@ var cmdCerts = []cobra.Command{ }, }, { - Use: "issue-csr ", + Use: "issue-csr ", Short: "Issue from CSR", Long: `issues a certificate for a given csr.`, Run: func(cmd *cobra.Command, args []string) { - if len(args) != 4 { + if len(args) != 3 { logUsageCmd(*cmd, cmd.Use) return } @@ -298,13 +298,7 @@ var cmdCerts = []cobra.Command{ return } - privData, err := os.ReadFile(args[3]) - if err != nil { - logErrorCmd(*cmd, err) - return - } - - cert, err := sdk.IssueFromCSR(args[0], args[1], string(csrData), string(privData)) + cert, err := sdk.IssueFromCSR(args[0], args[1], string(csrData)) if err != nil { logErrorCmd(*cmd, err) return @@ -402,7 +396,7 @@ func CreateCSR(metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKEr case ed25519.PrivateKey: signer = key case []byte: - parsedKey, err := certs.ExtractPrivateKey(key) + parsedKey, err := extractPrivateKey(key) if err != nil { return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, err)) } @@ -427,3 +421,33 @@ func CreateCSR(metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKEr return csr, nil } + +func extractPrivateKey(pemKey []byte) (any, error) { + block, _ := pem.Decode(pemKey) + if block == nil { + return nil, errors.New("failed to parse private key PEM") + } + + var ( + privateKey any + err error + ) + + switch block.Type { + case certs.RSAPrivateKey: + privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) + case certs.ECPrivateKey: + privateKey, err = x509.ParseECPrivateKey(block.Bytes) + case certs.PrivateKey, "PKCS8 PRIVATE KEY": + privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + case "ED25519 PRIVATE KEY": + privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) + default: + err = errors.New("unsupported private key type") + } + if err != nil { + return nil, errors.New("failed to parse key") + } + + return privateKey, nil +} \ No newline at end of file diff --git a/sdk/sdk.go b/sdk/sdk.go index 5f371aa..003698e 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -264,20 +264,12 @@ type SDK interface { // fmt.Println(response) GetCAToken() (Token, errors.SDKError) - // CreateCSR creates a new Certificate Signing Request - // - // example: - // pm = sdk.CSRMetadata{CommonName: "common_name"} - // response, _ := sdk.CreateCSR(pm, "privKey") - // fmt.Println(response) - CreateCSR(pm PageMetadata, privKey string) (CSR, errors.SDKError) - // IssueFromCSR issues certificate from provided CSR // // example: - // certs, err := sdk.IssueFromCSR( "entityID", "ttl", "csrFile", "privKey") + // certs, err := sdk.IssueFromCSR( "entityID", "ttl", "csrFile") // fmt.Println(err) - IssueFromCSR(entityID, ttl string, csr, privKey string) (Certificate, errors.SDKError) + IssueFromCSR(entityID, ttl string, csr string) (Certificate, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -546,59 +538,13 @@ func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) { return tk, nil } -func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey string) (CSR, errors.SDKError) { - r := csrReq{ - Metadata: meta{ - CommonName: pm.CommonName, - Organization: pm.Organization, - OrganizationalUnit: pm.OrganizationalUnit, - Country: pm.Country, - Province: pm.Province, - Locality: pm.Locality, - StreetAddress: pm.StreetAddress, - PostalCode: pm.PostalCode, - DNSNames: pm.DNSNames, - IPAddresses: pm.IPAddresses, - EmailAddresses: pm.EmailAddresses, - }, - PrivateKey: privKey, - } - d, err := json.Marshal(r) - if err != nil { - return CSR{}, errors.NewSDKError(err) - } - url := fmt.Sprintf("%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint) - _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK) - if sdkerr != nil { - return CSR{}, errors.NewSDKError(err) - } - - zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) - if err != nil { - return CSR{}, errors.NewSDKError(err) - } - - var csr CSR - for _, file := range zipReader.File { - fileContent, err := readZipFile(file) - if err != nil { - return CSR{}, errors.NewSDKError(err) - } - - csr.CSR = fileContent - } - - return csr, nil -} - -func (sdk mgSDK) IssueFromCSR(entityID, ttl string, csr, privKey string) (Certificate, errors.SDKError) { +func (sdk mgSDK) IssueFromCSR(entityID, ttl string, csr string) (Certificate, errors.SDKError) { pm := PageMetadata{ TTL: ttl, } r := csrReq{ CSR: csr, - PrivateKey: privKey, } d, err := json.Marshal(r) @@ -732,21 +678,6 @@ type certReq struct { } type csrReq struct { - Metadata meta `json:"metadata,omitempty"` - PrivateKey string `json:"private_key,omitempty"` CSR string `json:"csr,omitempty"` } -type meta struct { - CommonName string `json:"common_name"` - Organization []string `json:"organization"` - OrganizationalUnit []string `json:"organizational_unit"` - Country []string `json:"country"` - Province []string `json:"province"` - Locality []string `json:"locality"` - StreetAddress []string `json:"street_address"` - PostalCode []string `json:"postal_code"` - DNSNames []string `json:"dns_names"` - IPAddresses []string `json:"ip_addresses"` - EmailAddresses []string `json:"email_addresses"` -} diff --git a/service.go b/service.go index 8474dbe..2dff11d 100644 --- a/service.go +++ b/service.go @@ -28,7 +28,7 @@ const ( PrivateKeyBytes = 2048 RootCAValidityPeriod = time.Hour * 24 * 365 // 365 days IntermediateCAVAlidityPeriod = time.Hour * 24 * 90 // 90 days - certValidityPeriod = time.Hour * 24 * 90 // 30 days + certValidityPeriod = time.Hour * 24 * 30 // 30 days rCertExpiryThreshold = time.Hour * 24 * 30 // 30 days iCertExpiryThreshold = time.Hour * 24 * 10 // 10 days downloadTokenExpiry = time.Minute * 5 @@ -86,6 +86,11 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, return &svc, nil } +// issueCert generates and issues a certificate for a given backendID. +// It uses the RSA algorithm to generate a private key, and then creates a certificate +// using the provided template and the generated private key. +// The certificate is then stored in the repository using the CreateCert method. +// If the root CA is not found, it returns an error. // issueCert generates and issues a certificate for a given backendID. // It uses the RSA algorithm to generate a private key, and then creates a certificate // using the provided template and the generated private key. @@ -93,16 +98,35 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // If the root CA is not found, it returns an error. func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...any) (Certificate, error) { var privKey any + var pubKey crypto.PublicKey var err error + if len(key) == 0 { pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - privKey = pKey if err != nil { return Certificate{}, err } + privKey = pKey + pubKey = pKey.Public() } else { - privKey = key[0] + switch k := key[0].(type) { + case *rsa.PrivateKey: + privKey = k + pubKey = k.Public() + case *ecdsa.PrivateKey: + privKey = k + pubKey = k.Public() + case ed25519.PrivateKey: + privKey = k + pubKey = k.Public() + case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: + pubKey = k + privKey = nil + default: + return Certificate{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported key type")) + } } + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return Certificate{}, err @@ -135,43 +159,44 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ DNSNames: append(s.intermediateCA.Certificate.DNSNames, ipAddrs...), } - var pubKey crypto.PublicKey var privKeyBytes []byte var privKeyType string - switch key := privKey.(type) { - case *rsa.PrivateKey: - pubKey = key.Public() - privKeyBytes = x509.MarshalPKCS1PrivateKey(key) - privKeyType = RSAPrivateKey - case *ecdsa.PrivateKey: - pubKey = key.Public() - privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = ECPrivateKey - case ed25519.PrivateKey: - pubKey = key.Public() - privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) - privKeyType = PrivateKey - default: - return Certificate{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type")) - } + if privKey != nil { + switch key := privKey.(type) { + case *rsa.PrivateKey: + privKeyBytes = x509.MarshalPKCS1PrivateKey(key) + privKeyType = RSAPrivateKey + case *ecdsa.PrivateKey: + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = ECPrivateKey + case ed25519.PrivateKey: + privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key) + privKeyType = PrivateKey + } - if err != nil { - return Certificate{}, err + if err != nil { + return Certificate{}, err + } } certBytes, err := x509.CreateCertificate(rand.Reader, &template, s.intermediateCA.Certificate, pubKey, s.intermediateCA.PrivateKey) if err != nil { return Certificate{}, err } + dbCert := Certificate{ - Key: pem.EncodeToMemory(&pem.Block{Type: privKeyType, Bytes: privKeyBytes}), - Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), SerialNumber: template.SerialNumber.String(), EntityID: entityID, ExpiryTime: template.NotAfter, Type: ClientCert, + Certificate: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}), + } + + if privKeyBytes != nil { + dbCert.Key = pem.EncodeToMemory(&pem.Block{Type: privKeyType, Bytes: privKeyBytes}) } + if err = s.repo.CreateCert(ctx, dbCert); err != nil { return Certificate{}, errors.Wrap(ErrCreateEntity, err) } @@ -447,11 +472,6 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } - privKey, err := ExtractPrivateKey(csr.PrivateKey) - if err != nil { - return Certificate{}, errors.Wrap(ErrMalformedEntity, err) - } - cert, err := s.IssueCert(ctx, entityID, ttl, nil, SubjectOptions{ CommonName: parsedCSR.Subject.CommonName, Organization: parsedCSR.Subject.Organization, @@ -461,7 +481,7 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS Locality: parsedCSR.Subject.Locality, StreetAddress: parsedCSR.Subject.StreetAddress, PostalCode: parsedCSR.Subject.PostalCode, - }, privKey) + }, parsedCSR.PublicKey) if err != nil { return Certificate{}, errors.Wrap(ErrCreateEntity, err) } @@ -791,33 +811,3 @@ func (s *service) loadCACerts(ctx context.Context) error { } return nil } - -func ExtractPrivateKey(pemKey []byte) (any, error) { - block, _ := pem.Decode(pemKey) - if block == nil { - return nil, errors.New("failed to parse private key PEM") - } - - var ( - privateKey any - err error - ) - - switch block.Type { - case RSAPrivateKey: - privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) - case ECPrivateKey: - privateKey, err = x509.ParseECPrivateKey(block.Bytes) - case PrivateKey, "PKCS8 PRIVATE KEY": - privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) - case "ED25519 PRIVATE KEY": - privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) - default: - err = errors.New("unsupported private key type") - } - if err != nil { - return nil, errors.New("failed to parse key") - } - - return privateKey, nil -} From 44e5e41672dc457c2aaed5e296bf7d0c5ab2c91a Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 19:43:55 +0300 Subject: [PATCH 15/22] fix failing linter Signed-off-by: nyagamunene --- sdk/mocks/sdk.go | 88 ++++++++---------------------------------------- 1 file changed, 14 insertions(+), 74 deletions(-) diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index 210102d..7266cca 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -25,65 +25,6 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { return &MockSDK_Expecter{mock: &_m.Mock} } -// CreateCSR provides a mock function with given fields: pm, privKey -func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKey string) (sdk.CSR, errors.SDKError) { - ret := _m.Called(pm, privKey) - - if len(ret) == 0 { - panic("no return value specified for CreateCSR") - } - - var r0 sdk.CSR - var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)); ok { - return rf(pm, privKey) - } - if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.CSR); ok { - r0 = rf(pm, privKey) - } else { - r0 = ret.Get(0).(sdk.CSR) - } - - if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok { - r1 = rf(pm, privKey) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(errors.SDKError) - } - } - - return r0, r1 -} - -// MockSDK_CreateCSR_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateCSR' -type MockSDK_CreateCSR_Call struct { - *mock.Call -} - -// CreateCSR is a helper method to define mock.On call -// - pm sdk.PageMetadata -// - privKey string -func (_e *MockSDK_Expecter) CreateCSR(pm interface{}, privKey interface{}) *MockSDK_CreateCSR_Call { - return &MockSDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", pm, privKey)} -} - -func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKey string)) *MockSDK_CreateCSR_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata), args[1].(string)) - }) - return _c -} - -func (_c *MockSDK_CreateCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *MockSDK_CreateCSR_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { - _c.Call.Return(run) - return _c -} - // DeleteCert provides a mock function with given fields: entityID func (_m *MockSDK) DeleteCert(entityID string) errors.SDKError { ret := _m.Called(entityID) @@ -367,9 +308,9 @@ func (_c *MockSDK_IssueCert_Call) RunAndReturn(run func(string, string, []string return _c } -// IssueFromCSR provides a mock function with given fields: entityID, ttl, csr, privKey -func (_m *MockSDK) IssueFromCSR(entityID string, ttl string, csr string, privKey string) (sdk.Certificate, errors.SDKError) { - ret := _m.Called(entityID, ttl, csr, privKey) +// IssueFromCSR provides a mock function with given fields: entityID, ttl, csr +func (_m *MockSDK) IssueFromCSR(entityID string, ttl string, csr string) (sdk.Certificate, errors.SDKError) { + ret := _m.Called(entityID, ttl, csr) if len(ret) == 0 { panic("no return value specified for IssueFromCSR") @@ -377,17 +318,17 @@ func (_m *MockSDK) IssueFromCSR(entityID string, ttl string, csr string, privKey var r0 sdk.Certificate var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string, string, string, string) (sdk.Certificate, errors.SDKError)); ok { - return rf(entityID, ttl, csr, privKey) + if rf, ok := ret.Get(0).(func(string, string, string) (sdk.Certificate, errors.SDKError)); ok { + return rf(entityID, ttl, csr) } - if rf, ok := ret.Get(0).(func(string, string, string, string) sdk.Certificate); ok { - r0 = rf(entityID, ttl, csr, privKey) + if rf, ok := ret.Get(0).(func(string, string, string) sdk.Certificate); ok { + r0 = rf(entityID, ttl, csr) } else { r0 = ret.Get(0).(sdk.Certificate) } - if rf, ok := ret.Get(1).(func(string, string, string, string) errors.SDKError); ok { - r1 = rf(entityID, ttl, csr, privKey) + if rf, ok := ret.Get(1).(func(string, string, string) errors.SDKError); ok { + r1 = rf(entityID, ttl, csr) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -406,14 +347,13 @@ type MockSDK_IssueFromCSR_Call struct { // - entityID string // - ttl string // - csr string -// - privKey string -func (_e *MockSDK_Expecter) IssueFromCSR(entityID interface{}, ttl interface{}, csr interface{}, privKey interface{}) *MockSDK_IssueFromCSR_Call { - return &MockSDK_IssueFromCSR_Call{Call: _e.mock.On("IssueFromCSR", entityID, ttl, csr, privKey)} +func (_e *MockSDK_Expecter) IssueFromCSR(entityID interface{}, ttl interface{}, csr interface{}) *MockSDK_IssueFromCSR_Call { + return &MockSDK_IssueFromCSR_Call{Call: _e.mock.On("IssueFromCSR", entityID, ttl, csr)} } -func (_c *MockSDK_IssueFromCSR_Call) Run(run func(entityID string, ttl string, csr string, privKey string)) *MockSDK_IssueFromCSR_Call { +func (_c *MockSDK_IssueFromCSR_Call) Run(run func(entityID string, ttl string, csr string)) *MockSDK_IssueFromCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(string), args[3].(string)) + run(args[0].(string), args[1].(string), args[2].(string)) }) return _c } @@ -423,7 +363,7 @@ func (_c *MockSDK_IssueFromCSR_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKE return _c } -func (_c *MockSDK_IssueFromCSR_Call) RunAndReturn(run func(string, string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_IssueFromCSR_Call { +func (_c *MockSDK_IssueFromCSR_Call) RunAndReturn(run func(string, string, string) (sdk.Certificate, errors.SDKError)) *MockSDK_IssueFromCSR_Call { _c.Call.Return(run) return _c } From 01eb0bb0483060e3e5efffc5f4ad652e197eb4b1 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 2 Dec 2024 19:51:27 +0300 Subject: [PATCH 16/22] fix failing linter Signed-off-by: nyagamunene --- cli/certs.go | 2 +- sdk/sdk.go | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/cli/certs.go b/cli/certs.go index faedad6..d6caed5 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -450,4 +450,4 @@ func extractPrivateKey(pemKey []byte) (any, error) { } return privateKey, nil -} \ No newline at end of file +} diff --git a/sdk/sdk.go b/sdk/sdk.go index 003698e..3ffe060 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -544,7 +544,7 @@ func (sdk mgSDK) IssueFromCSR(entityID, ttl string, csr string) (Certificate, er } r := csrReq{ - CSR: csr, + CSR: csr, } d, err := json.Marshal(r) @@ -678,6 +678,5 @@ type certReq struct { } type csrReq struct { - CSR string `json:"csr,omitempty"` + CSR string `json:"csr,omitempty"` } - From c412393650639660aead28173909b229b9d6a517 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Sun, 8 Dec 2024 21:02:09 +0300 Subject: [PATCH 17/22] address comments Signed-off-by: nyagamunene --- cli/certs.go | 22 +++++++++++----------- service.go | 2 ++ 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/cli/certs.go b/cli/certs.go index d6caed5..e416488 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -25,7 +25,11 @@ import ( // Keep SDK handle in global var. var sdk ctxsdk.SDK -var ErrCreateEntity = errors.New("failed to create entity") +var ( + ErrCreateEntity = errors.New("failed to create entity") + ErrPrivKeyType = errors.New("unsupported private key type") + ErrFailedParse = errors.New("failed to parse key") +) func SetSDK(s ctxsdk.SDK) { sdk = s @@ -389,10 +393,8 @@ func CreateCSR(metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKEr var err error switch key := privKey.(type) { - case *rsa.PrivateKey: - signer = key - case *ecdsa.PrivateKey: - signer = key + case *rsa.PrivateKey, *ecdsa.PrivateKey: + signer = key.(crypto.Signer) case ed25519.PrivateKey: signer = key case []byte: @@ -402,7 +404,7 @@ func CreateCSR(metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKEr } return CreateCSR(metadata, parsedKey) default: - return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type"))) + return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, ErrPrivKeyType)) } csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, signer) @@ -438,15 +440,13 @@ func extractPrivateKey(pemKey []byte) (any, error) { privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes) case certs.ECPrivateKey: privateKey, err = x509.ParseECPrivateKey(block.Bytes) - case certs.PrivateKey, "PKCS8 PRIVATE KEY": - privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) - case "ED25519 PRIVATE KEY": + case certs.PrivateKey, certs.PKCS8PrivateKey, certs.EDPrivateKey: privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) default: - err = errors.New("unsupported private key type") + err = ErrPrivKeyType } if err != nil { - return nil, errors.New("failed to parse key") + return nil, ErrFailedParse } return privateKey, nil diff --git a/service.go b/service.go index 2dff11d..5aacb0d 100644 --- a/service.go +++ b/service.go @@ -35,6 +35,8 @@ const ( PrivateKey = "PRIVATE KEY" RSAPrivateKey = "RSA PRIVATE KEY" ECPrivateKey = "EC PRIVATE KEY" + PKCS8PrivateKey = "PKCS8 PRIVATE KEY" + EDPrivateKey = "ED25519 PRIVATE KEY" ) var ( From 024409ca9968922675a31ba99be925ce2373a5bd Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 9 Dec 2024 12:16:55 +0300 Subject: [PATCH 18/22] change variable from any to crypto.PrivateKey Signed-off-by: nyagamunene --- api/logging.go | 3 ++- api/metrics.go | 3 ++- certs.go | 3 ++- mocks/service.go | 26 ++++++++++++++++---------- service.go | 4 ++-- tracing/certs.go | 3 ++- 6 files changed, 26 insertions(+), 16 deletions(-) diff --git a/api/logging.go b/api/logging.go index 45b7f32..f92125f 100644 --- a/api/logging.go +++ b/api/logging.go @@ -5,6 +5,7 @@ package api import ( "context" + "crypto" "crypto/x509" "fmt" "log/slog" @@ -85,7 +86,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s return lm.svc.RetrieveCAToken(ctx) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...crypto.PrivateKey) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { diff --git a/api/metrics.go b/api/metrics.go index fb3e70e..65cc582 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -5,6 +5,7 @@ package api import ( "context" + "crypto" "crypto/x509" "time" @@ -71,7 +72,7 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error return mm.svc.RetrieveCAToken(ctx) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...crypto.PrivateKey) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) diff --git a/certs.go b/certs.go index f32eac9..ac1e586 100644 --- a/certs.go +++ b/certs.go @@ -5,6 +5,7 @@ package certs import ( "context" + "crypto" "crypto/rsa" "crypto/x509" "net" @@ -158,7 +159,7 @@ type Service interface { RetrieveCAToken(ctx context.Context) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...any) (Certificate, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...crypto.PrivateKey) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) diff --git a/mocks/service.go b/mocks/service.go index 61c8d66..d675dc7 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -10,6 +10,8 @@ import ( certs "github.com/absmach/certs" + crypto "crypto" + mock "github.com/stretchr/testify/mock" x509 "crypto/x509" @@ -202,10 +204,14 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s } // IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option, privKey -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...interface{}) (certs.Certificate, error) { +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...crypto.PrivateKey) (certs.Certificate, error) { + _va := make([]interface{}, len(privKey)) + for _i := range privKey { + _va[_i] = privKey[_i] + } var _ca []interface{} _ca = append(_ca, ctx, entityID, ttl, ipAddrs, option) - _ca = append(_ca, privKey...) + _ca = append(_ca, _va...) ret := _m.Called(_ca...) if len(ret) == 0 { @@ -214,16 +220,16 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) (certs.Certificate, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...crypto.PrivateKey) (certs.Certificate, error)); ok { return rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) certs.Certificate); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...crypto.PrivateKey) certs.Certificate); ok { r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r0 = ret.Get(0).(certs.Certificate) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...crypto.PrivateKey) error); ok { r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) } else { r1 = ret.Error(1) @@ -243,18 +249,18 @@ type MockService_IssueCert_Call struct { // - ttl string // - ipAddrs []string // - option certs.SubjectOptions -// - privKey ...interface{} +// - privKey ...crypto.PrivateKey func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey ...interface{}) *MockService_IssueCert_Call { return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", append([]interface{}{ctx, entityID, ttl, ipAddrs, option}, privKey...)...)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...interface{})) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...crypto.PrivateKey)) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]interface{}, len(args)-5) + variadicArgs := make([]crypto.PrivateKey, len(args)-5) for i, a := range args[5:] { if a != nil { - variadicArgs[i] = a.(interface{}) + variadicArgs[i] = a.(crypto.PrivateKey) } } run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), variadicArgs...) @@ -267,7 +273,7 @@ func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) * return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...interface{}) (certs.Certificate, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...crypto.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/service.go b/service.go index 5aacb0d..9a8ba02 100644 --- a/service.go +++ b/service.go @@ -98,8 +98,8 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...any) (Certificate, error) { - var privKey any +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...crypto.PrivateKey) (Certificate, error) { + var privKey crypto.PrivateKey var pubKey crypto.PublicKey var err error diff --git a/tracing/certs.go b/tracing/certs.go index dd43064..6de2163 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -5,6 +5,7 @@ package tracing import ( "context" + "crypto" "crypto/x509" "github.com/absmach/certs" @@ -53,7 +54,7 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error return tm.svc.RetrieveCAToken(ctx) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...any) (certs.Certificate, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...crypto.PrivateKey) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) From 98275e79366899ff959ed65e688f37e12116eaf9 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 9 Dec 2024 13:47:46 +0300 Subject: [PATCH 19/22] Refactor issueCerts Signed-off-by: nyagamunene --- api/http/endpoint.go | 2 +- api/logging.go | 4 ++-- api/metrics.go | 4 ++-- certs.go | 2 +- cli/certs.go | 16 +++++----------- service.go | 12 +++++++----- tracing/certs.go | 4 ++-- 7 files changed, 20 insertions(+), 24 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 72620ad..c1fb611 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -106,7 +106,7 @@ func issueCertEndpoint(svc certs.Service) endpoint.Endpoint { return issueCertRes{}, err } - cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options) + cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options, nil) if err != nil { return issueCertRes{}, err } diff --git a/api/logging.go b/api/logging.go index f92125f..83eaf47 100644 --- a/api/logging.go +++ b/api/logging.go @@ -86,7 +86,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s return lm.svc.RetrieveCAToken(ctx) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...crypto.PrivateKey) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { @@ -95,7 +95,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string } lm.logger.Info(message) }(time.Now()) - return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) + return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) } func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) { diff --git a/api/metrics.go b/api/metrics.go index 65cc582..b50b16d 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -72,12 +72,12 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error return mm.svc.RetrieveCAToken(ctx) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...crypto.PrivateKey) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) + return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) } func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { diff --git a/certs.go b/certs.go index ac1e586..327bd6b 100644 --- a/certs.go +++ b/certs.go @@ -159,7 +159,7 @@ type Service interface { RetrieveCAToken(ctx context.Context) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey ...crypto.PrivateKey) (Certificate, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey crypto.PrivateKey) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) diff --git a/cli/certs.go b/cli/certs.go index e416488..fd264df 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -25,12 +25,6 @@ import ( // Keep SDK handle in global var. var sdk ctxsdk.SDK -var ( - ErrCreateEntity = errors.New("failed to create entity") - ErrPrivKeyType = errors.New("unsupported private key type") - ErrFailedParse = errors.New("failed to parse key") -) - func SetSDK(s ctxsdk.SDK) { sdk = s } @@ -400,16 +394,16 @@ func CreateCSR(metadata certs.CSRMetadata, privKey any) (certs.CSR, errors.SDKEr case []byte: parsedKey, err := extractPrivateKey(key) if err != nil { - return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, err)) + return certs.CSR{}, errors.NewSDKError(errors.Wrap(certs.ErrCreateEntity, err)) } return CreateCSR(metadata, parsedKey) default: - return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, ErrPrivKeyType)) + return certs.CSR{}, errors.NewSDKError(errors.Wrap(certs.ErrCreateEntity, certs.ErrPrivKeyType)) } csrBytes, err := x509.CreateCertificateRequest(rand.Reader, template, signer) if err != nil { - return certs.CSR{}, errors.NewSDKError(errors.Wrap(ErrCreateEntity, err)) + return certs.CSR{}, errors.NewSDKError(errors.Wrap(certs.ErrCreateEntity, err)) } csrPEM := pem.EncodeToMemory(&pem.Block{ @@ -443,10 +437,10 @@ func extractPrivateKey(pemKey []byte) (any, error) { case certs.PrivateKey, certs.PKCS8PrivateKey, certs.EDPrivateKey: privateKey, err = x509.ParsePKCS8PrivateKey(block.Bytes) default: - err = ErrPrivKeyType + err = certs.ErrPrivKeyType } if err != nil { - return nil, ErrFailedParse + return nil, certs.ErrFailedParse } return privateKey, nil diff --git a/service.go b/service.go index 9a8ba02..4950a14 100644 --- a/service.go +++ b/service.go @@ -54,6 +54,8 @@ var ( ErrCertRevoked = errors.New("certificate has been revoked and cannot be renewed") ErrCertInvalidType = errors.New("invalid cert type") ErrInvalidLength = errors.New("invalid length of serial numbers") + ErrPrivKeyType = errors.New("unsupported private key type") + ErrFailedParse = errors.New("failed to parse key PEM") ) type service struct { @@ -98,12 +100,12 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...crypto.PrivateKey) (Certificate, error) { +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key crypto.PrivateKey) (Certificate, error) { var privKey crypto.PrivateKey var pubKey crypto.PublicKey var err error - if len(key) == 0 { + if key == nil { pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) if err != nil { return Certificate{}, err @@ -111,7 +113,7 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ privKey = pKey pubKey = pKey.Public() } else { - switch k := key[0].(type) { + switch k := key.(type) { case *rsa.PrivateKey: privKey = k pubKey = k.Public() @@ -772,7 +774,7 @@ func (s *service) loadCACerts(ctx context.Context) error { } rkey, _ := pem.Decode(c.Key) if rkey == nil { - return errors.New("failed to parse key PEM") + return ErrFailedParse } rootKey, err := x509.ParsePKCS1PrivateKey(rkey.Bytes) if err != nil { @@ -797,7 +799,7 @@ func (s *service) loadCACerts(ctx context.Context) error { } ikey, _ := pem.Decode(c.Key) if ikey == nil { - return errors.New("failed to parse key PEM") + return ErrFailedParse } interKey, err := x509.ParsePKCS1PrivateKey(ikey.Bytes) if err != nil { diff --git a/tracing/certs.go b/tracing/certs.go index 6de2163..5ba671c 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -54,10 +54,10 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error return tm.svc.RetrieveCAToken(ctx) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey ...crypto.PrivateKey) (certs.Certificate, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() - return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey...) + return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) } func (tm *tracingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { From 85dedd2ce2aaecd488722009b14a8721fc58a534 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 9 Dec 2024 13:50:29 +0300 Subject: [PATCH 20/22] fix failing linter Signed-off-by: nyagamunene --- certs_test.go | 2 +- mocks/service.go | 42 ++++++++++++++---------------------------- 2 files changed, 15 insertions(+), 29 deletions(-) diff --git a/certs_test.go b/certs_test.go index 28616e2..f9e9200 100644 --- a/certs_test.go +++ b/certs_test.go @@ -67,7 +67,7 @@ func TestIssueCert(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(tc.err) - _, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{}) + _, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{}, nil) require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) repoCall1.Unset() }) diff --git a/mocks/service.go b/mocks/service.go index d675dc7..7460f4e 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -204,15 +204,8 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s } // IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option, privKey -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...crypto.PrivateKey) (certs.Certificate, error) { - _va := make([]interface{}, len(privKey)) - for _i := range privKey { - _va[_i] = privKey[_i] - } - var _ca []interface{} - _ca = append(_ca, ctx, entityID, ttl, ipAddrs, option) - _ca = append(_ca, _va...) - ret := _m.Called(_ca...) +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { + ret := _m.Called(ctx, entityID, ttl, ipAddrs, option, privKey) if len(ret) == 0 { panic("no return value specified for IssueCert") @@ -220,17 +213,17 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...crypto.PrivateKey) (certs.Certificate, error)); ok { - return rf(ctx, entityID, ttl, ipAddrs, option, privKey...) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) (certs.Certificate, error)); ok { + return rf(ctx, entityID, ttl, ipAddrs, option, privKey) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, ...crypto.PrivateKey) certs.Certificate); ok { - r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) certs.Certificate); ok { + r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey) } else { r0 = ret.Get(0).(certs.Certificate) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, ...crypto.PrivateKey) error); ok { - r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey...) + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) error); ok { + r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey) } else { r1 = ret.Error(1) } @@ -249,21 +242,14 @@ type MockService_IssueCert_Call struct { // - ttl string // - ipAddrs []string // - option certs.SubjectOptions -// - privKey ...crypto.PrivateKey -func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey ...interface{}) *MockService_IssueCert_Call { - return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", - append([]interface{}{ctx, entityID, ttl, ipAddrs, option}, privKey...)...)} +// - privKey crypto.PrivateKey +func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey interface{}) *MockService_IssueCert_Call { + return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, option, privKey)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey ...crypto.PrivateKey)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey crypto.PrivateKey)) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]crypto.PrivateKey, len(args)-5) - for i, a := range args[5:] { - if a != nil { - variadicArgs[i] = a.(crypto.PrivateKey) - } - } - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), variadicArgs...) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), args[5].(crypto.PrivateKey)) }) return _c } @@ -273,7 +259,7 @@ func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) * return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, ...crypto.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { _c.Call.Return(run) return _c } From 11972ae16c28204dccf49360852e0d90e7b43202 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 9 Dec 2024 16:29:45 +0300 Subject: [PATCH 21/22] refactor issueCerts and issueforcsr Signed-off-by: nyagamunene --- api/http/endpoint.go | 2 +- api/logging.go | 5 ++- api/metrics.go | 5 ++- certs.go | 3 +- certs_test.go | 2 +- mocks/service.go | 31 +++++++++---------- service.go | 73 +++++++++++++++++++++----------------------- tracing/certs.go | 5 ++- 8 files changed, 57 insertions(+), 69 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index c1fb611..72620ad 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -106,7 +106,7 @@ func issueCertEndpoint(svc certs.Service) endpoint.Endpoint { return issueCertRes{}, err } - cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options, nil) + cert, err := svc.IssueCert(ctx, req.entityID, req.TTL, req.IpAddrs, req.Options) if err != nil { return issueCertRes{}, err } diff --git a/api/logging.go b/api/logging.go index 83eaf47..487a903 100644 --- a/api/logging.go +++ b/api/logging.go @@ -5,7 +5,6 @@ package api import ( "context" - "crypto" "crypto/x509" "fmt" "log/slog" @@ -86,7 +85,7 @@ func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString s return lm.svc.RetrieveCAToken(ctx) } -func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) if err != nil { @@ -95,7 +94,7 @@ func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string } lm.logger.Info(message) }(time.Now()) - return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) + return lm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (cp certs.CertificatePage, err error) { diff --git a/api/metrics.go b/api/metrics.go index b50b16d..fe2f4cd 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -5,7 +5,6 @@ package api import ( "context" - "crypto" "crypto/x509" "time" @@ -72,12 +71,12 @@ func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error return mm.svc.RetrieveCAToken(ctx) } -func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { +func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) mm.latency.With("method", "issue_certificate").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) + return mm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { diff --git a/certs.go b/certs.go index 327bd6b..3d2ea0e 100644 --- a/certs.go +++ b/certs.go @@ -5,7 +5,6 @@ package certs import ( "context" - "crypto" "crypto/rsa" "crypto/x509" "net" @@ -159,7 +158,7 @@ type Service interface { RetrieveCAToken(ctx context.Context) (string, error) // IssueCert issues a certificate from the database. - IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions, privKey crypto.PrivateKey) (Certificate, error) + IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error) // OCSP retrieves the OCSP response for a certificate. OCSP(ctx context.Context, serialNumber string) (*Certificate, int, *x509.Certificate, error) diff --git a/certs_test.go b/certs_test.go index f9e9200..28616e2 100644 --- a/certs_test.go +++ b/certs_test.go @@ -67,7 +67,7 @@ func TestIssueCert(t *testing.T) { t.Run(tc.desc, func(t *testing.T) { repoCall1 := cRepo.On("CreateCert", mock.Anything, mock.Anything).Return(tc.err) - _, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{}, nil) + _, err = svc.IssueCert(context.Background(), tc.backendId, tc.ttl, []string{}, certs.SubjectOptions{}) require.True(t, errors.Contains(err, tc.err), "expected error %v, got %v", tc.err, err) repoCall1.Unset() }) diff --git a/mocks/service.go b/mocks/service.go index 7460f4e..3351649 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -10,8 +10,6 @@ import ( certs "github.com/absmach/certs" - crypto "crypto" - mock "github.com/stretchr/testify/mock" x509 "crypto/x509" @@ -203,9 +201,9 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s return _c } -// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option, privKey -func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { - ret := _m.Called(ctx, entityID, ttl, ipAddrs, option, privKey) +// IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option +func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (certs.Certificate, error) { + ret := _m.Called(ctx, entityID, ttl, ipAddrs, option) if len(ret) == 0 { panic("no return value specified for IssueCert") @@ -213,17 +211,17 @@ func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl strin var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) (certs.Certificate, error)); ok { - return rf(ctx, entityID, ttl, ipAddrs, option, privKey) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)); ok { + return rf(ctx, entityID, ttl, ipAddrs, option) } - if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) certs.Certificate); ok { - r0 = rf(ctx, entityID, ttl, ipAddrs, option, privKey) + if rf, ok := ret.Get(0).(func(context.Context, string, string, []string, certs.SubjectOptions) certs.Certificate); ok { + r0 = rf(ctx, entityID, ttl, ipAddrs, option) } else { r0 = ret.Get(0).(certs.Certificate) } - if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) error); ok { - r1 = rf(ctx, entityID, ttl, ipAddrs, option, privKey) + if rf, ok := ret.Get(1).(func(context.Context, string, string, []string, certs.SubjectOptions) error); ok { + r1 = rf(ctx, entityID, ttl, ipAddrs, option) } else { r1 = ret.Error(1) } @@ -242,14 +240,13 @@ type MockService_IssueCert_Call struct { // - ttl string // - ipAddrs []string // - option certs.SubjectOptions -// - privKey crypto.PrivateKey -func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}, privKey interface{}) *MockService_IssueCert_Call { - return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, option, privKey)} +func (_e *MockService_Expecter) IssueCert(ctx interface{}, entityID interface{}, ttl interface{}, ipAddrs interface{}, option interface{}) *MockService_IssueCert_Call { + return &MockService_IssueCert_Call{Call: _e.mock.On("IssueCert", ctx, entityID, ttl, ipAddrs, option)} } -func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions, privKey crypto.PrivateKey)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) Run(run func(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions)) *MockService_IssueCert_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions), args[5].(crypto.PrivateKey)) + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].([]string), args[4].(certs.SubjectOptions)) }) return _c } @@ -259,7 +256,7 @@ func (_c *MockService_IssueCert_Call) Return(_a0 certs.Certificate, _a1 error) * return _c } -func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions, crypto.PrivateKey) (certs.Certificate, error)) *MockService_IssueCert_Call { +func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, string, string, []string, certs.SubjectOptions) (certs.Certificate, error)) *MockService_IssueCert_Call { _c.Call.Return(run) return _c } diff --git a/service.go b/service.go index 4950a14..1070d6c 100644 --- a/service.go +++ b/service.go @@ -55,6 +55,7 @@ var ( ErrCertInvalidType = errors.New("invalid cert type") ErrInvalidLength = errors.New("invalid length of serial numbers") ErrPrivKeyType = errors.New("unsupported private key type") + ErrPubKeyType = errors.New("unsupported public key type") ErrFailedParse = errors.New("failed to parse key PEM") ) @@ -95,49 +96,45 @@ func NewService(ctx context.Context, repo Repository, config *Config) (Service, // using the provided template and the generated private key. // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. -// issueCert generates and issues a certificate for a given backendID. -// It uses the RSA algorithm to generate a private key, and then creates a certificate -// using the provided template and the generated private key. -// The certificate is then stored in the repository using the CreateCert method. -// If the root CA is not found, it returns an error. -func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key crypto.PrivateKey) (Certificate, error) { - var privKey crypto.PrivateKey - var pubKey crypto.PublicKey - var err error +func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions) (Certificate, error) { + pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) + if err != nil { + return Certificate{}, err + } - if key == nil { - pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) - if err != nil { - return Certificate{}, err - } - privKey = pKey - pubKey = pKey.Public() - } else { - switch k := key.(type) { - case *rsa.PrivateKey: - privKey = k - pubKey = k.Public() - case *ecdsa.PrivateKey: - privKey = k - pubKey = k.Public() - case ed25519.PrivateKey: - privKey = k - pubKey = k.Public() - case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: - pubKey = k - privKey = nil - default: - return Certificate{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported key type")) - } + if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil { + return Certificate{}, ErrIntermediateCANotFound } + cert, err := s.issue(ctx, entityID, ttl, ipAddrs, options, pKey.Public(), pKey) + if err != nil { + return Certificate{}, err + } + + return cert, nil +} + +func (s *service) issue(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, pubKey crypto.PublicKey, privKey crypto.PrivateKey) (Certificate, error) { serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return Certificate{}, err } - if s.intermediateCA.Certificate == nil || s.intermediateCA.PrivateKey == nil { - return Certificate{}, ErrIntermediateCANotFound + subject := s.getSubject(options) + if privKey != nil { + switch privKey.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey, *ed25519.PrivateKey: + break + default: + return Certificate{}, errors.Wrap(ErrCreateEntity, ErrPrivKeyType) + } + } + + switch pubKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey, *ed25519.PublicKey: + break + default: + return Certificate{}, errors.Wrap(ErrCreateEntity, ErrPubKeyType) } // Parse the TTL if provided, otherwise use the default certValidityPeriod. @@ -149,8 +146,6 @@ func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs [ } } - subject := s.getSubject(options) - template := x509.Certificate{ SerialNumber: serialNumber, Subject: subject, @@ -476,7 +471,7 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS return Certificate{}, errors.Wrap(ErrMalformedEntity, err) } - cert, err := s.IssueCert(ctx, entityID, ttl, nil, SubjectOptions{ + cert, err := s.issue(ctx, entityID, ttl, nil, SubjectOptions{ CommonName: parsedCSR.Subject.CommonName, Organization: parsedCSR.Subject.Organization, OrganizationalUnit: parsedCSR.Subject.OrganizationalUnit, @@ -485,7 +480,7 @@ func (s *service) IssueFromCSR(ctx context.Context, entityID, ttl string, csr CS Locality: parsedCSR.Subject.Locality, StreetAddress: parsedCSR.Subject.StreetAddress, PostalCode: parsedCSR.Subject.PostalCode, - }, parsedCSR.PublicKey) + }, parsedCSR.PublicKey, nil) if err != nil { return Certificate{}, errors.Wrap(ErrCreateEntity, err) } diff --git a/tracing/certs.go b/tracing/certs.go index 5ba671c..d766c67 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -5,7 +5,6 @@ package tracing import ( "context" - "crypto" "crypto/x509" "github.com/absmach/certs" @@ -54,10 +53,10 @@ func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error return tm.svc.RetrieveCAToken(ctx) } -func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions, privKey crypto.PrivateKey) (certs.Certificate, error) { +func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() - return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options, privKey) + return tm.svc.IssueCert(ctx, entityID, ttl, ipAddrs, options) } func (tm *tracingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadata) (certs.CertificatePage, error) { From 86ceb56a36d13744dbbe74ea50e0b8c6a0f44b11 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Mon, 9 Dec 2024 18:27:26 +0300 Subject: [PATCH 22/22] rename getsubject to subjectFromOpts Signed-off-by: nyagamunene --- service.go | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/service.go b/service.go index 1070d6c..59c35de 100644 --- a/service.go +++ b/service.go @@ -120,7 +120,7 @@ func (s *service) issue(ctx context.Context, entityID, ttl string, ipAddrs []str return Certificate{}, err } - subject := s.getSubject(options) + subject := subjectFromOpts(options) if privKey != nil { switch privKey.(type) { case *rsa.PrivateKey, *ecdsa.PrivateKey, *ed25519.PrivateKey: @@ -647,31 +647,31 @@ func (s *service) createIntermediateCA(ctx context.Context, rootCA *CA, config C return intermediateCA, nil } -func (s *service) getSubject(options SubjectOptions) pkix.Name { +func subjectFromOpts(opts SubjectOptions) pkix.Name { subject := pkix.Name{ - CommonName: options.CommonName, + CommonName: opts.CommonName, } - if len(options.Organization) > 0 { - subject.Organization = options.Organization + if len(opts.Organization) > 0 { + subject.Organization = opts.Organization } - if len(options.OrganizationalUnit) > 0 { - subject.OrganizationalUnit = options.OrganizationalUnit + if len(opts.OrganizationalUnit) > 0 { + subject.OrganizationalUnit = opts.OrganizationalUnit } - if len(options.Country) > 0 { - subject.Country = options.Country + if len(opts.Country) > 0 { + subject.Country = opts.Country } - if len(options.Province) > 0 { - subject.Province = options.Province + if len(opts.Province) > 0 { + subject.Province = opts.Province } - if len(options.Locality) > 0 { - subject.Locality = options.Locality + if len(opts.Locality) > 0 { + subject.Locality = opts.Locality } - if len(options.StreetAddress) > 0 { - subject.StreetAddress = options.StreetAddress + if len(opts.StreetAddress) > 0 { + subject.StreetAddress = opts.StreetAddress } - if len(options.PostalCode) > 0 { - subject.PostalCode = options.PostalCode + if len(opts.PostalCode) > 0 { + subject.PostalCode = opts.PostalCode } return subject