From 2a4de1e96623080101c7bb819c9795503aa8af9d Mon Sep 17 00:00:00 2001 From: Steve Munene Date: Fri, 22 Nov 2024 18:28:47 +0300 Subject: [PATCH] NOISSUE - Fix OCSP with SDK (#52) * Fix ocsp using sdk Signed-off-by: nyagamunene * Update mocks Signed-off-by: nyagamunene * Address comments Signed-off-by: nyagamunene --------- Signed-off-by: nyagamunene --- api/http/common.go | 1 + api/http/transport.go | 33 +++++++++++- cli/certs.go | 21 ++++++-- sdk/mocks/sdk.go | 39 +++++++-------- sdk/sdk.go | 114 +++++++++++++++++++++++++++++++++++++++--- 5 files changed, 175 insertions(+), 33 deletions(-) diff --git a/api/http/common.go b/api/http/common.go index ce3a1a4..afd5d9d 100644 --- a/api/http/common.go +++ b/api/http/common.go @@ -15,6 +15,7 @@ import ( const ( // ContentType represents JSON content type. ContentType = "application/json" + OCSPType = "application/ocsp-response" ) // Response contains HTTP response specific methods. diff --git a/api/http/transport.go b/api/http/transport.go index 44aef4e..1d0df55 100644 --- a/api/http/transport.go +++ b/api/http/transport.go @@ -7,6 +7,7 @@ import ( "archive/zip" "bytes" "context" + "encoding/asn1" "encoding/json" "fmt" "io" @@ -24,6 +25,8 @@ import ( "golang.org/x/crypto/ocsp" ) +var idPKIXOCSPBasic = asn1.ObjectIdentifier([]int{1, 3, 6, 1, 5, 5, 7, 48, 1, 1}) + const ( offsetKey = "offset" limitKey = "limit" @@ -37,6 +40,16 @@ const ( defType = 1 ) +type responseASN1 struct { + Status asn1.Enumerated + Response responseBytes `asn1:"explicit,tag:0,optional"` +} + +type responseBytes struct { + ResponseType asn1.ObjectIdentifier + Response []byte +} + // MakeHandler returns a HTTP handler for API endpoints. func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http.Handler { opts := []kithttp.ServerOption{ @@ -267,12 +280,30 @@ func EncodeResponse(_ context.Context, w http.ResponseWriter, response interface func encodeOSCPResponse(_ context.Context, w http.ResponseWriter, response interface{}) error { res := response.(ocspRes) + if res.template.Certificate == nil { + ocspRes, err := asn1.Marshal(responseASN1{ + Status: asn1.Enumerated(ocsp.Malformed), + Response: responseBytes{ + ResponseType: idPKIXOCSPBasic, + }, + }) + if err != nil { + return err + } + + w.Header().Set("Content-Type", OCSPType) + if _, err := w.Write(ocspRes); err != nil { + return err + } + + return err + } ocspRes, err := ocsp.CreateResponse(res.issuerCert, res.template.Certificate, res.template, res.signer) if err != nil { return err } - w.Header().Set("Content-Type", "application/ocsp-response") + w.Header().Set("Content-Type", OCSPType) if _, err := w.Write(ocspRes); err != nil { return err } diff --git a/cli/certs.go b/cli/certs.go index 7f9b1b9..c1bd3bf 100644 --- a/cli/certs.go +++ b/cli/certs.go @@ -5,6 +5,7 @@ package cli import ( "encoding/json" + "os" ctxsdk "github.com/absmach/certs/sdk" "github.com/spf13/cobra" @@ -105,15 +106,29 @@ var cmdCerts = []cobra.Command{ }, }, { - Use: "ocsp ", + Use: "ocsp ", Short: "OCSP", - Long: `OCSP for a given serial number.`, + Long: `OCSP for a given serial number or certificate.`, Run: func(cmd *cobra.Command, args []string) { if len(args) != 1 { logUsageCmd(*cmd, cmd.Use) return } - response, err := sdk.OCSP(args[0]) + + var serialNumber, certContent string + + if _, statErr := os.Stat(args[0]); statErr == nil { + certBytes, err := os.ReadFile(args[0]) + if err != nil { + logErrorCmd(*cmd, err) + return + } + certContent = string(certBytes) + } else { + serialNumber = args[0] + } + + response, err := sdk.OCSP(serialNumber, certContent) if err != nil { logErrorCmd(*cmd, err) return diff --git a/sdk/mocks/sdk.go b/sdk/mocks/sdk.go index fd994e6..4410524 100644 --- a/sdk/mocks/sdk.go +++ b/sdk/mocks/sdk.go @@ -9,8 +9,6 @@ import ( errors "github.com/absmach/certs/errors" mock "github.com/stretchr/testify/mock" - ocsp "golang.org/x/crypto/ocsp" - sdk "github.com/absmach/certs/sdk" ) @@ -368,29 +366,27 @@ func (_c *MockSDK_ListCerts_Call) RunAndReturn(run func(sdk.PageMetadata) (sdk.C return _c } -// OCSP provides a mock function with given fields: serialNumber -func (_m *MockSDK) OCSP(serialNumber string) (*ocsp.Response, errors.SDKError) { - ret := _m.Called(serialNumber) +// OCSP provides a mock function with given fields: serialNumber, cert +func (_m *MockSDK) OCSP(serialNumber string, cert string) (sdk.OCSPResponse, errors.SDKError) { + ret := _m.Called(serialNumber, cert) if len(ret) == 0 { panic("no return value specified for OCSP") } - var r0 *ocsp.Response + var r0 sdk.OCSPResponse var r1 errors.SDKError - if rf, ok := ret.Get(0).(func(string) (*ocsp.Response, errors.SDKError)); ok { - return rf(serialNumber) + if rf, ok := ret.Get(0).(func(string, string) (sdk.OCSPResponse, errors.SDKError)); ok { + return rf(serialNumber, cert) } - if rf, ok := ret.Get(0).(func(string) *ocsp.Response); ok { - r0 = rf(serialNumber) + if rf, ok := ret.Get(0).(func(string, string) sdk.OCSPResponse); ok { + r0 = rf(serialNumber, cert) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*ocsp.Response) - } + r0 = ret.Get(0).(sdk.OCSPResponse) } - if rf, ok := ret.Get(1).(func(string) errors.SDKError); ok { - r1 = rf(serialNumber) + if rf, ok := ret.Get(1).(func(string, string) errors.SDKError); ok { + r1 = rf(serialNumber, cert) } else { if ret.Get(1) != nil { r1 = ret.Get(1).(errors.SDKError) @@ -407,23 +403,24 @@ type MockSDK_OCSP_Call struct { // OCSP is a helper method to define mock.On call // - serialNumber string -func (_e *MockSDK_Expecter) OCSP(serialNumber interface{}) *MockSDK_OCSP_Call { - return &MockSDK_OCSP_Call{Call: _e.mock.On("OCSP", serialNumber)} +// - cert string +func (_e *MockSDK_Expecter) OCSP(serialNumber interface{}, cert interface{}) *MockSDK_OCSP_Call { + return &MockSDK_OCSP_Call{Call: _e.mock.On("OCSP", serialNumber, cert)} } -func (_c *MockSDK_OCSP_Call) Run(run func(serialNumber string)) *MockSDK_OCSP_Call { +func (_c *MockSDK_OCSP_Call) Run(run func(serialNumber string, cert string)) *MockSDK_OCSP_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string)) + run(args[0].(string), args[1].(string)) }) return _c } -func (_c *MockSDK_OCSP_Call) Return(_a0 *ocsp.Response, _a1 errors.SDKError) *MockSDK_OCSP_Call { +func (_c *MockSDK_OCSP_Call) Return(_a0 sdk.OCSPResponse, _a1 errors.SDKError) *MockSDK_OCSP_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockSDK_OCSP_Call) RunAndReturn(run func(string) (*ocsp.Response, errors.SDKError)) *MockSDK_OCSP_Call { +func (_c *MockSDK_OCSP_Call) RunAndReturn(run func(string, string) (sdk.OCSPResponse, errors.SDKError)) *MockSDK_OCSP_Call { _c.Call.Return(run) return _c } diff --git a/sdk/sdk.go b/sdk/sdk.go index 9614ab7..9bf3ff3 100644 --- a/sdk/sdk.go +++ b/sdk/sdk.go @@ -6,11 +6,15 @@ package sdk import ( "archive/zip" "bytes" + "crypto" "crypto/tls" + "crypto/x509" "encoding/json" + "encoding/pem" "fmt" "io" "log" + "math/big" "net/http" "net/url" "strconv" @@ -24,6 +28,7 @@ import ( const ( certsEndpoint = "certs" issueCertEndpoint = "certs/issue" + emptyOCSPbody = 22 ) const ( @@ -40,6 +45,35 @@ const ( // ContentType represents all possible content types. type ContentType string +type CertStatus int + +const ( + Valid CertStatus = iota + Revoked + Unknown +) + +const ( + valid = "Valid" + revoked = "Revoked" + unknown = "Unknown" +) + +func (c CertStatus) String() string { + switch c { + case Valid: + return valid + case Revoked: + return revoked + default: + return unknown + } +} + +func (c CertStatus) MarshalJSON() ([]byte, error) { + return json.Marshal(c.String()) +} + type PageMetadata struct { Total uint64 `json:"total,omitempty"` Offset uint64 `json:"offset,omitempty"` @@ -105,6 +139,15 @@ type CertificateBundle struct { PrivateKey []byte `json:"private_key"` } +type OCSPResponse struct { + Status CertStatus `json:"status"` + SerialNumber *big.Int `json:"serial_number"` + RevokedAt *time.Time `json:"revoked_at,omitempty"` + ProducedAt *time.Time `json:"produced_at,omitempty"` + Certificate []byte `json:"certificate,omitempty"` + IssuerHash string `json:"issuer_hash,omitempty"` +} + type SDK interface { // IssueCert issues a certificate for a thing required for mTLS. // @@ -165,9 +208,9 @@ type SDK interface { // OCSP checks the revocation status of a certificate // // example: - // response, _ := sdk.OCSP("serialNumber") + // response, _ := sdk.OCSP("serialNumber", "") // fmt.Println(response) - OCSP(serialNumber string) (*ocsp.Response, errors.SDKError) + OCSP(serialNumber, cert string) (OCSPResponse, errors.SDKError) // ViewCA views the signing certificate // @@ -318,18 +361,73 @@ func (sdk mgSDK) RetrieveCertDownloadToken(serialNumber string) (Token, errors.S return tk, nil } -func (sdk mgSDK) OCSP(serialNumber string) (*ocsp.Response, errors.SDKError) { +func (sdk mgSDK) OCSP(serialNumber, cert string) (OCSPResponse, errors.SDKError) { + var sn *big.Int + var ok bool + + if serialNumber == "" && cert == "" { + return OCSPResponse{}, errors.NewSDKError(errors.New("either serial number or certificate must be provided")) + } + + if serialNumber != "" { + sn, ok = new(big.Int).SetString(serialNumber, 10) + if !ok { + return OCSPResponse{}, errors.NewSDKError(errors.New("invalid serial number")) + } + } + + if cert != "" { + block, _ := pem.Decode([]byte(cert)) + if block == nil { + return OCSPResponse{}, errors.NewSDKError(errors.New("failed to decode PEM block")) + } + + parsedCert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return OCSPResponse{}, errors.NewSDKError(err) + } + sn = parsedCert.SerialNumber + } + + req := ocsp.Request{ + SerialNumber: sn, + HashAlgorithm: crypto.SHA256, + IssuerNameHash: nil, + IssuerKeyHash: nil, + } + + requestBody, err := req.Marshal() + if err != nil { + return OCSPResponse{}, errors.NewSDKError(err) + } + url := fmt.Sprintf("%s/%s/ocsp", sdk.certsURL, certsEndpoint) - requestBody := []byte(serialNumber) _, body, sdkerr := sdk.processRequest(http.MethodPost, url, requestBody, nil, http.StatusOK) if sdkerr != nil { - return &ocsp.Response{}, sdkerr + return OCSPResponse{}, sdkerr } - ocspResp, err := ocsp.ParseResponse(body, nil) + + if len(body) == emptyOCSPbody { + return OCSPResponse{ + Status: CertStatus(Unknown), + SerialNumber: sn, + }, nil + } + res, err := ocsp.ParseResponse(body, nil) if err != nil { - return &ocsp.Response{}, errors.NewSDKError(err) + return OCSPResponse{}, errors.NewSDKError(err) + } + + resp := OCSPResponse{ + Status: CertStatus(res.Status), + SerialNumber: res.SerialNumber, + Certificate: res.Certificate.Raw, + RevokedAt: &res.RevokedAt, + IssuerHash: res.IssuerHash.String(), + ProducedAt: &res.ProducedAt, } - return ocspResp, nil + + return resp, nil } func (sdk mgSDK) ViewCA(token string) (Certificate, errors.SDKError) {