Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cli: support for authenticating with private keys and certificates stored in PKCS #11 backend #771

Merged
merged 8 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions cli/internal/certcache/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"errors"

"github.com/edgelesssys/marblerun/cli/internal/file"
"github.com/edgelesssys/marblerun/cli/internal/pkcs11"
"github.com/edgelesssys/marblerun/util"
"github.com/spf13/afero"
"github.com/spf13/pflag"
Expand Down Expand Up @@ -43,21 +44,50 @@ func LoadCoordinatorCachedCert(flags *pflag.FlagSet, fs afero.Fs) (root, interme
}

// LoadClientCert parses the command line flags to load a TLS client certificate.
func LoadClientCert(flags *pflag.FlagSet) (*tls.Certificate, error) {
// The returned cancel function must be called only after the certificate is no longer needed.
func LoadClientCert(flags *pflag.FlagSet) (crt *tls.Certificate, cancel func() error, err error) {
certFile, err := flags.GetString("cert")
if err != nil {
return nil, err
return nil, nil, err
}
keyFile, err := flags.GetString("key")
if err != nil {
return nil, err
return nil, nil, err
}

pkcs11ConfigFile, err := flags.GetString("pkcs11-config")
if err != nil {
return nil, nil, err
}
pkcs11KeyID, err := flags.GetString("pkcs11-key-id")
if err != nil {
return nil, nil, err
}
pkcs11KeyLabel, err := flags.GetString("pkcs11-key-label")
if err != nil {
return nil, nil, err
}
clientCert, err := tls.LoadX509KeyPair(certFile, keyFile)
pkcs11CertID, err := flags.GetString("pkcs11-cert-id")
if err != nil {
return nil, err
return nil, nil, err
}
pkcs11CertLabel, err := flags.GetString("pkcs11-cert-label")
if err != nil {
return nil, nil, err
}

var clientCert tls.Certificate
switch {
case pkcs11ConfigFile != "":
clientCert, cancel, err = pkcs11.LoadX509KeyPair(pkcs11ConfigFile, pkcs11KeyID, pkcs11KeyLabel, pkcs11CertID, pkcs11CertLabel)
case certFile != "" && keyFile != "":
clientCert, err = tls.LoadX509KeyPair(certFile, keyFile)
cancel = func() error { return nil }
default:
err = errors.New("neither PKCS#11 nor file-based client certificate can be loaded with the provided flags")
}

return &clientCert, nil
return &clientCert, cancel, err
}

func saveCert(fh *file.Handler, root, intermediate *x509.Certificate) error {
Expand Down
21 changes: 21 additions & 0 deletions cli/internal/cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/edgelesssys/marblerun/cli/internal/kube"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"k8s.io/apimachinery/pkg/util/version"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/clientcmd"
Expand All @@ -32,6 +33,26 @@ func webhookDNSName(namespace string) string {
return "marble-injector." + namespace
}

func addClientAuthFlags(cmd *cobra.Command, flags *pflag.FlagSet) {
flags.StringP("cert", "c", "", "PEM encoded admin certificate file")
flags.StringP("key", "k", "", "PEM encoded admin key file")
cmd.MarkFlagsRequiredTogether("key", "cert")

flags.String("pkcs11-config", "", "Path to a PKCS#11 configuration file to load the client certificate with")
flags.String("pkcs11-key-id", "", "ID of the private key in the PKCS#11 token")
flags.String("pkcs11-key-label", "", "Label of the private key in the PKCS#11 token")
flags.String("pkcs11-cert-id", "", "ID of the certificate in the PKCS#11 token")
flags.String("pkcs11-cert-label", "", "Label of the certificate in the PKCS#11 token")
must(cobra.MarkFlagFilename(flags, "pkcs11-config", "json"))
cmd.MarkFlagsOneRequired("pkcs11-key-id", "pkcs11-key-label", "cert")
cmd.MarkFlagsOneRequired("pkcs11-cert-id", "pkcs11-cert-label", "cert")

cmd.MarkFlagsMutuallyExclusive("pkcs11-config", "cert")
cmd.MarkFlagsMutuallyExclusive("pkcs11-config", "key")
cmd.MarkFlagsOneRequired("pkcs11-config", "cert")
cmd.MarkFlagsOneRequired("pkcs11-config", "key")
}

// parseRestFlags parses the command line flags used to configure the REST client.
func parseRestFlags(cmd *cobra.Command) (api.VerifyOptions, string, error) {
eraConfig, err := cmd.Flags().GetString("era-config")
Expand Down
37 changes: 21 additions & 16 deletions cli/internal/cmd/manifestUpdate.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,7 @@ An admin certificate specified in the original manifest is needed to verify the
Args: cobra.ExactArgs(2),
RunE: runUpdateApply,
}

cmd.Flags().StringP("cert", "c", "", "PEM encoded admin certificate file (required)")
must(cmd.MarkFlagRequired("cert"))
cmd.Flags().StringP("key", "k", "", "PEM encoded admin key file (required)")
must(cmd.MarkFlagRequired("key"))
addClientAuthFlags(cmd, cmd.Flags())

return cmd
}
Expand All @@ -66,11 +62,8 @@ All participants must use the same manifest to acknowledge the pending update.
Args: cobra.ExactArgs(2),
RunE: runUpdateAcknowledge,
}
addClientAuthFlags(cmd, cmd.Flags())

