From bdaaa5fa60b6fae82261476cec5536c1d21bff2c Mon Sep 17 00:00:00 2001 From: Steve Munene Date: Mon, 18 Nov 2024 13:13:43 +0300 Subject: [PATCH] NOISSUE - Add delete certs feature (#46) * Add delete feature Signed-off-by: nyagamunene * Remove plural words in methods Signed-off-by: nyagamunene --------- Signed-off-by: nyagamunene --- api/http/endpoint.go | 19 +++++++++++++-- api/http/errors.go | 3 +++ api/http/requests.go | 11 +++++++++ api/http/responses.go | 24 ++++++++++++++++-- api/http/transport.go | 13 ++++++++++ api/logging.go | 12 +++++++++ api/metrics.go | 8 ++++++ certs.go | 6 +++++ cli/certs.go | 17 +++++++++++++ cli/certs_test.go | 57 +++++++++++++++++++++++++++++++++++++++++++ mocks/repository.go | 47 +++++++++++++++++++++++++++++++++++ mocks/service.go | 47 +++++++++++++++++++++++++++++++++++ postgres/certs.go | 15 ++++++++++++ sdk/certs_test.go | 56 +++++++++++++++++++++++++++++++++++++++++- sdk/mocks/sdk.go | 48 ++++++++++++++++++++++++++++++++++++ sdk/sdk.go | 15 +++++++++++- service.go | 4 +++ tracing/certs.go | 6 +++++ 18 files changed, 402 insertions(+), 6 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index 9c6a7fa..2f43f13 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -36,17 +36,32 @@ func revokeCertEndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { req := request.(viewReq) if err := req.validate(); err != nil { - return revokeCertRes{}, err + return revokeCertRes{revoked: false}, err } if err = svc.RevokeCert(ctx, req.id); err != nil { - return revokeCertRes{}, err + return revokeCertRes{revoked: false}, err } return revokeCertRes{revoked: true}, nil } } +func deleteCertEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(deleteReq) + if err := req.validate(); err != nil { + return deleteCertRes{deleted: false}, err + } + + if err = svc.RemoveCert(ctx, req.entityID); err != nil { + return deleteCertRes{deleted: false}, err + } + + return deleteCertRes{deleted: true}, nil + } +} + func requestCertDownloadTokenEndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { req := request.(viewReq) diff --git a/api/http/errors.go b/api/http/errors.go index e4123dc..bdcff3d 100644 --- a/api/http/errors.go +++ b/api/http/errors.go @@ -32,4 +32,7 @@ var ( // ErrMissingCN indicates missing common name. ErrMissingCN = errors.New("missing common name") + + // ErrEmptyEntityID indicates that the entity id is empty. + ErrEmptyEntityID = errors.New("missing entity id") ) diff --git a/api/http/requests.go b/api/http/requests.go index 6e77f47..4031b30 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -32,6 +32,17 @@ func (req viewReq) validate() error { return nil } +type deleteReq struct { + entityID string +} + +func (req deleteReq) validate() error { + if req.entityID == "" { + return errors.Wrap(certs.ErrMalformedEntity, ErrEmptyEntityID) + } + return nil +} + type crlReq struct { certtype certs.CertType } diff --git a/api/http/responses.go b/api/http/responses.go index 42b706b..f53ce6c 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -45,10 +45,10 @@ type revokeCertRes struct { func (res revokeCertRes) Code() int { if res.revoked { - return http.StatusOK + return http.StatusNoContent } - return http.StatusBadRequest + return http.StatusUnprocessableEntity } func (res revokeCertRes) Headers() map[string]string { @@ -59,6 +59,26 @@ func (res revokeCertRes) Empty() bool { return true } +type deleteCertRes struct { + deleted bool +} + +func (res deleteCertRes) Code() int { + if res.deleted { + return http.StatusNoContent + } + + return http.StatusUnprocessableEntity +} + +func (res deleteCertRes) Headers() map[string]string { + return map[string]string{} +} + +func (res deleteCertRes) Empty() bool { + return true +} + type requestCertDownloadTokenRes struct { Token string `json:"token"` } diff --git a/api/http/transport.go b/api/http/transport.go index ac03fd5..44aef4e 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -64,6 +64,12 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http EncodeResponse, opts..., ), "revoke_cert").ServeHTTP) + r.Delete("/{entityID}/delete", otelhttp.NewHandler(kithttp.NewServer( + deleteCertEndpoint(svc), + decodeDelete, + EncodeResponse, + opts..., + ), "delete_cert").ServeHTTP) r.Get("/{id}/download/token", otelhttp.NewHandler(kithttp.NewServer( requestCertDownloadTokenEndpoint(svc), decodeView, @@ -133,6 +139,13 @@ func decodeView(_ context.Context, r *http.Request) (interface{}, error) { return req, nil } +func decodeDelete(_ context.Context, r *http.Request) (interface{}, error) { + req := deleteReq{ + entityID: chi.URLParam(r, "entityID"), + } + return req, nil +} + func decodeCRL(_ context.Context, r *http.Request) (interface{}, error) { certType, err := readNumQuery(r, "", defType) if err != nil { diff --git a/api/logging.go b/api/logging.go index 9a42d21..731784d 100644 --- a/api/logging.go +++ b/api/logging.go @@ -109,6 +109,18 @@ func (lm *loggingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadat return lm.svc.ListCerts(ctx, pm) } +func (lm *loggingMiddleware) RemoveCert(ctx context.Context, entityId string) (err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method remove_cert took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.RemoveCert(ctx, entityId) +} + func (lm *loggingMiddleware) ViewCert(ctx context.Context, serialNumber string) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method view_cert for serial number %s took %s to complete", serialNumber, time.Since(begin)) diff --git a/api/metrics.go b/api/metrics.go index 379e34c..a5506c3 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -87,6 +87,14 @@ func (mm *metricsMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadat return mm.svc.ListCerts(ctx, pm) } +func (mm *metricsMiddleware) RemoveCert(ctx context.Context, entityId string) error { + defer func(begin time.Time) { + mm.counter.With("method", "remove_certificate").Add(1) + mm.latency.With("method", "remove_certificate").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.RemoveCert(ctx, entityId) +} + func (mm *metricsMiddleware) ViewCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "view_certificate").Add(1) diff --git a/certs.go b/certs.go index 2d3c9a2..7a88bee 100644 --- a/certs.go +++ b/certs.go @@ -70,6 +70,9 @@ type Service interface { // GetChainCA retrieves the chain of CA i.e. root and intermediate cert concat together. GetChainCA(ctx context.Context, token string) (Certificate, error) + + // RemoveCert deletes a cert for a provided entityID. + RemoveCert(ctx context.Context, entityId string) error } type Repository interface { @@ -90,4 +93,7 @@ type Repository interface { // ListRevokedCerts retrieves revoked lists from database. ListRevokedCerts(ctx context.Context) ([]Certificate, error) + + // RemoveCert deletes cert from database. + RemoveCert(ctx context.Context, entityId string) error } diff --git a/cli/certs.go b/cli/certs.go index 66203fb..7f9b1b9 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -70,6 +70,23 @@ var cmdCerts = []cobra.Command{ logOKCmd(*cmd) }, }, + { + Use: "delete ", + Short: "Delete certificate", + Long: `Deletes certificates for a given entity id.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + logUsageCmd(*cmd, cmd.Use) + return + } + err := sdk.DeleteCert(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logOKCmd(*cmd) + }, + }, { Use: "renew ", Short: "Renew certificate", diff --git a/cli/certs_test.go b/cli/certs_test.go index d2b1c86..9cd4444 100644 --- a/cli/certs_test.go +++ b/cli/certs_test.go @@ -22,6 +22,7 @@ import ( const ( revokeCmd = "revoke" + deleteCmd = "delete" issueCmd = "issue" renewCmd = "renew" listCmd = "get" @@ -175,6 +176,62 @@ func TestRevokeCertCmd(t *testing.T) { } } +func TestDeleteCertCmd(t *testing.T) { + sdkMock := new(sdkmocks.MockSDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + }{ + { + desc: "delete certs successfully", + args: []string{ + id, + }, + logType: okLog, + }, + { + desc: "delete certs with invalid args", + args: []string{ + id, + extraArg, + }, + logType: usageLog, + }, + { + desc: "delete certs failed", + args: []string{ + id, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("DeleteCerts", mock.Anything).Return(tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{deleteCmd}, tc.args...)...) + switch tc.logType { + case okLog: + assert.True(t, strings.Contains(out, "ok"), fmt.Sprintf("%s unexpected response: expected success message, got: %v", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + } + sdkCall.Unset() + }) + } +} + func TestRenewCertCmd(t *testing.T) { sdkMock := new(sdkmocks.MockSDK) cli.SetSDK(sdkMock) diff --git a/mocks/repository.go b/mocks/repository.go index 2ef8365..40a6640 100644 --- a/mocks/repository.go +++ b/mocks/repository.go @@ -261,6 +261,53 @@ func (_c *MockRepository_ListRevokedCerts_Call) RunAndReturn(run func(context.Co return _c } +// RemoveCert provides a mock function with given fields: ctx, entityId +func (_m *MockRepository) RemoveCert(ctx context.Context, entityId string) error { + ret := _m.Called(ctx, entityId) + + if len(ret) == 0 { + panic("no return value specified for RemoveCert") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, entityId) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockRepository_RemoveCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCert' +type MockRepository_RemoveCert_Call struct { + *mock.Call +} + +// RemoveCert is a helper method to define mock.On call +// - ctx context.Context +// - entityId string +func (_e *MockRepository_Expecter) RemoveCert(ctx interface{}, entityId interface{}) *MockRepository_RemoveCert_Call { + return &MockRepository_RemoveCert_Call{Call: _e.mock.On("RemoveCert", ctx, entityId)} +} + +func (_c *MockRepository_RemoveCert_Call) Run(run func(ctx context.Context, entityId string)) *MockRepository_RemoveCert_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockRepository_RemoveCert_Call) Return(_a0 error) *MockRepository_RemoveCert_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockRepository_RemoveCert_Call) RunAndReturn(run func(context.Context, string) error) *MockRepository_RemoveCert_Call { + _c.Call.Return(run) + return _c +} + // RetrieveCert provides a mock function with given fields: ctx, serialNumber func (_m *MockRepository) RetrieveCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { ret := _m.Called(ctx, serialNumber) diff --git a/mocks/service.go b/mocks/service.go index eba5610..aa91db8 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -393,6 +393,53 @@ func (_c *MockService_OCSP_Call) RunAndReturn(run func(context.Context, string) return _c } +// RemoveCert provides a mock function with given fields: ctx, entityId +func (_m *MockService) RemoveCert(ctx context.Context, entityId string) error { + ret := _m.Called(ctx, entityId) + + if len(ret) == 0 { + panic("no return value specified for RemoveCert") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, entityId) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// MockService_RemoveCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RemoveCert' +type MockService_RemoveCert_Call struct { + *mock.Call +} + +// RemoveCert is a helper method to define mock.On call +// - ctx context.Context +// - entityId string +func (_e *MockService_Expecter) RemoveCert(ctx interface{}, entityId interface{}) *MockService_RemoveCert_Call { + return &MockService_RemoveCert_Call{Call: _e.mock.On("RemoveCert", ctx, entityId)} +} + +func (_c *MockService_RemoveCert_Call) Run(run func(ctx context.Context, entityId string)) *MockService_RemoveCert_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockService_RemoveCert_Call) Return(_a0 error) *MockService_RemoveCert_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockService_RemoveCert_Call) RunAndReturn(run func(context.Context, string) error) *MockService_RemoveCert_Call { + _c.Call.Return(run) + return _c +} + // RenewCert provides a mock function with given fields: ctx, serialNumber func (_m *MockService) RenewCert(ctx context.Context, serialNumber string) error { ret := _m.Called(ctx, serialNumber) diff --git a/postgres/certs.go b/postgres/certs.go index cab1b81..9e2161f 100644 --- a/postgres/certs.go +++ b/postgres/certs.go @@ -199,6 +199,21 @@ func (repo certsRepo) ListRevokedCerts(ctx context.Context) ([]certs.Certificate return revokedCerts, nil } +func (repo certsRepo) RemoveCert(ctx context.Context, backendId string) error { + q := `DELETE FROM certs WHERE entity_id = $1` + + result, err := repo.db.ExecContext(ctx, q, backendId) + if err != nil { + return errors.Wrap(certs.ErrViewEntity, err) + } + + if rows, _ := result.RowsAffected(); rows == 0 { + return certs.ErrNotFound + } + + return nil +} + func (repo certsRepo) total(ctx context.Context, query string, params interface{}) (uint64, error) { rows, err := repo.db.NamedQueryContext(ctx, query, params) if err != nil { diff --git a/sdk/certs_test.go b/sdk/certs_test.go index da97a49..b905a1a 100644 --- a/sdk/certs_test.go +++ b/sdk/certs_test.go @@ -23,7 +23,7 @@ const ( instanceID = "5de9b29a-feb9-11ed-be56-0242ac120002" contentType = "application/senml+json" serialNum = "8e7a30c-bc9f-22de-ae67-1342bc139507" - id = "c333e6f1-59bb-4c39-9e13-3a2766af8ba5" + id = "c333e6f-59bb-4c39-9e13-3a2766af8ba5" ttl = "10h" commonName = "test" token = "token" @@ -203,6 +203,60 @@ func TestRevokeCert(t *testing.T) { } } +func TestDeleteCert(t *testing.T) { + ts, svc := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: contentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cases := []struct { + desc string + entityID string + svcresp string + svcerr error + err errors.SDKError + }{ + { + desc: "DeleteCert success", + entityID: id, + svcerr: nil, + err: nil, + }, + { + desc: "DeleteCert failure", + entityID: id, + svcerr: certs.ErrUpdateEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + }, + { + desc: "DeleteCert with empty entity id", + entityID: "", + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("RemoveCert", mock.Anything, tc.entityID).Return(tc.svcerr) + + err := ctsdk.DeleteCert(tc.entityID) + assert.Equal(t, tc.err, err) + if tc.desc != "DeleteCert with empty entity id" { + ok := svcCall.Parent.AssertCalled(t, "RemoveCert", mock.Anything, tc.entityID) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} + func TestRenewCert(t *testing.T) { ts, svc := setupCerts() defer ts.Close() diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index ba93512..fd994e6 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -27,6 +27,54 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { return &MockSDK_Expecter{mock: &_m.Mock} } +// DeleteCert provides a mock function with given fields: entityID +func (_m *MockSDK) DeleteCert(entityID string) errors.SDKError { + ret := _m.Called(entityID) + + if len(ret) == 0 { + panic("no return value specified for DeleteCert") + } + + var r0 errors.SDKError + if rf, ok := ret.Get(0).(func(string) errors.SDKError); ok { + r0 = rf(entityID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(errors.SDKError) + } + } + + return r0 +} + +// MockSDK_DeleteCert_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteCert' +type MockSDK_DeleteCert_Call struct { + *mock.Call +} + +// DeleteCert is a helper method to define mock.On call +// - entityID string +func (_e *MockSDK_Expecter) DeleteCert(entityID interface{}) *MockSDK_DeleteCert_Call { + return &MockSDK_DeleteCert_Call{Call: _e.mock.On("DeleteCert", entityID)} +} + +func (_c *MockSDK_DeleteCert_Call) Run(run func(entityID string)) *MockSDK_DeleteCert_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockSDK_DeleteCert_Call) Return(_a0 errors.SDKError) *MockSDK_DeleteCert_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *MockSDK_DeleteCert_Call) RunAndReturn(run func(string) errors.SDKError) *MockSDK_DeleteCert_Call { + _c.Call.Return(run) + return _c +} + // DownloadCA provides a mock function with given fields: token func (_m *MockSDK) DownloadCA(token string) (sdk.CertificateBundle, errors.SDKError) { ret := _m.Called(token) diff --git a/sdk/sdk.go b/sdk/sdk.go index 290cfac..9614ab7 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -141,6 +141,13 @@ type SDK interface { // fmt.Println(page) ListCerts(pm PageMetadata) (CertificatePage, errors.SDKError) + // DeleteCert deletes certificates for a given entityID. + // + // example: + // err := sdk.DeleteCert("entityID") + // fmt.Println(err) + DeleteCert(entityID string) errors.SDKError + // ViewCert retrieves a certificate record from the database. // // example: @@ -265,7 +272,7 @@ func (sdk mgSDK) ViewCert(serialNumber string) (Certificate, errors.SDKError) { func (sdk mgSDK) RevokeCert(serialNumber string) errors.SDKError { url := fmt.Sprintf("%s/%s/%s/revoke", sdk.certsURL, certsEndpoint, serialNumber) - _, _, sdkerr := sdk.processRequest(http.MethodPatch, url, nil, nil, http.StatusOK) + _, _, sdkerr := sdk.processRequest(http.MethodPatch, url, nil, nil, http.StatusNoContent) return sdkerr } @@ -291,6 +298,12 @@ func (sdk mgSDK) ListCerts(pm PageMetadata) (CertificatePage, errors.SDKError) { return cp, nil } +func (sdk mgSDK) DeleteCert(entityID string) errors.SDKError { + url := fmt.Sprintf("%s/%s/%s/delete", sdk.certsURL, certsEndpoint, entityID) + _, _, sdkerr := sdk.processRequest(http.MethodDelete, url, nil, nil, http.StatusNoContent) + return sdkerr +} + func (sdk mgSDK) RetrieveCertDownloadToken(serialNumber string) (Token, errors.SDKError) { url := fmt.Sprintf("%s/%s/%s/download/token", sdk.certsURL, certsEndpoint, serialNumber) _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) diff --git a/service.go b/service.go index 13a882b..c1dc667 100644 --- a/service.go +++ b/service.go @@ -271,6 +271,10 @@ func (s *service) ListCerts(ctx context.Context, pm PageMetadata) (CertificatePa return certPg, nil } +func (s *service) RemoveCert(ctx context.Context, entityId string) error { + return s.repo.RemoveCert(ctx, entityId) +} + func (s *service) ViewCert(ctx context.Context, serialNumber string) (Certificate, error) { cert, err := s.repo.RetrieveCert(ctx, serialNumber) if err != nil { diff --git a/tracing/certs.go b/tracing/certs.go index 7c3ea97..efac814 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -65,6 +65,12 @@ func (tm *tracingMiddleware) ListCerts(ctx context.Context, pm certs.PageMetadat return tm.svc.ListCerts(ctx, pm) } +func (tm *tracingMiddleware) RemoveCert(ctx context.Context, entityId string) (err error) { + ctx, span := tm.tracer.Start(ctx, "remove_cert") + defer span.End() + return tm.svc.RemoveCert(ctx, entityId) +} + func (s *tracingMiddleware) ViewCert(ctx context.Context, serialNumber string) (certs.Certificate, error) { ctx, span := s.tracer.Start(ctx, "view_cert") defer span.End()