Skip to content

Commit

Permalink
handle multiple private key types
Browse files Browse the repository at this point in the history
Signed-off-by: nyagamunene <[email protected]>
  • Loading branch information
nyagamunene committed Nov 29, 2024
1 parent 290d256 commit 459ffb3
Show file tree
Hide file tree
Showing 11 changed files with 86 additions and 41 deletions.
15 changes: 10 additions & 5 deletions api/http/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"crypto"
"crypto/x509"
"encoding/pem"
"fmt"

Check failure on line 11 in api/http/endpoint.go

View workflow job for this annotation

GitHub Actions / Lint and Build

"fmt" imported and not used (typecheck)

Check failure on line 11 in api/http/endpoint.go

View workflow job for this annotation

GitHub Actions / Lint and Build

"fmt" imported and not used) (typecheck)

Check failure on line 11 in api/http/endpoint.go

View workflow job for this annotation

GitHub Actions / Lint and Build

"fmt" imported and not used) (typecheck)

Check failure on line 11 in api/http/endpoint.go

View workflow job for this annotation

GitHub Actions / Lint and Build

"fmt" imported and not used) (typecheck)
"math/rand"
"strings"
"time"
Expand Down Expand Up @@ -316,15 +317,15 @@ func createCSREndpoint(svc certs.Service) endpoint.Endpoint {
if err := req.validate(); err != nil {
return createCSRRes{created: false}, err
}

csr, err := svc.CreateCSR(ctx, req.Metadata, req.privKey)
if err != nil {
return createCSRRes{created: false}, err
}

return createCSRRes{
created: true,
CSR: csr,
created: true,
CSR: csr.CSR,
PrivateKey: csr.PrivateKey,
}, nil
}
}
Expand All @@ -342,8 +343,12 @@ func signCSREndpoint(svc certs.Service) endpoint.Endpoint {
}

return signCSRRes{
crt: cert,
signed: true,
SerialNumber: cert.SerialNumber,
Certificate: string(cert.Certificate),
Revoked: cert.Revoked,
ExpiryTime: cert.ExpiryTime,
EntityID: cert.EntityID,
signed: true,
}, nil
}
}
6 changes: 2 additions & 4 deletions api/http/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
package http

