From 4e808a1f91610c1dcde8dce0c7084d72f012d2c0 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Thu, 28 Nov 2024 02:45:34 +0300 Subject: [PATCH] fix sdk url paths Signed-off-by: nyagamunene --- api/http/endpoint.go | 8 ++--- api/http/errors.go | 3 ++ api/http/requests.go | 13 ++++---- api/http/responses.go | 4 +-- api/http/transport.go | 44 +++++++++++++++++++++---- api/logging.go | 4 +-- api/metrics.go | 4 +-- certs.go | 75 ++++++++++++++++++++++++++++++++++++------- cli/certs.go | 12 +++++-- go.mod | 2 +- go.sum | 4 +-- mocks/service.go | 31 +++++++++--------- postgres/csr/csr.go | 73 ++++++++++++++++++++++++++--------------- postgres/csr/init.go | 2 +- sdk/certs_test.go | 12 +++---- sdk/mocks/sdk.go | 30 ++++++++--------- sdk/sdk.go | 21 ++++++------ service.go | 23 ++++++------- tracing/certs.go | 4 +-- 19 files changed, 244 insertions(+), 125 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 29919a9..4907d4a 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -333,16 +333,16 @@ func signCSREndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { req := request.(SignCSRReq) if err := req.validate(); err != nil { - return signCSRRes{processed: false}, err + return signCSRRes{signed: false}, err } err = svc.SignCSR(ctx, req.csrID, req.approve) if err != nil { - return signCSRRes{processed: false}, err + return signCSRRes{signed: false}, err } return signCSRRes{ - processed: true, + signed: true, }, nil } } @@ -354,7 +354,7 @@ func listCSRsEndpoint(svc certs.Service) endpoint.Endpoint { return listCSRsRes{}, err } - cp, err := svc.ListCSRs(ctx, req.entityID, req.status) + cp, err := svc.ListCSRs(ctx, req.pm) if err != nil { return listCSRsRes{}, err } diff --git a/api/http/errors.go b/api/http/errors.go index e4123dc..dcbfb14 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -32,4 +32,7 @@ var ( // ErrMissingCN indicates missing common name. ErrMissingCN = errors.New("missing common name") + + // ErrMissingStatus indicates missing status. + ErrMissingStatus = errors.New("missing status") ) diff --git a/api/http/requests.go b/api/http/requests.go index 2af4a96..2fbb7de 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -91,8 +91,9 @@ func (req ocspReq) validate() error { } type createCSRReq struct { - Metadata certs.CSRMetadata `json:"metadata"` - privKey *rsa.PrivateKey + Metadata certs.CSRMetadata `json:"metadata"` + PrivateKey []byte `json:"private_Key"` + privKey *rsa.PrivateKey } func (req createCSRReq) validate() error { @@ -111,17 +112,17 @@ func (req SignCSRReq) validate() error { if req.csrID == "" { return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) } + return nil } type listCSRsReq struct { - entityID string - status string + pm certs.PageMetadata } func (req listCSRsReq) validate() error { - if req.entityID == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrMissingEntityID) + if req.pm.Status.String() == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrMissingStatus) } return nil } diff --git a/api/http/responses.go b/api/http/responses.go index f743a83..d24709f 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -213,7 +213,7 @@ func (res createCSRRes) Code() int { return http.StatusCreated } - return http.StatusOK + return http.StatusNoContent } func (res createCSRRes) Headers() map[string]string { @@ -225,7 +225,7 @@ func (res createCSRRes) Empty() bool { } type signCSRRes struct { - processed bool + signed bool } func (res signCSRRes) Code() int { diff --git a/api/http/transport.go b/api/http/transport.go index 52622a8..8293eee 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -7,8 +7,10 @@ import ( "archive/zip" "bytes" "context" + "crypto/x509" "encoding/asn1" "encoding/json" + "encoding/pem" "fmt" "io" "log/slog" @@ -18,7 +20,7 @@ import ( "github.com/absmach/certs" "github.com/absmach/certs/errors" - "github.com/go-chi/chi" + "github.com/go-chi/chi/v5" kithttp "github.com/go-kit/kit/transport/http" "github.com/prometheus/client_golang/prometheus/promhttp" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -140,7 +142,7 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http opts..., ), "download_ca").ServeHTTP) r.Route("/csr", func(r chi.Router) { - r.Post("/", otelhttp.NewHandler(kithttp.NewServer( + r.Post("/{entityID}", otelhttp.NewHandler(kithttp.NewServer( createCSREndpoint(svc), decodeCreateCSR, EncodeResponse, @@ -291,10 +293,22 @@ func decodeListCerts(_ context.Context, r *http.Request) (interface{}, error) { func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) { req := createCSRReq{} + req.Metadata.EntityID = chi.URLParam(r, "entityID") if err := json.NewDecoder(r.Body).Decode(&req); err != nil { return nil, err } + if len(req.PrivateKey) > 0 { + block, _ := pem.Decode(req.PrivateKey) + if block != nil { + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, errors.Wrap(ErrInvalidRequest, err) + } + req.privKey = privateKey + } + } + return req, nil } @@ -321,7 +335,17 @@ func decodeRetrieveCSR(_ context.Context, r *http.Request) (interface{}, error) } func decodeListCSR(_ context.Context, r *http.Request) (interface{}, error) { - s, err := readStringQuery(r, status, "all") + o, err := readNumQuery(r, offsetKey, defOffset) + if err != nil { + return nil, err + } + + l, err := readNumQuery(r, limitKey, defLimit) + if err != nil { + return nil, err + } + + s, err := readStringQuery(r, status, "") if err != nil { return nil, err } @@ -330,11 +354,19 @@ func decodeListCSR(_ context.Context, r *http.Request) (interface{}, error) { return nil, err } - req := listCSRsReq{ - entityID: e, - status: s, + stat, err := certs.ParseCSRStatus(strings.ToLower(s)) + if err != nil { + return nil, err } + req := listCSRsReq{ + pm: certs.PageMetadata{ + Offset: o, + Limit: l, + EntityID: e, + Status: stat, + }, + } return req, nil } diff --git a/api/logging.go b/api/logging.go index 905ed28..893f521 100644 --- a/api/logging.go +++ b/api/logging.go @@ -206,7 +206,7 @@ func (lm *loggingMiddleware) SignCSR(ctx context.Context, csrID string, approve return lm.svc.SignCSR(ctx, csrID, approve) } -func (lm *loggingMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (cp certs.CSRPage, err error) { +func (lm *loggingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (cp certs.CSRPage, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method list_csrs took %s to complete", time.Since(begin)) if err != nil { @@ -215,7 +215,7 @@ func (lm *loggingMiddleware) ListCSRs(ctx context.Context, entityID string, stat } lm.logger.Info(message) }(time.Now()) - return lm.svc.ListCSRs(ctx, entityID, status) + return lm.svc.ListCSRs(ctx, pm) } func (lm *loggingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (csr certs.CSR, err error) { diff --git a/api/metrics.go b/api/metrics.go index 9747e10..8610458 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -153,12 +153,12 @@ func (mm *metricsMiddleware) SignCSR(ctx context.Context, csrID string, approve return mm.svc.SignCSR(ctx, csrID, approve) } -func (mm *metricsMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { +func (mm *metricsMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { defer func(begin time.Time) { mm.counter.With("method", "list_csrs").Add(1) mm.latency.With("method", "list_csrs").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.ListCSRs(ctx, entityID, status) + return mm.svc.ListCSRs(ctx, pm) } func (mm *metricsMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) { diff --git a/certs.go b/certs.go index 59ae29a..f17ad50 100644 --- a/certs.go +++ b/certs.go @@ -7,6 +7,7 @@ import ( "context" "crypto/rsa" "crypto/x509" + "encoding/json" "net" "time" @@ -54,6 +55,56 @@ func CertTypeFromString(s string) (CertType, error) { } } +type CSRStatus int + +const ( + Pending CSRStatus = iota + Signed + Rejected + All +) + +const ( + pending = "pending" + signed = "signed" + rejected = "rejected" + all = "all" +) + +func (c CSRStatus) String() string { + switch c { + case Pending: + return pending + case Signed: + return signed + case Rejected: + return rejected + case All: + return all + default: + return Unknown + } +} + +func ParseCSRStatus(s string) (CSRStatus, error) { + switch s { + case pending: + return Pending, nil + case signed: + return Signed, nil + case rejected: + return Rejected, nil + case all: + return All, nil + default: + return -1, errors.New("unknown CSR status") + } +} + +func (c CSRStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(c.String()) +} + type CA struct { Type CertType Certificate *x509.Certificate @@ -78,16 +129,16 @@ type CertificatePage struct { } type PageMetadata struct { - Total uint64 `json:"total,omitempty" db:"total"` - Offset uint64 `json:"offset,omitempty" db:"offset"` - Limit uint64 `json:"limit,omitempty" db:"limit"` - EntityID string `json:"entity_id,omitempty" db:"entity_id"` - Status string `json:"status,omitempty" db:"status"` + Total uint64 `json:"total" db:"total"` + Offset uint64 `json:"offset,omitempty" db:"offset"` + Limit uint64 `json:"limit" db:"limit"` + EntityID string `json:"entity_id,omitempty" db:"entity_id"` + Status CSRStatus `json:"status,omitempty" db:"status"` } type CSRMetadata struct { + EntityID string CommonName string `json:"common_name"` - EntityID string `json:"entity_id"` Organization []string `json:"organization"` OrganizationalUnit []string `json:"organizational_unit"` Country []string `json:"country"` @@ -102,18 +153,18 @@ type CSRMetadata struct { type CSR struct { ID string `json:"id" db:"id"` - CSR []byte `json:"csr" db:"csr"` - PrivateKey []byte `json:"private_key" db:"private_key"` + CSR []byte `json:"csr,omitempty" db:"csr"` + PrivateKey []byte `json:"private_key,omitempty" db:"private_key"` EntityID string `json:"entity_id" db:"entity_id"` - Status string `json:"status" db:"status"` + Status CSRStatus `json:"status" db:"status"` SubmittedAt time.Time `json:"submitted_at" db:"submitted_at"` - ProcessedAt time.Time `json:"processed_at" db:"processed_at"` + ProcessedAt time.Time `json:"processed_at,omitempty" db:"processed_at"` SerialNumber string `json:"serial_number" db:"serial_number"` } type CSRPage struct { PageMetadata - CSRs []CSR + CSRs []CSR `json:"csrs,omitempty"` } type SubjectOptions struct { @@ -190,7 +241,7 @@ type Service interface { SignCSR(ctx context.Context, csrID string, approve bool) error // ListCSRs returns a list of CSRs based on filter criteria - ListCSRs(ctx context.Context, entityID string, status string) (CSRPage, error) + ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) // RetrieveCSR retrieves a specific CSR by ID RetrieveCSR(ctx context.Context, csrID string) (CSR, error) diff --git a/cli/certs.go b/cli/certs.go index 4159236..2cf131b 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -255,10 +255,11 @@ var cmdCerts = []cobra.Command{ logErrorCmd(*cmd, err) return } + var csr ctxsdk.CSR var err error if len(args) == 1 { - csr, err = sdk.CreateCSR(pm, "") + csr, err = sdk.CreateCSR(pm, []byte{}) if err != nil { logErrorCmd(*cmd, err) return @@ -266,7 +267,14 @@ var cmdCerts = []cobra.Command{ logJSONCmd(*cmd, csr) return } - csr, err = sdk.CreateCSR(pm, args[1]) + + data, err := os.ReadFile(args[1]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + + csr, err = sdk.CreateCSR(pm, data) if err != nil { logErrorCmd(*cmd, err) return diff --git a/go.mod b/go.mod index a579d5c..478cc3b 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,7 @@ go 1.23.0 require ( github.com/caarlos0/env/v10 v10.0.0 github.com/fatih/color v1.18.0 - github.com/go-chi/chi v4.1.2+incompatible + github.com/go-chi/chi/v5 v5.1.0 github.com/go-kit/kit v0.13.0 github.com/gofrs/uuid v4.4.0+incompatible github.com/golang-jwt/jwt v3.2.2+incompatible diff --git a/go.sum b/go.sum index 393756d..8edef24 100644 --- a/go.sum +++ b/go.sum @@ -39,8 +39,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/go-chi/chi v4.1.2+incompatible h1:fGFk2Gmi/YKXk0OmGfBh0WgmN3XB8lVnEyNz34tQRec= -github.com/go-chi/chi v4.1.2+incompatible/go.mod h1:eB3wogJHnLi3x/kFX2A+IbTBlXxmMeXJVKy9tTv1XzQ= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-gorp/gorp/v3 v3.1.0 h1:ItKF/Vbuj31dmV4jxA1qblpSwkl9g1typ24xoe70IGs= github.com/go-gorp/gorp/v3 v3.1.0/go.mod h1:dLEjIyyRNiXvNZ8PSmzpt1GsWAUK8kjVhEpjH8TixEw= github.com/go-kit/kit v0.13.0 h1:OoneCcHKHQ03LfBpoQCUfCluwd2Vt3ohz+kvbJneZAU= diff --git a/mocks/service.go b/mocks/service.go index fd28c58..b7f2c1f 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -351,9 +351,9 @@ func (_c *MockService_IssueCert_Call) RunAndReturn(run func(context.Context, str return _c } -// ListCSRs provides a mock function with given fields: ctx, entityID, status -func (_m *MockService) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { - ret := _m.Called(ctx, entityID, status) +// ListCSRs provides a mock function with given fields: ctx, pm +func (_m *MockService) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { + ret := _m.Called(ctx, pm) if len(ret) == 0 { panic("no return value specified for ListCSRs") @@ -361,17 +361,17 @@ func (_m *MockService) ListCSRs(ctx context.Context, entityID string, status str var r0 certs.CSRPage var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, string) (certs.CSRPage, error)); ok { - return rf(ctx, entityID, status) + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) (certs.CSRPage, error)); ok { + return rf(ctx, pm) } - if rf, ok := ret.Get(0).(func(context.Context, string, string) certs.CSRPage); ok { - r0 = rf(ctx, entityID, status) + if rf, ok := ret.Get(0).(func(context.Context, certs.PageMetadata) certs.CSRPage); ok { + r0 = rf(ctx, pm) } else { r0 = ret.Get(0).(certs.CSRPage) } - if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { - r1 = rf(ctx, entityID, status) + if rf, ok := ret.Get(1).(func(context.Context, certs.PageMetadata) error); ok { + r1 = rf(ctx, pm) } else { r1 = ret.Error(1) } @@ -386,15 +386,14 @@ type MockService_ListCSRs_Call struct { // ListCSRs is a helper method to define mock.On call // - ctx context.Context -// - entityID string -// - status string -func (_e *MockService_Expecter) ListCSRs(ctx interface{}, entityID interface{}, status interface{}) *MockService_ListCSRs_Call { - return &MockService_ListCSRs_Call{Call: _e.mock.On("ListCSRs", ctx, entityID, status)} +// - pm certs.PageMetadata +func (_e *MockService_Expecter) ListCSRs(ctx interface{}, pm interface{}) *MockService_ListCSRs_Call { + return &MockService_ListCSRs_Call{Call: _e.mock.On("ListCSRs", ctx, pm)} } -func (_c *MockService_ListCSRs_Call) Run(run func(ctx context.Context, entityID string, status string)) *MockService_ListCSRs_Call { +func (_c *MockService_ListCSRs_Call) Run(run func(ctx context.Context, pm certs.PageMetadata)) *MockService_ListCSRs_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(string), args[2].(string)) + run(args[0].(context.Context), args[1].(certs.PageMetadata)) }) return _c } @@ -404,7 +403,7 @@ func (_c *MockService_ListCSRs_Call) Return(_a0 certs.CSRPage, _a1 error) *MockS return _c } -func (_c *MockService_ListCSRs_Call) RunAndReturn(run func(context.Context, string, string) (certs.CSRPage, error)) *MockService_ListCSRs_Call { +func (_c *MockService_ListCSRs_Call) RunAndReturn(run func(context.Context, certs.PageMetadata) (certs.CSRPage, error)) *MockService_ListCSRs_Call { _c.Call.Return(run) return _c } diff --git a/postgres/csr/csr.go b/postgres/csr/csr.go index 58194ae..54d6c24 100644 --- a/postgres/csr/csr.go +++ b/postgres/csr/csr.go @@ -7,6 +7,8 @@ import ( "context" "database/sql" "fmt" + "log" + "strings" "github.com/absmach/certs" "github.com/absmach/certs/errors" @@ -52,9 +54,9 @@ func (repo CSRRepo) CreateCSR(ctx context.Context, csr certs.CSR) error { return nil } -func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { - q := `UPDATE csr SET certificate = :certificate, key = :key, revoked = :revoked, expiry_time = :expiry_time WHERE serial_number = :serial_number` - res, err := repo.db.NamedExecContext(ctx, q, cert) +func (repo CSRRepo) UpdateCSR(ctx context.Context, csr certs.CSR) error { + q := `UPDATE csr SET serial_number = :serial_number, status = :status, private_key = :private_key, submitted_at = :submitted_at, processed_at = :processed_at WHERE id = :id` + res, err := repo.db.NamedExecContext(ctx, q, csr) if err != nil { return handleError(certs.ErrUpdateEntity, err) } @@ -68,8 +70,8 @@ func (repo CSRRepo) UpdateCSR(ctx context.Context, cert certs.CSR) error { return nil } -func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error) { - q := `SELECT serial_number, certificate, key, entity_id, revoked, expiry_time FROM csr WHERE id = $1` +func (repo CSRRepo) RetrieveCSR(ctx context.Context, id string) (certs.CSR, error) { + q := `SELECT id, serial_number, csr, private_key, entity_id, status, submitted_at, processed_at FROM csr WHERE id = $1` var csr certs.CSR if err := repo.db.QueryRowxContext(ctx, q, id).StructScan(&csr); err != nil { if err == sql.ErrNoRows { @@ -81,44 +83,65 @@ func (repo CSRRepo) RetrieveCSR(ctx context.Context,id string) (certs.CSR, error } func (repo CSRRepo) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { - q := `SELECT serial_number, status, submitted_at, processed_at, entity_id FROM csr %s LIMIT :limit OFFSET :offset` - var condition string + var query []string + params := map[string]interface{}{ + "limit": pm.Limit, + "offset": pm.Offset, + } if pm.EntityID != "" { - condition = `WHERE entity_id = :entity_id` - } else { - condition = `` + query = append(query, `c.entity_id = :entity_id`) + params["entity_id"] = pm.EntityID + } + if pm.Status != certs.All { + query = append(query, `c.status = :status`) + params["status"] = pm.Status } - q = fmt.Sprintf(q, condition) - var csrs []certs.CSR - params := map[string]interface{}{ - "limit": pm.Limit, - "offset": pm.Offset, - "entity_id": pm.EntityID, + var str string + if len(query) > 0 { + str = fmt.Sprintf(`WHERE %s`, strings.Join(query, ` AND `)) } - rows, err := repo.db.NamedQueryContext(ctx, q, params) + + q := fmt.Sprintf(` + SELECT + c.id, + c.serial_number, + c.submitted_at, + c.processed_at, + c.entity_id + FROM csr c %s LIMIT :limit OFFSET :offset;`, str) + + log.Printf("Query: %s", q) + log.Printf("Parameters: %+v", pm) + rows, err := repo.db.NamedQueryContext(ctx, q, pm) if err != nil { return certs.CSRPage{}, handleError(certs.ErrViewEntity, err) } defer rows.Close() - + log.Printf("row : %+v", rows) + var csrs []certs.CSR for rows.Next() { - csr := &certs.CSR{} - if err := rows.StructScan(csr); err != nil { + csr := certs.CSR{} + if err := rows.StructScan(&csr); err != nil { + log.Printf("StructScan error: %v", err) return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) } - - csrs = append(csrs, *csr) + log.Printf("Scanned CSR: %+v", csr) + csrs = append(csrs, csr) } - q = fmt.Sprintf(`SELECT COUNT(*) FROM csr %s LIMIT :limit OFFSET :offset`, condition) - pm.Total, err = repo.total(ctx, q, params) + if len(csrs) == 0 { + log.Println("No CSRs found matching the query") + } + + cq := fmt.Sprintf(`SELECT COUNT(*) FROM csr c %s;`, str) + pm.Total, err = repo.total(ctx, cq, pm) if err != nil { return certs.CSRPage{}, errors.Wrap(certs.ErrViewEntity, err) } return certs.CSRPage{ PageMetadata: pm, - CSRs: csrs, + CSRs: csrs, }, nil } diff --git a/postgres/csr/init.go b/postgres/csr/init.go index b31c055..b95067c 100644 --- a/postgres/csr/init.go +++ b/postgres/csr/init.go @@ -20,7 +20,7 @@ func Migration() *migrate.MemoryMigrationSource { csr TEXT, private_key TEXT, entity_id VARCHAR(36), - status TEXT, + status TEXT CHECK (status IN ('pending', 'signed', 'rejected')), submitted_at TIMESTAMP, processed_at TIMESTAMP )`, diff --git a/sdk/certs_test.go b/sdk/certs_test.go index b905a1a..54f2f14 100644 --- a/sdk/certs_test.go +++ b/sdk/certs_test.go @@ -649,12 +649,12 @@ func TestDownloadCACert(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("GetSigningCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("GetChainCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) _, err := ctsdk.DownloadCA(tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "GetSigningCA", mock.Anything, tc.token) + ok := svcCall.Parent.AssertCalled(t, "GetChainCA", mock.Anything, tc.token) assert.True(t, ok) } svcCall.Unset() @@ -709,12 +709,12 @@ func TestViewCA(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("GetSigningCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("GetChainCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) c, err := ctsdk.ViewCA(tc.token) assert.Equal(t, tc.err, err) if tc.err == nil { - ok := svcCall.Parent.AssertCalled(t, "GetSigningCA", mock.Anything, tc.token) + ok := svcCall.Parent.AssertCalled(t, "GetChainCA", mock.Anything, tc.token) assert.True(t, ok) } assert.Equal(t, tc.sdkCert.Certificate, c.Certificate, fmt.Sprintf("expected: %v, got: %v", tc.sdkCert.Certificate, c.Certificate)) @@ -765,13 +765,13 @@ func TestGetCAToken(t *testing.T) { for _, tc := range cases { t.Run(tc.desc, func(t *testing.T) { - svcCall := svc.On("RetrieveCertDownloadToken", mock.Anything).Return(tc.svcresp, tc.svcerr) + svcCall := svc.On("RetrieveCAToken", mock.Anything).Return(tc.svcresp, tc.svcerr) resp, err := ctsdk.GetCAToken() assert.Equal(t, tc.err, err) if tc.err == nil { assert.Equal(t, tc.svcresp, resp.Token) - ok := svcCall.Parent.AssertCalled(t, "RetrieveCertDownloadToken", mock.Anything) + ok := svcCall.Parent.AssertCalled(t, "RetrieveCAToken", mock.Anything) assert.True(t, ok) } svcCall.Unset() diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index 5ccd73e..23be435 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -25,9 +25,9 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { return &MockSDK_Expecter{mock: &_m.Mock} } -// CreateCSR provides a mock function with given fields: pm, privKeyPath -func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKeyPath string) (sdk.CSR, errors.SDKError) { - ret := _m.Called(pm, privKeyPath) +// CreateCSR provides a mock function with given fields: pm, privKey +func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKey []byte) (sdk.CSR, errors.SDKError) { + ret := _m.Called(pm, privKey) if len(ret) == 0 { panic("no return value specified for CreateCSR") @@ -35,17 +35,17 @@ func (_m *MockSDK) CreateCSR(pm sdk.PageMetadata, privKeyPath string) (sdk.CSR, var r0 sdk.CSR var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)); ok { - return rf(pm, privKeyPath) + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, []byte) (sdk.CSR, errors.SDKError)); ok { + return rf(pm, privKey) } - if rf, ok := ret.Get(0).(func(sdk.PageMetadata, string) sdk.CSR); ok { - r0 = rf(pm, privKeyPath) + if rf, ok := ret.Get(0).(func(sdk.PageMetadata, []byte) sdk.CSR); ok { + r0 = rf(pm, privKey) } else { r0 = ret.Get(0).(sdk.CSR) } - if rf, ok := ret.Get(1).(func(sdk.PageMetadata, string) errors.SDKError); ok { - r1 = rf(pm, privKeyPath) + if rf, ok := ret.Get(1).(func(sdk.PageMetadata, []byte) errors.SDKError); ok { + r1 = rf(pm, privKey) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -62,14 +62,14 @@ type MockSDK_CreateCSR_Call struct { // CreateCSR is a helper method to define mock.On call // - pm sdk.PageMetadata -// - privKeyPath string -func (_e *MockSDK_Expecter) CreateCSR(pm interface{}, privKeyPath interface{}) *MockSDK_CreateCSR_Call { - return &MockSDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", pm, privKeyPath)} +// - privKey []byte +func (_e *MockSDK_Expecter) CreateCSR(pm interface{}, privKey interface{}) *MockSDK_CreateCSR_Call { + return &MockSDK_CreateCSR_Call{Call: _e.mock.On("CreateCSR", pm, privKey)} } -func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKeyPath string)) *MockSDK_CreateCSR_Call { +func (_c *MockSDK_CreateCSR_Call) Run(run func(pm sdk.PageMetadata, privKey []byte)) *MockSDK_CreateCSR_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(sdk.PageMetadata), args[1].(string)) + run(args[0].(sdk.PageMetadata), args[1].([]byte)) }) return _c } @@ -79,7 +79,7 @@ func (_c *MockSDK_CreateCSR_Call) Return(_a0 sdk.CSR, _a1 errors.SDKError) *Mock return _c } -func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, string) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { +func (_c *MockSDK_CreateCSR_Call) RunAndReturn(run func(sdk.PageMetadata, []byte) (sdk.CSR, errors.SDKError)) *MockSDK_CreateCSR_Call { _c.Call.Return(run) return _c } diff --git a/sdk/sdk.go b/sdk/sdk.go index 5af513a..86d0228 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -279,10 +279,10 @@ type SDK interface { // CreateCSR creates a new Certificate Signing Request // // example: - // pm = sdk.CSRMetadata{CommonName: "common_name", "entity_id" } - // reponse, _ := sdk.CreateCSR(pm, "privKeyPath") + // pm = sdk.CSRMetadata{CommonName: "common_name", EntityID: "entity_id" } + // reponse, _ := sdk.CreateCSR(pm, []bytes("privKey")) // fmt.Println(response) - CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDKError) + CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) // SignCSR processes a pending CSR and either signs or rejects it // @@ -570,7 +570,7 @@ func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) { return tk, nil } -func (sdk mgSDK) CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDKError) { +func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) { r := csrReq{ Organization: pm.Organization, OrganizationalUnit: pm.OrganizationalUnit, @@ -582,12 +582,13 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKeyPath string) (CSR, errors.SDK DNSNames: pm.DNSNames, IPAddresses: pm.IPAddresses, EmailAddresses: pm.EmailAddresses, + PrivateKey: privKey, } d, err := json.Marshal(r) if err != nil { return CSR{}, errors.NewSDKError(err) } - url := fmt.Sprintf("%s/%s", sdk.certsURL, csrEndpoint) + url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint, pm.EntityID) _, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK) if sdkerr != nil { return CSR{}, sdkerr @@ -604,7 +605,7 @@ func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { pm := PageMetadata{ Sign: sign, } - url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s", certsEndpoint, csrID), pm) + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/%s", certsEndpoint, csrEndpoint, csrID), pm) if err != nil { return errors.NewSDKError(err) } @@ -617,11 +618,10 @@ func (sdk mgSDK) SignCSR(csrID string, sign bool) errors.SDKError { } func (sdk mgSDK) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) { - url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/list", csrEndpoint), pm) + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/%s/list", certsEndpoint, csrEndpoint), pm) if err != nil { return CSRPage{}, errors.NewSDKError(err) } - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) if sdkerr != nil { return CSRPage{}, sdkerr @@ -635,9 +635,9 @@ func (sdk mgSDK) ListCSRs(pm PageMetadata) (CSRPage, errors.SDKError) { } func (sdk mgSDK) RetrieveCSR(csrID string) (CSR, errors.SDKError) { - url := fmt.Sprintf("%s/%s/%s", sdk.certsURL, csrEndpoint, csrID) + url := fmt.Sprintf("%s/%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint, csrID) - _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusCreated) if sdkerr != nil { return CSR{}, sdkerr } @@ -771,4 +771,5 @@ type csrReq struct { DNSNames []string `json:"dns_names"` IPAddresses []string `json:"ip_addresses"` EmailAddresses []string `json:"email_addresses"` + PrivateKey []byte `json:"private_key"` } diff --git a/service.go b/service.go index da2da54..3867230 100644 --- a/service.go +++ b/service.go @@ -92,16 +92,16 @@ func NewService(ctx context.Context, repo Repository, csrRepo CSRRepository, con // The certificate is then stored in the repository using the CreateCert method. // If the root CA is not found, it returns an error. func (s *service) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options SubjectOptions, key ...*rsa.PrivateKey) (Certificate, error) { + var privKey rsa.PrivateKey var err error - privKey := rsa.PrivateKey{} if len(key) == 0 { pKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) privKey = *pKey if err != nil { return Certificate{}, err - } else { - privKey = *key[0] } + } else { + privKey = *key[0] } serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { @@ -468,7 +468,7 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, entityID CSR: csrPEM, PrivateKey: privKeyPEM, EntityID: entityID, - Status: "pending", + Status: Pending, SubmittedAt: time.Now(), } @@ -486,7 +486,7 @@ func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error } if !approve { - csr.Status = "rejected" + csr.Status = Rejected csr.ProcessedAt = time.Now() return s.csrRepo.UpdateCSR(ctx, csr) } @@ -524,19 +524,20 @@ func (s *service) SignCSR(ctx context.Context, csrID string, approve bool) error return errors.Wrap(ErrCreateEntity, err) } - csr.Status = "approved" + csr.Status = Signed csr.ProcessedAt = time.Now() csr.SerialNumber = cert.SerialNumber return s.csrRepo.UpdateCSR(ctx, csr) } -func (s *service) ListCSRs(ctx context.Context, entityID string, status string) (CSRPage, error) { - pm := PageMetadata{ - EntityID: entityID, - Status: status, +func (s *service) ListCSRs(ctx context.Context, pm PageMetadata) (CSRPage, error) { + cp, err := s.csrRepo.ListCSRs(ctx, pm) + if err != nil { + return CSRPage{}, errors.Wrap(ErrViewEntity, err) } - return s.csrRepo.ListCSRs(ctx, pm) + + return cp, nil } func (s *service) RetrieveCSR(ctx context.Context, csrID string) (CSR, error) { diff --git a/tracing/certs.go b/tracing/certs.go index 71ee502..ec9c699 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -114,10 +114,10 @@ func (tm *tracingMiddleware) SignCSR(ctx context.Context, csrID string, approve return tm.svc.SignCSR(ctx, csrID, approve) } -func (tm *tracingMiddleware) ListCSRs(ctx context.Context, entityID string, status string) (certs.CSRPage, error) { +func (tm *tracingMiddleware) ListCSRs(ctx context.Context, pm certs.PageMetadata) (certs.CSRPage, error) { ctx, span := tm.tracer.Start(ctx, "list_csrs") defer span.End() - return tm.svc.ListCSRs(ctx, entityID, status) + return tm.svc.ListCSRs(ctx, pm) } func (tm *tracingMiddleware) RetrieveCSR(ctx context.Context, csrID string) (certs.CSR, error) {