diff --git a/internal/ca/ca.go b/internal/ca/ca.go index 66f1b560d4..f3a412804f 100644 --- a/internal/ca/ca.go +++ b/internal/ca/ca.go @@ -8,6 +8,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "errors" "fmt" "time" @@ -34,117 +35,78 @@ type CA struct { // New creates a new CA. func New(namespace string) (*CA, error) { - rootSerialNumber, err := crypto.GenerateCertificateSerialNumber() - if err != nil { - return nil, err - } + now := time.Now() + notBefore := now.Add(-time.Hour) + notAfter := now.AddDate(10, 0, 0) root := &x509.Certificate{ - SerialNumber: rootSerialNumber, Subject: pkix.Name{CommonName: "system:coordinator:root"}, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), + NotBefore: notBefore, + NotAfter: notAfter, IsCA: true, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } rootPrivKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) if err != nil { - return nil, fmt.Errorf("failed to generate RSA private key: %w", err) + return nil, fmt.Errorf("generating root private key: %w", err) } - rootBytes, err := x509.CreateCertificate(rand.Reader, root, root, &rootPrivKey.PublicKey, rootPrivKey) + rootPEM, err := createCert(root, root, &rootPrivKey.PublicKey, rootPrivKey) if err != nil { - return nil, fmt.Errorf("failed to create root certificate: %w", err) - } - rootPEM := new(bytes.Buffer) - if err := pem.Encode(rootPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: rootBytes, - }); err != nil { - return nil, fmt.Errorf("failed to encode root certificate: %w", err) + return nil, fmt.Errorf("creating root certificate: %w", err) } - intermSerialNumber, err := crypto.GenerateCertificateSerialNumber() - if err != nil { - return nil, err - } intermed := &x509.Certificate{ - SerialNumber: intermSerialNumber, Subject: pkix.Name{CommonName: "system:coordinator:meshCA"}, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), + NotBefore: notBefore, + NotAfter: notAfter, IsCA: true, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } intermPrivKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) if err != nil { - return nil, fmt.Errorf("failed to generate RSA private key: %w", err) + return nil, fmt.Errorf("generating intermediate private key: %w", err) } - intermBytes, err := x509.CreateCertificate(rand.Reader, intermed, root, &intermPrivKey.PublicKey, rootPrivKey) + intermPEM, err := createCert(intermed, root, &intermPrivKey.PublicKey, rootPrivKey) if err != nil { - return nil, fmt.Errorf("failed to create intermediate certificate: %w", err) - } - intermPEM := new(bytes.Buffer) - if err := pem.Encode(intermPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: intermBytes, - }); err != nil { - return nil, fmt.Errorf("failed to encode intermediate certificate: %w", err) + return nil, fmt.Errorf("creating intermediate certificate: %w", err) } - intermCASerialNumber, err := crypto.GenerateCertificateSerialNumber() - if err != nil { - return nil, err - } meshCA := &x509.Certificate{ - SerialNumber: intermCASerialNumber, Subject: pkix.Name{CommonName: "system:coordinator:meshCA"}, - NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), + NotBefore: notBefore, + NotAfter: notAfter, IsCA: true, KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, BasicConstraintsValid: true, } - meshCABytes, err := x509.CreateCertificate(rand.Reader, meshCA, meshCA, &intermPrivKey.PublicKey, intermPrivKey) + meshCAPEM, err := createCert(meshCA, meshCA, &intermPrivKey.PublicKey, intermPrivKey) if err != nil { - return nil, fmt.Errorf("failed to create meshCA certificate: %w", err) - } - meshCAPEM := new(bytes.Buffer) - if err := pem.Encode(meshCAPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: meshCABytes, - }); err != nil { - return nil, fmt.Errorf("failed to encode meshCA certificate: %w", err) + return nil, fmt.Errorf("creating mesh certificate: %w", err) } return &CA{ rootPrivKey: rootPrivKey, rootCert: root, - rootPEM: rootPEM.Bytes(), + rootPEM: rootPEM, intermPrivKey: intermPrivKey, intermCert: intermed, - intermPEM: intermPEM.Bytes(), + intermPEM: intermPEM, meshCACert: meshCA, - meshCAPEM: meshCAPEM.Bytes(), + meshCAPEM: meshCAPEM, namespace: namespace, }, nil } // NewAttestedMeshCert creates a new attested mesh certificate. func (c *CA) NewAttestedMeshCert(dnsNames []string, extensions []pkix.Extension, subjectPublicKey any) ([]byte, error) { - serialNumber, err := crypto.GenerateCertificateSerialNumber() - if err != nil { - return nil, err - } - now := time.Now() certTemplate := &x509.Certificate{ - SerialNumber: serialNumber, Subject: pkix.Name{CommonName: dnsNames[0]}, Issuer: pkix.Name{CommonName: "system:coordinator:meshCA"}, - NotBefore: now.Add(-2 * time.Hour), - NotAfter: now.Add(354 * 24 * time.Hour), + NotBefore: now.Add(-time.Hour), + NotAfter: now.AddDate(1, 0, 0), ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, KeyUsage: x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, @@ -152,20 +114,12 @@ func (c *CA) NewAttestedMeshCert(dnsNames []string, extensions []pkix.Extension, DNSNames: dnsNames, } - certBytes, err := x509.CreateCertificate(rand.Reader, certTemplate, c.meshCACert, subjectPublicKey, c.intermPrivKey) + certPEM, err := createCert(certTemplate, c.meshCACert, subjectPublicKey, c.intermPrivKey) if err != nil { return nil, fmt.Errorf("failed to create certificate: %w", err) } - certPEM := new(bytes.Buffer) - if err := pem.Encode(certPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }); err != nil { - return nil, fmt.Errorf("failed to encode certificate: %w", err) - } - - return certPEM.Bytes(), nil + return certPEM, nil } // GetRootCACert returns the root certificate of the CA in PEM format. @@ -182,3 +136,36 @@ func (c *CA) GetIntermCert() []byte { func (c *CA) GetMeshCACert() []byte { return c.meshCAPEM } + +func createCert(template, parent *x509.Certificate, pub, priv any) ([]byte, error) { + if parent == nil { + return nil, errors.New("parent cannot be nil") + } + if template == nil { + return nil, errors.New("cert cannot be nil") + } + if template.SerialNumber != nil { + return nil, errors.New("cert serial number must be nil") + } + + serialNum, err := crypto.GenerateCertificateSerialNumber() + if err != nil { + return nil, fmt.Errorf("generating serial number: %w", err) + } + template.SerialNumber = serialNum + + certBytes, err := x509.CreateCertificate(rand.Reader, template, parent, pub, priv) + if err != nil { + return nil, fmt.Errorf("creating certificate: %w", err) + } + + certPEM := new(bytes.Buffer) + if err := pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }); err != nil { + return nil, fmt.Errorf("encoding certificate: %w", err) + } + + return certPEM.Bytes(), nil +} diff --git a/internal/ca/ca_test.go b/internal/ca/ca_test.go index 1a6df09ff0..75da83a315 100644 --- a/internal/ca/ca_test.go +++ b/internal/ca/ca_test.go @@ -1,8 +1,13 @@ package ca import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/x509" + "crypto/x509/pkix" "encoding/pem" + "math/big" "testing" "github.com/stretchr/testify/assert" @@ -37,3 +42,118 @@ func TestNewCA(t *testing.T) { _, err = cert.Verify(opts) require.NoError(err) } + +func TestAttestedMeshCert(t *testing.T) { + req := require.New(t) + + testCases := map[string]struct { + dnsNames []string + extensions []pkix.Extension + subjectPub any + wantErr bool + }{ + "valid": { + dnsNames: []string{"foo", "bar"}, + extensions: []pkix.Extension{}, + subjectPub: newKey(req).Public(), + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + ca, err := New("namespace") + require.NoError(err) + + cert, err := ca.NewAttestedMeshCert(tc.dnsNames, tc.extensions, tc.subjectPub) + if tc.wantErr { + assert.Error(err) + return + } + assert.NoError(err) + assert.NotNil(cert) + + assertValidPEM(assert, cert) + }) + } +} + +func TestCerateCert(t *testing.T) { + req := require.New(t) + + testCases := map[string]struct { + template *x509.Certificate + parent *x509.Certificate + pub any + priv any + wantErr bool + }{ + "parent signed": { + template: &x509.Certificate{}, + parent: &x509.Certificate{}, + pub: newKey(req).Public(), + priv: newKey(req), + }, + "template nil": { + parent: &x509.Certificate{}, + pub: newKey(req).Public(), + priv: newKey(req), + wantErr: true, + }, + "parent nil": { + template: &x509.Certificate{}, + pub: newKey(req).Public(), + priv: newKey(req), + wantErr: true, + }, + "pub nil": { + template: &x509.Certificate{}, + parent: &x509.Certificate{}, + priv: newKey(req), + wantErr: true, + }, + "priv nil": { + template: &x509.Certificate{}, + parent: &x509.Certificate{}, + pub: newKey(req).Public(), + wantErr: true, + }, + "serial number already set": { + template: &x509.Certificate{SerialNumber: big.NewInt(1)}, + parent: &x509.Certificate{}, + pub: newKey(req).Public(), + priv: newKey(req), + wantErr: true, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + + pem, err := createCert(tc.template, tc.parent, tc.pub, tc.priv) + if tc.wantErr { + assert.Error(err) + return + } + + assert.NoError(err) + assertValidPEM(assert, pem) + }) + } +} + +func newKey(require *require.Assertions) *ecdsa.PrivateKey { + key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + require.NoError(err) + return key +} + +func assertValidPEM(assert *assert.Assertions, data []byte) { + block, _ := pem.Decode(data) + assert.NotNil(block) + _, err := x509.ParseCertificate(block.Bytes) + assert.NoError(err) +}