cmd.Flags().StringP("cert", "c", "", "PEM encoded admin certificate file (required)")
must(cmd.MarkFlagRequired("cert"))
cmd.Flags().StringP("key", "k", "", "PEM encoded admin key file (required)")
must(cmd.MarkFlagRequired("key"))
return cmd
}

Expand All @@ -83,11 +76,8 @@ func newUpdateCancel() *cobra.Command {
Args: cobra.ExactArgs(1),
RunE: runUpdateCancel,
}
addClientAuthFlags(cmd, cmd.Flags())

cmd.Flags().StringP("cert", "c", "", "PEM encoded admin certificate file (required)")
must(cmd.MarkFlagRequired("cert"))
cmd.Flags().StringP("key", "k", "", "PEM encoded admin key file (required)")
must(cmd.MarkFlagRequired("key"))
return cmd
}

Expand Down Expand Up @@ -116,10 +106,15 @@ func runUpdateApply(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

manifest, err := loadManifestFile(file.New(manifestFile, fs))
if err != nil {
Expand All @@ -142,10 +137,15 @@ func runUpdateAcknowledge(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

manifest, err := loadManifestFile(file.New(manifestFile, fs))
if err != nil {
Expand Down Expand Up @@ -177,10 +177,15 @@ func runUpdateCancel(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

if err := api.ManifestUpdateCancel(cmd.Context(), hostname, root, keyPair); err != nil {
return fmt.Errorf("canceling update: %w", err)
Expand Down
6 changes: 1 addition & 5 deletions cli/internal/cmd/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@ func NewSecretCmd() *cobra.Command {
Manage secrets for the MarbleRun Coordinator.
Set or retrieve a secret defined in the manifest.`,
}

cmd.PersistentFlags().StringP("cert", "c", "", "PEM encoded MarbleRun user certificate file (required)")
cmd.PersistentFlags().StringP("key", "k", "", "PEM encoded MarbleRun user key file (required)")
must(cmd.MarkPersistentFlagRequired("key"))
must(cmd.MarkPersistentFlagRequired("cert"))
addClientAuthFlags(cmd, cmd.PersistentFlags())

cmd.AddCommand(newSecretSet())
cmd.AddCommand(newSecretGet())
Expand Down
7 changes: 6 additions & 1 deletion cli/internal/cmd/secretGet.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,15 @@ func runSecretGet(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

getSecrets := func(ctx context.Context) (map[string]manifest.Secret, error) {
return api.SecretGet(ctx, hostname, root, keyPair, secretIDs)
Expand Down
7 changes: 6 additions & 1 deletion cli/internal/cmd/secretSet.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,15 @@ func runSecretSet(cmd *cobra.Command, args []string) error {
if err != nil {
return err
}
keyPair, err := certcache.LoadClientCert(cmd.Flags())
keyPair, cancel, err := certcache.LoadClientCert(cmd.Flags())
if err != nil {
return err
}
defer func() {
if err := cancel(); err != nil {
cmd.PrintErrf("Failed to close PKCS #11 session: %s\n", err)
}
}()

if err := api.SecretSet(cmd.Context(), hostname, root, keyPair, newSecrets); err != nil {
return err
Expand Down
82 changes: 82 additions & 0 deletions cli/internal/pkcs11/pkcs11.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright (c) Edgeless Systems GmbH

SPDX-License-Identifier: BUSL-1.1
*/

package pkcs11

import (
"crypto"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"

"github.com/ThalesGroup/crypto11"
)

// LoadX509KeyPair loads a [tls.Certificate] using the provided PKCS#11 configuration file.
// The returned cancel function must be called to release the PKCS#11 resources only after the certificate is no longer needed.
func LoadX509KeyPair(pkcs11ConfigPath string, keyID, keyLabel, certID, certLabel string) (crt tls.Certificate, cancel func() error, err error) {
pkcs11, err := crypto11.ConfigureFromFile(pkcs11ConfigPath)
if err != nil {
return crt, nil, err
}
defer func() {
if err != nil {
err = errors.Join(err, pkcs11.Close())
}
}()

var keyIDBytes, keyLabelBytes, certIDBytes, certLabelBytes []byte
if keyID != "" {
keyIDBytes = []byte(keyID)
}
if keyLabel != "" {
keyLabelBytes = []byte(keyLabel)
}
if certID != "" {
certIDBytes = []byte(certID)
}
if certLabel != "" {
certLabelBytes = []byte(certLabel)
}

privateKey, err := loadPrivateKey(pkcs11, keyIDBytes, keyLabelBytes)
if err != nil {
return crt, nil, err
}
cert, err := loadCertificate(pkcs11, certIDBytes, certLabelBytes)
if err != nil {
return crt, nil, err
}

return tls.Certificate{
Certificate: [][]byte{cert.Raw},
PrivateKey: privateKey,
Leaf: cert,
}, pkcs11.Close, nil
}

func loadPrivateKey(pkcs11 *crypto11.Context, id, label []byte) (crypto.Signer, error) {
priv, err := pkcs11.FindKeyPair(id, label)
if err != nil {
return nil, err
}
if priv == nil {
return nil, fmt.Errorf("no key pair found for id \"%s\" and label \"%s\"", id, label)
}
return priv, nil
}

func loadCertificate(pkcs11 *crypto11.Context, id, label []byte) (*x509.Certificate, error) {
cert, err := pkcs11.FindCertificate(id, label, nil)
if err != nil {
return nil, err
}
if cert == nil {
return nil, fmt.Errorf("no certificate found for id \"%s\" and label \"%s\"", id, label)
}
return cert, nil
}
Loading