diff --git a/api/http/endpoint.go b/api/http/endpoint.go index bcbf3ab..c5b9810 100644 --- a/api/http/endpoint.go +++ b/api/http/endpoint.go @@ -67,7 +67,7 @@ func downloadCertEndpoint(svc certs.Service) endpoint.Endpoint { return func(ctx context.Context, request interface{}) (response interface{}, err error) { req := request.(downloadReq) if err := req.validate(); err != nil { - return downloadCertRes{}, err + return fileDownloadRes{}, err } cert, ca, err := svc.RetrieveCert(ctx, req.token, req.id) if err != nil { @@ -243,3 +243,54 @@ func generateCRLEndpoint(svc certs.Service) endpoint.Endpoint { }, nil } } + +func getDownloadCATokenEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + token, err := svc.RetrieveCAToken(ctx) + if err != nil { + return requestCertDownloadTokenRes{}, err + } + + return requestCertDownloadTokenRes{Token: token}, nil + } +} + +func downloadCAEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(downloadReq) + if err := req.validate(); err != nil { + return fileDownloadRes{}, err + } + + cert, err := svc.GetSigningCA(ctx, req.token) + if err != nil { + return fileDownloadRes{}, err + } + + return fileDownloadRes{ + Certificate: cert.Certificate, + PrivateKey: cert.Key, + Filename: "ca.zip", + ContentType: "application/zip", + }, nil + } +} + +func viewCAEndpoint(svc certs.Service) endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (response interface{}, err error) { + req := request.(downloadReq) + if err := req.validate(); err != nil { + return viewCertRes{}, err + } + + cert, err := svc.GetSigningCA(ctx, req.token) + if err != nil { + return viewCertRes{}, err + } + + return viewCertRes{ + Certificate: string(cert.Certificate), + Key: string(cert.Key), + }, nil + } +} diff --git a/api/http/requests.go b/api/http/requests.go index f11f497..6e77f47 100644 --- a/api/http/requests.go +++ b/api/http/requests.go @@ -15,9 +15,6 @@ type downloadReq struct { } func (req downloadReq) validate() error { - if req.id == "" { - return errors.Wrap(certs.ErrMalformedEntity, ErrEmptySerialNo) - } if req.token == "" { return errors.Wrap(certs.ErrMalformedEntity, ErrEmptyToken) } diff --git a/api/http/responses.go b/api/http/responses.go index ed333bd..42b706b 100644 --- a/api/http/responses.go +++ b/api/http/responses.go @@ -75,24 +75,6 @@ func (res requestCertDownloadTokenRes) Empty() bool { return false } -type downloadCertRes struct { - Certificate []byte `json:"certificate"` - PrivateKey []byte `json:"private_key"` - CA []byte `json:"ca"` -} - -func (res downloadCertRes) Code() int { - return http.StatusOK -} - -func (res downloadCertRes) Headers() map[string]string { - return map[string]string{} -} - -func (res downloadCertRes) Empty() bool { - return false -} - type issueCertRes struct { SerialNumber string `json:"serial_number"` Certificate string `json:"certificate,omitempty"` @@ -138,12 +120,12 @@ func (res listCertsRes) Empty() bool { } type viewCertRes struct { - SerialNumber string `json:"serial_number"` + SerialNumber string `json:"serial_number,omitempty"` Certificate string `json:"certificate,omitempty"` - Key string `json:"key,omitempty"` - Revoked bool `json:"revoked"` - ExpiryTime time.Time `json:"expiry_time"` - EntityID string `json:"entity_id"` + Key string `json:"key,omitempty,omitempty"` + Revoked bool `json:"revoked,omitempty"` + ExpiryTime time.Time `json:"expiry_time,omitempty"` + EntityID string `json:"entity_id,omitempty"` } func (res viewCertRes) Code() int { diff --git a/api/http/transport.go b/api/http/transport.go index b8bc352..ac03fd5 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -100,6 +100,24 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http EncodeResponse, opts..., ), "generate_crl").ServeHTTP) + r.Get("/get-ca/token", otelhttp.NewHandler(kithttp.NewServer( + getDownloadCATokenEndpoint(svc), + decodeView, + EncodeResponse, + opts..., + ), "get_ca_token").ServeHTTP) + r.Get("/view-ca", otelhttp.NewHandler(kithttp.NewServer( + viewCAEndpoint(svc), + decodeDownloadCA, + EncodeResponse, + opts..., + ), "view_ca").ServeHTTP) + r.Get("/download-ca", otelhttp.NewHandler(kithttp.NewServer( + downloadCAEndpoint(svc), + decodeDownloadCA, + encodeCADownloadResponse, + opts..., + ), "download_ca").ServeHTTP) }) r.Get("/health", certs.Health("certs", instanceID)) @@ -139,6 +157,18 @@ func decodeDownloadCerts(_ context.Context, r *http.Request) (interface{}, error return req, nil } +func decodeDownloadCA(_ context.Context, r *http.Request) (interface{}, error) { + token, err := readStringQuery(r, token, "") + if err != nil { + return nil, err + } + req := downloadReq{ + token: token, + } + + return req, nil +} + func decodeOCSPRequest(_ context.Context, r *http.Request) (interface{}, error) { body, err := io.ReadAll(r.Body) if err != nil { @@ -280,6 +310,41 @@ func encodeFileDownloadResponse(_ context.Context, w http.ResponseWriter, respon return err } +func encodeCADownloadResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { + resp := response.(fileDownloadRes) + var buffer bytes.Buffer + zw := zip.NewWriter(&buffer) + + f, err := zw.Create("ca.crt") + if err != nil { + return err + } + + if _, err = f.Write(resp.Certificate); err != nil { + return err + } + + f, err = zw.Create("ca.key") + if err != nil { + return err + } + + if _, err = f.Write(resp.PrivateKey); err != nil { + return err + } + + if err := zw.Close(); err != nil { + return err + } + + w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s", resp.Filename)) + w.Header().Set("Content-Type", resp.ContentType) + + _, err = w.Write(buffer.Bytes()) + + return err +} + // loggingErrorEncoder is a go-kit error encoder logging decorator. func loggingErrorEncoder(logger *slog.Logger, enc kithttp.ErrorEncoder) kithttp.ErrorEncoder { return func(ctx context.Context, err error, w http.ResponseWriter) { diff --git a/api/logging.go b/api/logging.go index 407f139..4a6c88a 100644 --- a/api/logging.go +++ b/api/logging.go @@ -63,7 +63,7 @@ func (lm *loggingMiddleware) RevokeCert(ctx context.Context, serialNumber string func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (tokenString string, err error) { defer func(begin time.Time) { - message := fmt.Sprintf("Method get_cert_download_token for cert %s took %s to complete", serialNumber, time.Since(begin)) + message := fmt.Sprintf("Method get_cert_download_token for cert took %s to complete", time.Since(begin)) if err != nil { lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) return @@ -73,6 +73,18 @@ func (lm *loggingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri return lm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } +func (lm *loggingMiddleware) RetrieveCAToken(ctx context.Context) (tokenString string, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method get_cert_download_token for cert took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.RetrieveCAToken(ctx) +} + func (lm *loggingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (cert certs.Certificate, err error) { defer func(begin time.Time) { message := fmt.Sprintf("Method issue_cert for took %s to complete", time.Since(begin)) @@ -144,3 +156,15 @@ func (lm *loggingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT }(time.Now()) return lm.svc.GenerateCRL(ctx, caType) } + +func (lm *loggingMiddleware) GetSigningCA(ctx context.Context, token string) (cert certs.Certificate, err error) { + defer func(begin time.Time) { + message := fmt.Sprintf("Method get_signing_ca took %s to complete", time.Since(begin)) + if err != nil { + lm.logger.Warn(fmt.Sprintf("%s with error: %s.", message, err)) + return + } + lm.logger.Info(message) + }(time.Now()) + return lm.svc.GetSigningCA(ctx, token) +} diff --git a/api/metrics.go b/api/metrics.go index 663bbac..f1ce5c8 100644 --- a/api/metrics.go +++ b/api/metrics.go @@ -58,9 +58,19 @@ func (mm *metricsMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri mm.counter.With("method", "get_certificate_download_token").Add(1) mm.latency.With("method", "get_certificate_download_token").Observe(time.Since(begin).Seconds()) }(time.Now()) + return mm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } +func (mm *metricsMiddleware) RetrieveCAToken(ctx context.Context) (string, error) { + defer func(begin time.Time) { + mm.counter.With("method", "get_CA_token").Add(1) + mm.latency.With("method", "get_CA_token").Observe(time.Since(begin).Seconds()) + }(time.Now()) + + return mm.svc.RetrieveCAToken(ctx) +} + func (mm *metricsMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { defer func(begin time.Time) { mm.counter.With("method", "issue_certificate").Add(1) @@ -82,6 +92,7 @@ func (mm *metricsMiddleware) ViewCert(ctx context.Context, serialNumber string) mm.counter.With("method", "view_certificate").Add(1) mm.latency.With("method", "view_certificate").Observe(time.Since(begin).Seconds()) }(time.Now()) + return mm.svc.ViewCert(ctx, serialNumber) } @@ -108,3 +119,11 @@ func (mm *metricsMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT }(time.Now()) return mm.svc.GenerateCRL(ctx, caType) } + +func (mm *metricsMiddleware) GetSigningCA(ctx context.Context, token string) (certs.Certificate, error) { + defer func(begin time.Time) { + mm.counter.With("method", "get_signing_ca").Add(1) + mm.latency.With("method", "get_signing_ca").Observe(time.Since(begin).Seconds()) + }(time.Now()) + return mm.svc.GetSigningCA(ctx, token) +} diff --git a/certs.go b/certs.go index 4adb46b..92629f8 100644 --- a/certs.go +++ b/certs.go @@ -40,7 +40,7 @@ type Service interface { RevokeCert(ctx context.Context, serialNumber string) error // RetrieveCert retrieves a certificate record from the database. - RetrieveCert(ctx context.Context, token string, serialNumber string) (Certificate, []byte, error) + RetrieveCert(ctx context.Context, token, serialNumber string) (Certificate, []byte, error) // ViewCert retrieves a certificate record from the database. ViewCert(ctx context.Context, serialNumber string) (Certificate, error) @@ -48,9 +48,14 @@ type Service interface { // ListCerts retrieves the certificates from the database while applying filters. ListCerts(ctx context.Context, pm PageMetadata) (CertificatePage, error) - // RetrieveCertDownloadToken retrieves a certificate download token. + // RetrieveCertDownloadToken generates a certificate download token. + // The token is needed to download the client certificate. RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error) + // RetrieveCAToken generates a CA download and view token. + // The token is needed to view and download the CA certificate. + RetrieveCAToken(ctx context.Context) (string, error) + // IssueCert issues a certificate from the database. IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, option SubjectOptions) (Certificate, error) @@ -60,8 +65,11 @@ type Service interface { // GetEntityID retrieves the entity ID for a certificate. GetEntityID(ctx context.Context, serialNumber string) (string, error) - // GenerateCRL creates + // GenerateCRL creates cert revocation list. GenerateCRL(ctx context.Context, caType CertType) ([]byte, error) + + // Retrieves the signing CA. + GetSigningCA(ctx context.Context, token string) (Certificate, error) } type Repository interface { diff --git a/cli/certs.go b/cli/certs.go index 9a16e2f..66203fb 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -139,7 +139,7 @@ var cmdCerts = []cobra.Command{ }, }, { - Use: "view ", + Use: "view ", Short: "View certificate", Long: `Views a certificate for a given serial number.`, Run: func(cmd *cobra.Command, args []string) { @@ -155,6 +155,57 @@ var cmdCerts = []cobra.Command{ logJSONCmd(*cmd, cert) }, }, + { + Use: "view-ca ", + Short: "View-ca certificate", + Long: `Views ca certificate key with a given token.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + logUsageCmd(*cmd, cmd.Use) + return + } + cert, err := sdk.ViewCA(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, cert) + }, + }, + { + Use: "download-ca ", + Short: "Download signing CA", + Long: `Download intermediate cert and ca with a given token.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + logUsageCmd(*cmd, cmd.Use) + return + } + bundle, err := sdk.DownloadCA(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + logSaveCAFiles(*cmd, bundle) + }, + }, + { + Use: "token-ca", + Short: "Get CA token", + Long: `Gets a download token for CA.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 0 { + logUsageCmd(*cmd, cmd.Use) + return + } + token, err := sdk.GetCAToken() + if err != nil { + logErrorCmd(*cmd, err) + return + } + logJSONCmd(*cmd, token) + }, + }, } // NewCertsCmd returns certificate command. diff --git a/cli/certs_test.go b/cli/certs_test.go index c81ca2f..d2b1c86 100644 --- a/cli/certs_test.go +++ b/cli/certs_test.go @@ -21,13 +21,16 @@ import ( ) const ( - revokeCmd = "revoke" - issueCmd = "issue" - renewCmd = "renew" - listCmd = "get" - tokenCmd = "token" - downloadCmd = "download" - all = "all" + revokeCmd = "revoke" + issueCmd = "issue" + renewCmd = "renew" + listCmd = "get" + tokenCmd = "token" + downloadCmd = "download" + all = "all" + downloadCACmd = "download-ca" + CATokenCmd = "token-ca" + viewCACmd = "view-ca" ) var ( @@ -35,6 +38,7 @@ var ( id = "5b4c9ee3-e719-4a0a-9ee5-354932c5e6a4" commonName = "test-name" extraArg = "extra-arg" + token = "token" ) func TestIssueCertCmd(t *testing.T) { @@ -395,7 +399,6 @@ func TestDownloadCertCmd(t *testing.T) { certCmd := cli.NewCertsCmd() rootCmd := setFlags(certCmd) - token := "token" cases := []struct { desc string args []string @@ -463,6 +466,200 @@ func TestDownloadCertCmd(t *testing.T) { } } +func TestGetCATokenCmd(t *testing.T) { + sdkMock := new(sdkmocks.MockSDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + tk := "ca1121f5-d66a-44c9-bf3c-d267498a0f3d" + + var token sdk.Token + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + token sdk.Token + }{ + { + desc: "get CA token successfully", + args: []string{}, + logType: entityLog, + token: sdk.Token{Token: tk}, + }, + { + desc: "get CA token with invalid args", + args: []string{ + extraArg, + }, + logType: usageLog, + }, + { + desc: "get CA token failed", + args: []string{}, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrGetToken, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrGetToken, http.StatusUnprocessableEntity)), + logType: errLog, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("GetCAToken").Return(tc.token, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{CATokenCmd}, tc.args...)...) + + switch tc.logType { + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case entityLog: + err := json.Unmarshal([]byte(out), &token) + if err != nil { + t.Fatalf("Failed to unmarshal JSON: %v", err) + } + assert.Equal(t, tc.token, token, fmt.Sprintf("%v unexpected response, expected: %v, got: %v", tc.desc, tc.token, token)) + } + sdkCall.Unset() + }) + } +} + +func TestDownloadCACmd(t *testing.T) { + sdkMock := new(sdkmocks.MockSDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logMessage string + logType outputLog + certBundle sdk.CertificateBundle + }{ + { + desc: "download CA successfully", + args: []string{ + token, + }, + logType: entityLog, + certBundle: sdk.CertificateBundle{ + Certificate: []byte("certificate"), + PrivateKey: []byte("privatekey"), + }, + logMessage: "Saved ca.pem\nSaved cert.pem\nSaved key.pem\n\nAll certificate files have been saved successfully.\n", + }, + { + desc: "download CA with invalid args", + args: []string{ + token, + extraArg, + }, + logType: usageLog, + }, + { + desc: "download cert failed", + args: []string{ + token, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity)), + logType: errLog, + certBundle: sdk.CertificateBundle{}, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + defer func() { + cleanupFiles(t, []string{"ca.key", "ca.crt"}) + }() + sdkCall := sdkMock.On("DownloadCA", mock.Anything, mock.Anything).Return(tc.certBundle, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{downloadCACmd}, tc.args...)...) + switch tc.logType { + case entityLog: + assert.True(t, strings.Contains(out, "Saved ca.crt"), fmt.Sprintf("%s invalid output: %s", tc.desc, out)) + assert.True(t, strings.Contains(out, "Saved ca.key"), fmt.Sprintf("%s invalid output: %s", tc.desc, out)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + sdkCall.Unset() + }) + } +} + +func TestViewCACmd(t *testing.T) { + sdkMock := new(sdkmocks.MockSDK) + cli.SetSDK(sdkMock) + certCmd := cli.NewCertsCmd() + rootCmd := setFlags(certCmd) + + var cert sdk.Certificate + cases := []struct { + desc string + args []string + sdkErr errors.SDKError + errLogMessage string + logType outputLog + cert sdk.Certificate + }{ + { + desc: "view cert successfully", + args: []string{ + token, + }, + logType: entityLog, + cert: sdk.Certificate{ + Certificate: "certificate", + Key: "privatekey", + }, + }, + { + desc: "view cert with invalid args", + args: []string{ + token, + extraArg, + }, + logType: usageLog, + }, + { + desc: "view cert failed", + args: []string{ + token, + }, + sdkErr: errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity), + errLogMessage: fmt.Sprintf("\nerror: %s\n\n", errors.NewSDKErrorWithStatus(certs.ErrUpdateEntity, http.StatusUnprocessableEntity)), + logType: errLog, + cert: sdk.Certificate{}, + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + sdkCall := sdkMock.On("ViewCA", mock.Anything).Return(tc.cert, tc.sdkErr) + out := executeCommand(t, rootCmd, append([]string{viewCACmd}, tc.args...)...) + switch tc.logType { + case entityLog: + err := json.Unmarshal([]byte(out), &cert) + assert.Nil(t, err) + assert.Equal(t, tc.cert, cert, fmt.Sprintf("%s unexpected response: expected: %v, got: %v", tc.desc, tc.cert, cert)) + case usageLog: + assert.False(t, strings.Contains(out, rootCmd.Use), fmt.Sprintf("%s invalid usage: %s", tc.desc, out)) + case errLog: + assert.Equal(t, tc.errLogMessage, out, fmt.Sprintf("%s unexpected error response: expected %s got errLogMessage:%s", tc.desc, tc.errLogMessage, out)) + } + sdkCall.Unset() + }) + } +} + func cleanupFiles(t *testing.T, filenames []string) { for _, filename := range filenames { err := os.Remove(filename) diff --git a/cli/utils.go b/cli/utils.go index c7e922f..6ff37ce 100644 --- a/cli/utils.go +++ b/cli/utils.go @@ -82,6 +82,23 @@ func logSaveCertFiles(cmd cobra.Command, certBundle ctxsdk.CertificateBundle) { fmt.Fprintf(cmd.OutOrStdout(), "\nAll certificate files have been saved successfully.\n") } +func logSaveCAFiles(cmd cobra.Command, certBundle ctxsdk.CertificateBundle) { + files := map[string][]byte{ + "ca.crt": certBundle.Certificate, + "ca.key": certBundle.PrivateKey, + } + + for filename, content := range files { + err := saveToFile(filename, content) + if err != nil { + logErrorCmd(cmd, err) + return + } + fmt.Fprintf(cmd.OutOrStdout(), "Saved %s\n", filename) + } + fmt.Fprintf(cmd.OutOrStdout(), "\nAll certificate files have been saved successfully.\n") +} + func saveToFile(filename string, content []byte) error { cwd, err := os.Getwd() if err != nil { diff --git a/mocks/service.go b/mocks/service.go index 00bf73e..cf6d732 100644 --- a/mocks/service.go +++ b/mocks/service.go @@ -144,6 +144,63 @@ func (_c *MockService_GetEntityID_Call) RunAndReturn(run func(context.Context, s return _c } +// GetSigningCA provides a mock function with given fields: ctx, token +func (_m *MockService) GetSigningCA(ctx context.Context, token string) (certs.Certificate, error) { + ret := _m.Called(ctx, token) + + if len(ret) == 0 { + panic("no return value specified for GetSigningCA") + } + + var r0 certs.Certificate + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (certs.Certificate, error)); ok { + return rf(ctx, token) + } + if rf, ok := ret.Get(0).(func(context.Context, string) certs.Certificate); ok { + r0 = rf(ctx, token) + } else { + r0 = ret.Get(0).(certs.Certificate) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, token) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_GetSigningCA_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSigningCA' +type MockService_GetSigningCA_Call struct { + *mock.Call +} + +// GetSigningCA is a helper method to define mock.On call +// - ctx context.Context +// - token string +func (_e *MockService_Expecter) GetSigningCA(ctx interface{}, token interface{}) *MockService_GetSigningCA_Call { + return &MockService_GetSigningCA_Call{Call: _e.mock.On("GetSigningCA", ctx, token)} +} + +func (_c *MockService_GetSigningCA_Call) Run(run func(ctx context.Context, token string)) *MockService_GetSigningCA_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockService_GetSigningCA_Call) Return(_a0 certs.Certificate, _a1 error) *MockService_GetSigningCA_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_GetSigningCA_Call) RunAndReturn(run func(context.Context, string) (certs.Certificate, error)) *MockService_GetSigningCA_Call { + _c.Call.Return(run) + return _c +} + // IssueCert provides a mock function with given fields: ctx, entityID, ttl, ipAddrs, option func (_m *MockService) IssueCert(ctx context.Context, entityID string, ttl string, ipAddrs []string, option certs.SubjectOptions) (certs.Certificate, error) { ret := _m.Called(ctx, entityID, ttl, ipAddrs, option) @@ -383,6 +440,62 @@ func (_c *MockService_RenewCert_Call) RunAndReturn(run func(context.Context, str return _c } +// RetrieveCAToken provides a mock function with given fields: ctx +func (_m *MockService) RetrieveCAToken(ctx context.Context) (string, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for RetrieveCAToken") + } + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (string, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) string); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockService_RetrieveCAToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RetrieveCAToken' +type MockService_RetrieveCAToken_Call struct { + *mock.Call +} + +// RetrieveCAToken is a helper method to define mock.On call +// - ctx context.Context +func (_e *MockService_Expecter) RetrieveCAToken(ctx interface{}) *MockService_RetrieveCAToken_Call { + return &MockService_RetrieveCAToken_Call{Call: _e.mock.On("RetrieveCAToken", ctx)} +} + +func (_c *MockService_RetrieveCAToken_Call) Run(run func(ctx context.Context)) *MockService_RetrieveCAToken_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *MockService_RetrieveCAToken_Call) Return(_a0 string, _a1 error) *MockService_RetrieveCAToken_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockService_RetrieveCAToken_Call) RunAndReturn(run func(context.Context) (string, error)) *MockService_RetrieveCAToken_Call { + _c.Call.Return(run) + return _c +} + // RetrieveCert provides a mock function with given fields: ctx, token, serialNumber func (_m *MockService) RetrieveCert(ctx context.Context, token string, serialNumber string) (certs.Certificate, []byte, error) { ret := _m.Called(ctx, token, serialNumber) diff --git a/sdk/certs_test.go b/sdk/certs_test.go index b2a23db..da97a49 100644 --- a/sdk/certs_test.go +++ b/sdk/certs_test.go @@ -26,6 +26,7 @@ const ( id = "c333e6f1-59bb-4c39-9e13-3a2766af8ba5" ttl = "10h" commonName = "test" + token = "token" ) func setupCerts() (*httptest.Server, *mocks.MockService) { @@ -539,3 +540,187 @@ func TestViewCert(t *testing.T) { }) } } + +func TestDownloadCACert(t *testing.T) { + ts, svc := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: contentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cert := sdk.Certificate{ + SerialNumber: serialNum, + } + + cases := []struct { + desc string + token string + svcresp certs.Certificate + svcerr error + err errors.SDKError + sdkCert sdk.Certificate + }{ + { + desc: "Download CA successfully", + token: token, + svcresp: certs.Certificate{ + SerialNumber: serialNum, + Certificate: []byte("cert"), + Key: []byte("key"), + }, + sdkCert: cert, + svcerr: nil, + err: nil, + }, + { + desc: "Download CA failure", + token: token, + svcresp: certs.Certificate{}, + svcerr: certs.ErrViewEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + }, + { + desc: "Download CA with empty token", + token: "", + svcresp: certs.Certificate{}, + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("GetSigningCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) + + _, err := ctsdk.DownloadCA(tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "GetSigningCA", mock.Anything, tc.token) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} + +func TestViewCA(t *testing.T) { + ts, svc := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: contentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + cert := sdk.Certificate{ + SerialNumber: serialNum, + Certificate: "cert", + Key: "Key", + } + + cases := []struct { + desc string + token string + svcresp certs.Certificate + svcerr error + err errors.SDKError + sdkCert sdk.Certificate + }{ + { + desc: "ViewCA success", + token: token, + svcresp: certs.Certificate{ + Certificate: []byte("cert"), + }, + sdkCert: cert, + svcerr: nil, + err: nil, + }, + { + desc: "ViewCA failure", + token: token, + svcresp: certs.Certificate{}, + svcerr: certs.ErrViewEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrViewEntity, http.StatusUnprocessableEntity), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("GetSigningCA", mock.Anything, tc.token).Return(tc.svcresp, tc.svcerr) + + c, err := ctsdk.ViewCA(tc.token) + assert.Equal(t, tc.err, err) + if tc.err == nil { + ok := svcCall.Parent.AssertCalled(t, "GetSigningCA", mock.Anything, tc.token) + assert.True(t, ok) + } + assert.Equal(t, tc.sdkCert.Certificate, c.Certificate, fmt.Sprintf("expected: %v, got: %v", tc.sdkCert.Certificate, c.Certificate)) + svcCall.Unset() + }) + } +} + +func TestGetCAToken(t *testing.T) { + ts, svc := setupCerts() + defer ts.Close() + + sdkConfig := sdk.Config{ + CertsURL: ts.URL, + MsgContentType: contentType, + TLSVerification: false, + } + + ctsdk := sdk.NewSDK(sdkConfig) + + token := "valid token" + + cases := []struct { + desc string + svcresp string + svcerr error + err errors.SDKError + }{ + { + desc: "RetrieveCertDownloadToken success", + svcresp: token, + svcerr: nil, + err: nil, + }, + { + desc: "RetrieveCertDownloadToken failure", + svcresp: "", + svcerr: certs.ErrGetToken, + err: errors.NewSDKErrorWithStatus(certs.ErrGetToken, http.StatusUnprocessableEntity), + }, + { + desc: "RetrieveCertDownloadToken with empty serial", + svcresp: "", + svcerr: certs.ErrMalformedEntity, + err: errors.NewSDKErrorWithStatus(certs.ErrMalformedEntity, http.StatusBadRequest), + }, + } + + for _, tc := range cases { + t.Run(tc.desc, func(t *testing.T) { + svcCall := svc.On("RetrieveCertDownloadToken", mock.Anything).Return(tc.svcresp, tc.svcerr) + + resp, err := ctsdk.GetCAToken() + assert.Equal(t, tc.err, err) + if tc.err == nil { + assert.Equal(t, tc.svcresp, resp.Token) + ok := svcCall.Parent.AssertCalled(t, "RetrieveCertDownloadToken", mock.Anything) + assert.True(t, ok) + } + svcCall.Unset() + }) + } +} diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index b76a655..ba93512 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -27,6 +27,64 @@ func (_m *MockSDK) EXPECT() *MockSDK_Expecter { return &MockSDK_Expecter{mock: &_m.Mock} } +// DownloadCA provides a mock function with given fields: token +func (_m *MockSDK) DownloadCA(token string) (sdk.CertificateBundle, errors.SDKError) { + ret := _m.Called(token) + + if len(ret) == 0 { + panic("no return value specified for DownloadCA") + } + + var r0 sdk.CertificateBundle + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(string) (sdk.CertificateBundle, errors.SDKError)); ok { + return rf(token) + } + if rf, ok := ret.Get(0).(func(string) sdk.CertificateBundle); ok { + r0 = rf(token) + } else { + r0 = ret.Get(0).(sdk.CertificateBundle) + } + + if rf, ok := ret.Get(1).(func(string) errors.SDKError); ok { + r1 = rf(token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_DownloadCA_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DownloadCA' +type MockSDK_DownloadCA_Call struct { + *mock.Call +} + +// DownloadCA is a helper method to define mock.On call +// - token string +func (_e *MockSDK_Expecter) DownloadCA(token interface{}) *MockSDK_DownloadCA_Call { + return &MockSDK_DownloadCA_Call{Call: _e.mock.On("DownloadCA", token)} +} + +func (_c *MockSDK_DownloadCA_Call) Run(run func(token string)) *MockSDK_DownloadCA_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockSDK_DownloadCA_Call) Return(_a0 sdk.CertificateBundle, _a1 errors.SDKError) *MockSDK_DownloadCA_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_DownloadCA_Call) RunAndReturn(run func(string) (sdk.CertificateBundle, errors.SDKError)) *MockSDK_DownloadCA_Call { + _c.Call.Return(run) + return _c +} + // DownloadCert provides a mock function with given fields: token, serialNumber func (_m *MockSDK) DownloadCert(token string, serialNumber string) (sdk.CertificateBundle, errors.SDKError) { ret := _m.Called(token, serialNumber) @@ -86,6 +144,63 @@ func (_c *MockSDK_DownloadCert_Call) RunAndReturn(run func(string, string) (sdk. return _c } +// GetCAToken provides a mock function with given fields: +func (_m *MockSDK) GetCAToken() (sdk.Token, errors.SDKError) { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for GetCAToken") + } + + var r0 sdk.Token + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func() (sdk.Token, errors.SDKError)); ok { + return rf() + } + if rf, ok := ret.Get(0).(func() sdk.Token); ok { + r0 = rf() + } else { + r0 = ret.Get(0).(sdk.Token) + } + + if rf, ok := ret.Get(1).(func() errors.SDKError); ok { + r1 = rf() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_GetCAToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCAToken' +type MockSDK_GetCAToken_Call struct { + *mock.Call +} + +// GetCAToken is a helper method to define mock.On call +func (_e *MockSDK_Expecter) GetCAToken() *MockSDK_GetCAToken_Call { + return &MockSDK_GetCAToken_Call{Call: _e.mock.On("GetCAToken")} +} + +func (_c *MockSDK_GetCAToken_Call) Run(run func()) *MockSDK_GetCAToken_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSDK_GetCAToken_Call) Return(_a0 sdk.Token, _a1 errors.SDKError) *MockSDK_GetCAToken_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_GetCAToken_Call) RunAndReturn(run func() (sdk.Token, errors.SDKError)) *MockSDK_GetCAToken_Call { + _c.Call.Return(run) + return _c +} + // IssueCert provides a mock function with given fields: entityID, ttl, ipAddrs, opts func (_m *MockSDK) IssueCert(entityID string, ttl string, ipAddrs []string, opts sdk.Options) (sdk.Certificate, errors.SDKError) { ret := _m.Called(entityID, ttl, ipAddrs, opts) @@ -419,6 +534,64 @@ func (_c *MockSDK_RevokeCert_Call) RunAndReturn(run func(string) errors.SDKError return _c } +// ViewCA provides a mock function with given fields: token +func (_m *MockSDK) ViewCA(token string) (sdk.Certificate, errors.SDKError) { + ret := _m.Called(token) + + if len(ret) == 0 { + panic("no return value specified for ViewCA") + } + + var r0 sdk.Certificate + var r1 errors.SDKError + if rf, ok := ret.Get(0).(func(string) (sdk.Certificate, errors.SDKError)); ok { + return rf(token) + } + if rf, ok := ret.Get(0).(func(string) sdk.Certificate); ok { + r0 = rf(token) + } else { + r0 = ret.Get(0).(sdk.Certificate) + } + + if rf, ok := ret.Get(1).(func(string) errors.SDKError); ok { + r1 = rf(token) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(errors.SDKError) + } + } + + return r0, r1 +} + +// MockSDK_ViewCA_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ViewCA' +type MockSDK_ViewCA_Call struct { + *mock.Call +} + +// ViewCA is a helper method to define mock.On call +// - token string +func (_e *MockSDK_Expecter) ViewCA(token interface{}) *MockSDK_ViewCA_Call { + return &MockSDK_ViewCA_Call{Call: _e.mock.On("ViewCA", token)} +} + +func (_c *MockSDK_ViewCA_Call) Run(run func(token string)) *MockSDK_ViewCA_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(string)) + }) + return _c +} + +func (_c *MockSDK_ViewCA_Call) Return(_a0 sdk.Certificate, _a1 errors.SDKError) *MockSDK_ViewCA_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockSDK_ViewCA_Call) RunAndReturn(run func(string) (sdk.Certificate, errors.SDKError)) *MockSDK_ViewCA_Call { + _c.Call.Return(run) + return _c +} + // ViewCert provides a mock function with given fields: serialNumber func (_m *MockSDK) ViewCert(serialNumber string) (sdk.Certificate, errors.SDKError) { ret := _m.Called(serialNumber) diff --git a/sdk/sdk.go b/sdk/sdk.go index 0416b34..290cfac 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -65,12 +65,12 @@ type Token struct { } type Certificate struct { - SerialNumber string `json:"serial_number"` + SerialNumber string `json:"serial_number,omitempty"` Certificate string `json:"certificate,omitempty"` Key string `json:"key,omitempty"` - Revoked bool `json:"revoked"` - ExpiryTime time.Time `json:"expiry_time"` - EntityID string `json:"entity_id"` + Revoked bool `json:"revoked,omitempty"` + ExpiryTime time.Time `json:"expiry_time,omitempty"` + EntityID string `json:"entity_id,omitempty"` DownloadUrl string `json:"-"` } @@ -161,6 +161,27 @@ type SDK interface { // response, _ := sdk.OCSP("serialNumber") // fmt.Println(response) OCSP(serialNumber string) (*ocsp.Response, errors.SDKError) + + // ViewCA views the signing certificate + // + // example: + // response, _ := sdk.ViewCA(token) + // fmt.Println(response) + ViewCA(token string) (Certificate, errors.SDKError) + + // DownloadCA downloads the signing certificate + // + // example: + // response, _ := sdk.DownloadCA(token) + // fmt.Println(response) + DownloadCA(token string) (CertificateBundle, errors.SDKError) + + // GetCAToken get token for viewing and downloading CA + // + // example: + // response, _ := sdk.GetCAToken() + // fmt.Println(response) + GetCAToken() (Token, errors.SDKError) } func (sdk mgSDK) IssueCert(entityID, ttl string, ipAddrs []string, opts Options) (Certificate, errors.SDKError) { @@ -298,6 +319,76 @@ func (sdk mgSDK) OCSP(serialNumber string) (*ocsp.Response, errors.SDKError) { return ocspResp, nil } +func (sdk mgSDK) ViewCA(token string) (Certificate, errors.SDKError) { + pm := PageMetadata{ + Token: token, + } + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/view-ca", certsEndpoint), pm) + if err != nil { + return Certificate{}, errors.NewSDKError(err) + } + + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return Certificate{}, sdkerr + } + + var cert Certificate + if err := json.Unmarshal(body, &cert); err != nil { + return Certificate{}, errors.NewSDKError(err) + } + return cert, nil +} + +func (sdk mgSDK) DownloadCA(token string) (CertificateBundle, errors.SDKError) { + pm := PageMetadata{ + Token: token, + } + url, err := sdk.withQueryParams(sdk.certsURL, fmt.Sprintf("%s/download-ca", certsEndpoint), pm) + if err != nil { + return CertificateBundle{}, errors.NewSDKError(err) + } + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return CertificateBundle{}, sdkerr + } + + zipReader, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + return CertificateBundle{}, errors.NewSDKError(err) + } + + var bundle CertificateBundle + for _, file := range zipReader.File { + fileContent, err := readZipFile(file) + if err != nil { + return CertificateBundle{}, errors.NewSDKError(err) + } + switch file.Name { + case "ca.crt": + bundle.Certificate = fileContent + case "ca.key": + bundle.PrivateKey = fileContent + } + } + + return bundle, nil +} + +func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) { + url := fmt.Sprintf("%s/%s/get-ca/token", sdk.certsURL, certsEndpoint) + _, body, sdkerr := sdk.processRequest(http.MethodGet, url, nil, nil, http.StatusOK) + if sdkerr != nil { + return Token{}, sdkerr + } + + var tk Token + if err := json.Unmarshal(body, &tk); err != nil { + return Token{}, errors.NewSDKError(err) + } + return tk, nil +} + func NewSDK(conf Config) SDK { return &mgSDK{ certsURL: conf.CertsURL, diff --git a/service.go b/service.go index 079840f..dc6d839 100644 --- a/service.go +++ b/service.go @@ -35,6 +35,7 @@ const ( certValidityPeriod = time.Hour * 24 * 90 // 30 days rCertExpiryThreshold = time.Hour * 24 * 30 // 30 days iCertExpiryThreshold = time.Hour * 24 * 10 // 10 days + downloadTokenExpiry = time.Minute * 5 ) type CertType int @@ -99,6 +100,7 @@ var ( ErrCertExpired = errors.New("certificate expired before renewal") ErrCertRevoked = errors.New("certificate has been revoked and cannot be renewed") ErrCertInvalidType = errors.New("invalid cert type") + ErrInvalidLength = errors.New("invalid length of serial numbers") ) type SubjectOptions struct { @@ -266,7 +268,15 @@ func (s *service) ViewCert(ctx context.Context, serialNumber string) (Certificat return cert, nil } -// GetCertDownloadToken generates a download token for a certificate. +func (s *service) ViewCA(ctx context.Context) (Certificate, error) { + cert, err := s.repo.RetrieveCert(ctx, s.intermediateCA.SerialNumber) + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + return cert, nil +} + +// RetrieveCertDownloadToken generates a download token for a certificate. // It verifies the token and serial number, and returns a signed JWT token string. // The token is valid for 5 minutes. // Parameters: @@ -277,11 +287,31 @@ func (s *service) ViewCert(ctx context.Context, serialNumber string) (Certificat // - string: the signed JWT token string // - error: an error if the authentication fails or any other error occurs func (s *service) RetrieveCertDownloadToken(ctx context.Context, serialNumber string) (string, error) { - jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ExpiresAt: time.Now().Add(time.Minute * 5).Unix(), Issuer: Organization, Subject: "certs"}) + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ExpiresAt: time.Now().Add(downloadTokenExpiry).Unix(), Issuer: Organization, Subject: "certs"}) token, err := jwtToken.SignedString([]byte(serialNumber)) if err != nil { return "", errors.Wrap(ErrGetToken, err) } + + return token, nil +} + +// RetrieveCAToken generates a download token for a certificate. +// It verifies the token and serial number, and returns a signed JWT token string. +// The token is valid for 5 minutes. +// Parameters: +// - ctx: the context.Context object for the request +// +// Returns: +// - string: the signed JWT token string +// - error: an error if the authentication fails or any other error occurs +func (s *service) RetrieveCAToken(ctx context.Context) (string, error) { + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{ExpiresAt: time.Now().Add(downloadTokenExpiry).Unix(), Issuer: Organization, Subject: "certs"}) + token, err := jwtToken.SignedString([]byte(s.intermediateCA.SerialNumber)) + if err != nil { + return "", errors.Wrap(ErrGetToken, err) + } + return token, nil } @@ -414,6 +444,20 @@ func (s *service) GenerateCRL(ctx context.Context, caType CertType) ([]byte, err return pemBytes, nil } +func (s *service) GetSigningCA(ctx context.Context, token string) (Certificate, error) { + if _, err := jwt.ParseWithClaims(token, &jwt.StandardClaims{Issuer: Organization, Subject: "certs"}, func(token *jwt.Token) (interface{}, error) { + return []byte(s.intermediateCA.SerialNumber), nil + }); err != nil { + return Certificate{}, errors.Wrap(err, ErrMalformedEntity) + } + + cert, err := s.repo.RetrieveCert(ctx, s.intermediateCA.SerialNumber) + if err != nil { + return Certificate{}, errors.Wrap(ErrViewEntity, err) + } + return cert, nil +} + func (s *service) generateRootCA(ctx context.Context) (*CA, error) { rootKey, err := rsa.GenerateKey(rand.Reader, PrivateKeyBytes) if err != nil { diff --git a/tracing/certs.go b/tracing/certs.go index 61779af..787cb52 100644 --- a/tracing/certs.go +++ b/tracing/certs.go @@ -47,6 +47,12 @@ func (tm *tracingMiddleware) RetrieveCertDownloadToken(ctx context.Context, seri return tm.svc.RetrieveCertDownloadToken(ctx, serialNumber) } +func (tm *tracingMiddleware) RetrieveCAToken(ctx context.Context) (string, error) { + ctx, span := tm.tracer.Start(ctx, "get_CA_download_token") + defer span.End() + return tm.svc.RetrieveCAToken(ctx) +} + func (tm *tracingMiddleware) IssueCert(ctx context.Context, entityID, ttl string, ipAddrs []string, options certs.SubjectOptions) (certs.Certificate, error) { ctx, span := tm.tracer.Start(ctx, "issue_cert") defer span.End() @@ -82,3 +88,9 @@ func (tm *tracingMiddleware) GenerateCRL(ctx context.Context, caType certs.CertT defer span.End() return tm.svc.GenerateCRL(ctx, caType) } + +func (tm *tracingMiddleware) GetSigningCA(ctx context.Context, token string) (certs.Certificate, error) { + ctx, span := tm.tracer.Start(ctx, "get_signing_ca") + defer span.End() + return tm.svc.GetSigningCA(ctx, token) +}