From 6cadaa3dfaab208b080b8c22aea24d882db9fb32 Mon Sep 17 00:00:00 2001 From: nyagamunene Date: Wed, 16 Oct 2024 18:35:49 +0300 Subject: [PATCH] Get chain of certs Signed-off-by: nyagamunene --- api/http/endpoint.go | 4 +-- api/logging.go | 6 ++-- api/metrics.go | 8 ++--- certs.go | 4 +-- cli/utils.go | 6 ++-- mocks/service.go | 80 ++++++++++++++++++++++---------------------- service.go | 28 +++++++++++++--- tracing/certs.go | 6 ++-- 8 files changed, 80 insertions(+), 62 deletions(-) diff --git a/api/http/endpoint.go b/api/http/endpoint.go index c5b9810..9c6a7fa 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -262,7 +262,7 @@ func downloadCAEndpoint(svc certs.Service) endpoint.Endpoint { return fileDownloadRes{}, err } - cert, err := svc.GetSigningCA(ctx, req.token) + cert, err := svc.GetChainCA(ctx, req.token) if err != nil { return fileDownloadRes{}, err } @@ -283,7 +283,7 @@ func viewCAEndpoint(svc certs.Service) endpoint.Endpoint { return viewCertRes{}, err } - cert, err := svc.GetSigningCA(ctx, req.token) + cert, err := svc.GetChainCA(ctx, req.token) if err != nil { return viewCertRes{}, err } diff --git a/api/logging.go b/api/logging.go index 4a6c88a..9a42d21 100644 --- a/api/logging.go +++ b/api/logging.go @@ -157,14 +157,14 @@ func (lm *loggingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT return lm.svc.GenerateCRL(ctx, caType) } -func (lm *loggingMiddleware) GetSigningCA(ctx context.Context, token string) (cert certs.Certificate, err error) { +func (lm *loggingMiddleware) GetChainCA(ctx context.Context, token string) (cert certs.Certificate, err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method get_signing_ca took %s to complete", time.Since(begin)) + message := fmt.Sprintf("Method get_chain_ca took %s to complete", time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return } lm.logger.Info(message) }(time.Now()) - return lm.svc.GetSigningCA(ctx, token) + return lm.svc.GetChainCA(ctx, token) } diff --git a/api/metrics.go b/api/metrics.go index f1ce5c8..379e34c 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -120,10 +120,10 @@ func (mm *metricsMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT return mm.svc.GenerateCRL(ctx, caType) } -func (mm *metricsMiddleware) GetSigningCA(ctx context.Context, token string) (certs.Certificate, error) { +func (mm *metricsMiddleware) GetChainCA(ctx context.Context, token string) (certs.Certificate, error) { defer func(begin time.Time) { - mm.counter.With("method", "get_signing_ca").Add(1) - mm.latency.With("method", "get_signing_ca").Observe(time.Since(begin).Seconds()) + mm.counter.With("method", "get_chain_ca").Add(1) + mm.latency.With("method", "get_chain_ca").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.GetSigningCA(ctx, token) + return mm.svc.GetChainCA(ctx, token) } diff --git a/certs.go b/certs.go index 92629f8..2d3c9a2 100644 --- a/certs.go +++ b/certs.go @@ -68,8 +68,8 @@ type Service interface { // GenerateCRL creates cert revocation list. GenerateCRL(ctx context.Context, caType CertType) ([]byte, error) - // Retrieves the signing CA. - GetSigningCA(ctx context.Context, token string) (Certificate, error) + // GetChainCA retrieves the chain of CA i.e. root and intermediate cert concat together. + GetChainCA(ctx context.Context, token string) (Certificate, error) } type Repository interface { diff --git a/cli/utils.go b/cli/utils.go index 6ff37ce..d5e413e 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -6,7 +6,6 @@ package cli import ( "encoding/json" "fmt" - "io/fs" "os" "path/filepath" @@ -16,7 +15,7 @@ import ( "github.com/spf13/cobra" ) -const fileMode = fs.FileMode(600) +const fileMode = 0o644 var ( // Limit query parameter. @@ -106,8 +105,7 @@ func saveToFile(filename string, content []byte) error { } filePath := filepath.Join(cwd, filename) - err = os.WriteFile(filePath, content, fileMode) - if err != nil { + if err := os.WriteFile(filePath, content, fileMode); err != nil { return fmt.Errorf("failed to write file %s: %w", filename, err) } diff --git a/mocks/service.go b/mocks/service.go index cf6d732..eba5610 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -87,27 +87,27 @@ func (_c *MockService_GenerateCRL_Call) RunAndReturn(run func(context.Context, c return _c } -// GetEntityID provides a mock function with given fields: ctx, serialNumber -func (_m *MockService) GetEntityID(ctx context.Context, serialNumber string) (string, error) { - ret := _m.Called(ctx, serialNumber) +// GetChainCA provides a mock function with given fields: ctx, token +func (_m *MockService) GetChainCA(ctx context.Context, token string) (certs.Certificate, error) { + ret := _m.Called(ctx, token) if len(ret) == 0 { - panic("no return value specified for GetEntityID") + panic("no return value specified for GetChainCA") } - var r0 string + var r0 certs.Certificate var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { - return rf(ctx, serialNumber) + if rf, ok := ret.Get(0).(func(context.Context, string) (certs.Certificate, error)); ok { + return rf(ctx, token) } - if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { - r0 = rf(ctx, serialNumber) + if rf, ok := ret.Get(0).(func(context.Context, string) certs.Certificate); ok { + r0 = rf(ctx, token) } else { - r0 = ret.Get(0).(string) + r0 = ret.Get(0).(certs.Certificate) } if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, serialNumber) + r1 = rf(ctx, token) } else { r1 = ret.Error(1) } @@ -115,56 +115,56 @@ func (_m *MockService) GetEntityID(ctx context.Context, serialNumber string) (st return r0, r1 } -// MockService_GetEntityID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEntityID' -type MockService_GetEntityID_Call struct { +// MockService_GetChainCA_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetChainCA' +type MockService_GetChainCA_Call struct { *mock.Call } -// GetEntityID is a helper method to define mock.On call +// GetChainCA is a helper method to define mock.On call // - ctx context.Context -// - serialNumber string -func (_e *MockService_Expecter) GetEntityID(ctx interface{}, serialNumber interface{}) *MockService_GetEntityID_Call { - return &MockService_GetEntityID_Call{Call: _e.mock.On("GetEntityID", ctx, serialNumber)} +// - token string +func (_e *MockService_Expecter) GetChainCA(ctx interface{}, token interface{}) *MockService_GetChainCA_Call { + return &MockService_GetChainCA_Call{Call: _e.mock.On("GetChainCA", ctx, token)} } -func (_c *MockService_GetEntityID_Call) Run(run func(ctx context.Context, serialNumber string)) *MockService_GetEntityID_Call { +func (_c *MockService_GetChainCA_Call) Run(run func(ctx context.Context, token string)) *MockService_GetChainCA_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string)) }) return _c } -func (_c *MockService_GetEntityID_Call) Return(_a0 string, _a1 error) *MockService_GetEntityID_Call { +func (_c *MockService_GetChainCA_Call) Return(_a0 certs.Certificate, _a1 error) *MockService_GetChainCA_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, string) (string, error)) *MockService_GetEntityID_Call { +func (_c *MockService_GetChainCA_Call) RunAndReturn(run func(context.Context, string) (certs.Certificate, error)) *MockService_GetChainCA_Call { _c.Call.Return(run) return _c } -// GetSigningCA provides a mock function with given fields: ctx, token -func (_m *MockService) GetSigningCA(ctx context.Context, token string) (certs.Certificate, error) { - ret := _m.Called(ctx, token) +// GetEntityID provides a mock function with given fields: ctx, serialNumber +func (_m *MockService) GetEntityID(ctx context.Context, serialNumber string) (string, error) { + ret := _m.Called(ctx, serialNumber) if len(ret) == 0 { - panic("no return value specified for GetSigningCA") + panic("no return value specified for GetEntityID") } - var r0 certs.Certificate + var r0 string var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (certs.Certificate, error)); ok { - return rf(ctx, token) + if rf, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok { + return rf(ctx, serialNumber) } - if rf, ok := ret.Get(0).(func(context.Context, string) certs.Certificate); ok { - r0 = rf(ctx, token) + if rf, ok := ret.Get(0).(func(context.Context, string) string); ok { + r0 = rf(ctx, serialNumber) } else { - r0 = ret.Get(0).(certs.Certificate) + r0 = ret.Get(0).(string) } if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, token) + r1 = rf(ctx, serialNumber) } else { r1 = ret.Error(1) } @@ -172,31 +172,31 @@ func (_m *MockService) GetSigningCA(ctx context.Context, token string) (certs.Ce return r0, r1 } -// MockService_GetSigningCA_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSigningCA' -type MockService_GetSigningCA_Call struct { +// MockService_GetEntityID_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetEntityID' +type MockService_GetEntityID_Call struct { *mock.Call } -// GetSigningCA is a helper method to define mock.On call +// GetEntityID is a helper method to define mock.On call // - ctx context.Context -// - token string -func (_e *MockService_Expecter) GetSigningCA(ctx interface{}, token interface{}) *MockService_GetSigningCA_Call { - return &MockService_GetSigningCA_Call{Call: _e.mock.On("GetSigningCA", ctx, token)} +// - serialNumber string +func (_e *MockService_Expecter) GetEntityID(ctx interface{}, serialNumber interface{}) *MockService_GetEntityID_Call { + return &MockService_GetEntityID_Call{Call: _e.mock.On("GetEntityID", ctx, serialNumber)} } -func (_c *MockService_GetSigningCA_Call) Run(run func(ctx context.Context, token string)) *MockService_GetSigningCA_Call { +func (_c *MockService_GetEntityID_Call) Run(run func(ctx context.Context, serialNumber string)) *MockService_GetEntityID_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(string)) }) return _c } -func (_c *MockService_GetSigningCA_Call) Return(_a0 certs.Certificate, _a1 error) *MockService_GetSigningCA_Call { +func (_c *MockService_GetEntityID_Call) Return(_a0 string, _a1 error) *MockService_GetEntityID_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockService_GetSigningCA_Call) RunAndReturn(run func(context.Context, string) (certs.Certificate, error)) *MockService_GetSigningCA_Call { +func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, string) (string, error)) *MockService_GetEntityID_Call { _c.Call.Return(run) return _c } diff --git a/service.go b/service.go index 45ad4d3..13a882b 100644 --- a/service.go +++ b/service.go @@ -254,7 +254,12 @@ func (s *service) RetrieveCert(ctx context.Context, token, serialNumber string) if err != nil { return Certificate{}, []byte{}, errors.Wrap(ErrViewEntity, err) } - return cert, pem.EncodeToMemory(&pem.Block{Bytes: s.intermediateCA.Certificate.Raw, Type: "CERTIFICATE"}), nil + concat, err := s.getConcatCAs(ctx) + if err != nil { + return Certificate{}, []byte{}, errors.Wrap(ErrViewEntity, err) + } + + return cert, concat.Certificate, nil } func (s *service) ListCerts(ctx context.Context, pm PageMetadata) (CertificatePage, error) { @@ -450,18 +455,33 @@ func (s *service) GenerateCRL(ctx context.Context, caType CertType) ([]byte, err return pemBytes, nil } -func (s *service) GetSigningCA(ctx context.Context, token string) (Certificate, error) { +func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, error) { if _, err := jwt.ParseWithClaims(token, &jwt.StandardClaims{Issuer: Organization, Subject: "certs"}, func(token *jwt.Token) (interface{}, error) { return []byte(s.intermediateCA.SerialNumber), nil }); err != nil { return Certificate{}, errors.Wrap(err, ErrMalformedEntity) } - cert, err := s.repo.RetrieveCert(ctx, s.intermediateCA.SerialNumber) + return s.getConcatCAs(ctx) +} + +func (s *service) getConcatCAs(ctx context.Context) (Certificate, error) { + intermediateCert, err := s.repo.RetrieveCert(ctx, s.intermediateCA.SerialNumber) if err != nil { return Certificate{}, errors.Wrap(ErrViewEntity, err) } - return cert, nil + + rootCert, err := s.repo.RetrieveCert(ctx, s.rootCA.SerialNumber) + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + + concat := string(intermediateCert.Certificate) + string(rootCert.Certificate) + return Certificate{ + Certificate: []byte(concat), + Key: intermediateCert.Key, + ExpiryTime: intermediateCert.ExpiryTime, + }, nil } func (s *service) generateRootCA(ctx context.Context, config Config) (*CA, error) { diff --git a/tracing/certs.go b/tracing/certs.go index 787cb52..7c3ea97 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -89,8 +89,8 @@ func (tm *tracingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT return tm.svc.GenerateCRL(ctx, caType) } -func (tm *tracingMiddleware) GetSigningCA(ctx context.Context, token string) (certs.Certificate, error) { - ctx, span := tm.tracer.Start(ctx, "get_signing_ca") +func (tm *tracingMiddleware) GetChainCA(ctx context.Context, token string) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "get_chain_ca") defer span.End() - return tm.svc.GetSigningCA(ctx, token) + return tm.svc.GetChainCA(ctx, token) }