import (
"crypto/rsa"

"github.com/absmach/certs"
"github.com/absmach/certs/errors"
"golang.org/x/crypto/ocsp"
Expand Down Expand Up @@ -92,8 +90,8 @@ func (req ocspReq) validate() error {

type createCSRReq struct {
Metadata certs.CSRMetadata `json:"metadata"`
PrivateKey []byte `json:"private_Key"`
privKey *rsa.PrivateKey
PrivateKey string `json:"private_key"`
privKey any
}

func (req createCSRReq) validate() error {
Expand Down
14 changes: 9 additions & 5 deletions api/http/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"time"

"github.com/absmach/certs"
"golang.org/x/crypto/ocsp"
)

Expand Down Expand Up @@ -204,8 +203,9 @@ type fileDownloadRes struct {
}

type createCSRRes struct {
certs.CSR
created bool
CSR []byte `json:"csr"`
PrivateKey []byte `json:"private_key"`
created bool
}

func (res createCSRRes) Code() int {
Expand All @@ -225,8 +225,12 @@ func (res createCSRRes) Empty() bool {
}

type signCSRRes struct {
crt certs.Certificate
signed bool
SerialNumber string `json:"serial_number"`
Certificate string `json:"certificate,omitempty"`
Revoked bool `json:"revoked"`
ExpiryTime time.Time `json:"expiry_time"`
EntityID string `json:"entity_id"`
signed bool
}

func (res signCSRRes) Code() int {
Expand Down
10 changes: 7 additions & 3 deletions api/http/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func MakeHandler(svc certs.Service, logger *slog.Logger, instanceID string) http
opts...,
), "download_ca").ServeHTTP)
r.Route("/csrs", func(r chi.Router) {
r.Post("/create", otelhttp.NewHandler(kithttp.NewServer(
r.Post("/", otelhttp.NewHandler(kithttp.NewServer(
createCSREndpoint(svc),
decodeCreateCSR,
EncodeResponse,
Expand Down Expand Up @@ -286,12 +286,16 @@ func decodeCreateCSR(_ context.Context, r *http.Request) (interface{}, error) {
return nil, err
}

block, _ := pem.Decode(req.PrivateKey)
block, _ := pem.Decode([]byte(req.PrivateKey))
if block != nil {
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)

privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
fmt.Println(err)

if err != nil {
return nil, errors.Wrap(ErrInvalidRequest, err)
}

req.privKey = privateKey
}

Expand Down
2 changes: 1 addition & 1 deletion api/logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func (lm *loggingMiddleware) GetChainCA(ctx context.Context, token string) (cert
return lm.svc.GetChainCA(ctx, token)
}

func (lm *loggingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (csr certs.CSR, err error) {
func (lm *loggingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (csr certs.CSR, err error) {
defer func(begin time.Time) {
message := fmt.Sprintf("Method create_csr took %s to complete", time.Since(begin))
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion api/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ func (mm *metricsMiddleware) GetChainCA(ctx context.Context, token string) (cert
return mm.svc.GetChainCA(ctx, token)
}

func (mm *metricsMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) {
func (mm *metricsMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, error) {
defer func(begin time.Time) {
mm.counter.With("method", "create_csr").Add(1)
mm.latency.With("method", "create_csr").Observe(time.Since(begin).Seconds())
Expand Down
2 changes: 1 addition & 1 deletion certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ type Service interface {
RemoveCert(ctx context.Context, entityId string) error

// CreateCSR creates a new Certificate Signing Request
CreateCSR(ctx context.Context, metadata CSRMetadata, privKey *rsa.PrivateKey) (CSR, error)
CreateCSR(ctx context.Context, metadata CSRMetadata, privKey any) (CSR, error)

// SignCSR parses and signs a CSR
SignCSR(ctx context.Context, entityID, ttl string, csr CSR) (Certificate, error)
Expand Down
5 changes: 3 additions & 2 deletions cli/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"os"

"github.com/absmach/certs/errors"
ctxsdk "github.com/absmach/certs/sdk"
"github.com/spf13/cobra"
)
Expand Down Expand Up @@ -243,14 +244,14 @@ var cmdCerts = []cobra.Command{
Short: "Create CSR",
Long: `Creates a CSR.`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 0 {
if len(args) != 2 {
logUsageCmd(*cmd, cmd.Use)
return
}

var pm ctxsdk.PageMetadata
if err := json.Unmarshal([]byte(args[0]), &pm); err != nil {
logErrorCmd(*cmd, err)
logErrorCmd(*cmd, errors.New("here 1"))
return
}

Expand Down
41 changes: 26 additions & 15 deletions sdk/sdk.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,17 +549,19 @@ func (sdk mgSDK) GetCAToken() (Token, errors.SDKError) {

func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKError) {
r := csrReq{
Organization: pm.Organization,
OrganizationalUnit: pm.OrganizationalUnit,
Country: pm.Country,
Province: pm.Province,
Locality: pm.Locality,
StreetAddress: pm.StreetAddress,
PostalCode: pm.PostalCode,
DNSNames: pm.DNSNames,
IPAddresses: pm.IPAddresses,
EmailAddresses: pm.EmailAddresses,
PrivateKey: privKey,
Metadata: meta{
Organization: pm.Organization,
OrganizationalUnit: pm.OrganizationalUnit,
Country: pm.Country,
Province: pm.Province,
Locality: pm.Locality,
StreetAddress: pm.StreetAddress,
PostalCode: pm.PostalCode,
DNSNames: pm.DNSNames,
IPAddresses: pm.IPAddresses,
EmailAddresses: pm.EmailAddresses,
},
PrivateKey: privKey,
}
d, err := json.Marshal(r)
if err != nil {
Expand All @@ -568,7 +570,7 @@ func (sdk mgSDK) CreateCSR(pm PageMetadata, privKey []byte) (CSR, errors.SDKErro
url := fmt.Sprintf("%s/%s/%s", sdk.certsURL, certsEndpoint, csrEndpoint)
_, body, sdkerr := sdk.processRequest(http.MethodPost, url, d, nil, http.StatusOK)
if sdkerr != nil {
return CSR{}, sdkerr
return CSR{}, errors.NewSDKError(err)
}

var csr CSR
Expand All @@ -587,11 +589,16 @@ func (sdk mgSDK) SignCSR(entityID, ttl string, csr []byte) (Certificate, errors.
return Certificate{}, errors.NewSDKError(err)
}

_, _, sdkerr := sdk.processRequest(http.MethodPatch, url, nil, nil, http.StatusOK)
_, body, sdkerr := sdk.processRequest(http.MethodPost, url, nil, nil, http.StatusOK)
if sdkerr != nil {
return Certificate{}, sdkerr
}
return Certificate{}, nil

var cert Certificate
if err := json.Unmarshal(body, &cert); err != nil {
return Certificate{}, errors.NewSDKError(err)
}
return cert, nil
}

func NewSDK(conf Config) SDK {
Expand Down Expand Up @@ -703,6 +710,11 @@ type certReq struct {
}

type csrReq struct {
Metadata meta `json:"metadata"`
PrivateKey []byte `json:"private_key"`
}

type meta struct {
Organization []string `json:"organization"`
OrganizationalUnit []string `json:"organizational_unit"`
Country []string `json:"country"`
Expand All @@ -713,5 +725,4 @@ type csrReq struct {
DNSNames []string `json:"dns_names"`
IPAddresses []string `json:"ip_addresses"`
EmailAddresses []string `json:"email_addresses"`
PrivateKey []byte `json:"private_key"`
}
28 changes: 25 additions & 3 deletions service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package certs

import (
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
Expand Down Expand Up @@ -402,7 +404,7 @@ func (s *service) GetChainCA(ctx context.Context, token string) (Certificate, er
return s.getConcatCAs(ctx)
}

func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey *rsa.PrivateKey) (CSR, error) {
func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey any) (CSR, error) {
template := &x509.CertificateRequest{
Subject: pkix.Name{
CommonName: metadata.CommonName,
Expand Down Expand Up @@ -435,9 +437,29 @@ func (s *service) CreateCSR(ctx context.Context, metadata CSRMetadata, privKey *
Bytes: csrBytes,
})

var privKeyBytes []byte
var privKeyType string
switch key := privKey.(type) {
case *rsa.PrivateKey:
privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key)
privKeyType = "RSA PRIVATE KEY"
case *ecdsa.PrivateKey:
privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key)
privKeyType = "EC PRIVATE KEY"
case ed25519.PrivateKey:
privKeyBytes, err = x509.MarshalPKCS8PrivateKey(key)
privKeyType = "PRIVATE KEY"
default:
return CSR{}, errors.Wrap(ErrCreateEntity, errors.New("unsupported private key type"))
}

if err != nil {
return CSR{}, errors.Wrap(ErrCreateEntity, err)
}

privKeyPEM := pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(privKey),
Type: privKeyType,
Bytes: privKeyBytes,
})

csr := CSR{
Expand Down
2 changes: 1 addition & 1 deletion tracing/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (tm *tracingMiddleware) GetChainCA(ctx context.Context, token string) (cert
return tm.svc.GetChainCA(ctx, token)
}

func (tm *tracingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey *rsa.PrivateKey) (certs.CSR, error) {
func (tm *tracingMiddleware) CreateCSR(ctx context.Context, metadata certs.CSRMetadata, privKey any) (certs.CSR, error) {
ctx, span := tm.tracer.Start(ctx, "create_csr")
defer span.End()
return tm.svc.CreateCSR(ctx, metadata, privKey)
Expand Down

0 comments on commit 459ffb3

Please sign in to comment.