diff --git a/README.md b/README.md index dea155e..b0a12f1 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,53 @@ A minimal configuration file looks like this: Testing Guidance ================ +Disabling tests +--------------- + +To disable specific tests, set the environment variable `CRYPTO11_SKIP=` where `` is a comma-separated +list of the following options: + +* `CERTS` - disables certificate-related tests. Needed for AWS CloudHSM, which doesn't support certificates. +* `OAEP_LABEL` - disables RSA OAEP encryption tests that use source data encoding parameter (also known as a 'label' +in some crypto libraries). Needed for AWS CloudHSM. +* `DSA` - disables DSA tests. Needed for AWS CloudHSM (and any other tokens not supporting DSA). + +Testing with AWS CloudHSM +------------------------- + +A minimal configuration file for CloudHSM will look like this: + +```json +{ + "Path" : "/opt/cloudhsm/lib/libcloudhsm_pkcs11_standard.so", + "TokenLabel": "cavium", + "Pin" : "username:password", + "UseGCMIVFromHSM" : true, +} +``` + +To run the test suite you must skip unsupported tests: + +``` +CRYPTO11_SKIP=CERTS,OAEP_LABEL,DSA go test -v +``` + +Be sure to take note of the supported mechanisms, key types and other idiosyncrasies described at +https://docs.aws.amazon.com/cloudhsm/latest/userguide/pkcs11-library.html. Here's a collection of things we +noticed when testing with the v2.0.4 PKCS#11 library: + +- 1024-bit RSA keys don't appear to be supported, despite what `C_GetMechanismInfo` tells you. +- The `CKM_RSA_PKCS_OAEP` mechanism doesn't support source data. I.e. when constructing a `CK_RSA_PKCS_OAEP_PARAMS`, +one must set `pSourceData` to `NULL` and `ulSourceDataLen` to zero. +- CloudHSM will generate it's own IV for GCM mode. This is described in their documentation, see footnote 4 on +https://docs.aws.amazon.com/cloudhsm/latest/userguide/pkcs11-mechanisms.html. +- It appears that `CKA_ID` values must be unique, otherwise you get a `CKR_ATTRIBUTE_VALUE_INVALID` error. +- Very rapid session opening can trigger the following error: + ``` + C_OpenSession failed with error CKR_ARGUMENTS_BAD : 0x00000007 + HSM error 8c: HSM Error: Already maximum number of sessions are issued + ``` + Testing with SoftHSM2 --------------------- diff --git a/aead.go b/aead.go index c0c720c..1ca6899 100644 --- a/aead.go +++ b/aead.go @@ -42,6 +42,8 @@ const ( PaddingPKCS ) +var errBadGCMNonceSize = errors.New("nonce slice too small to hold IV") + type genericAead struct { key *SecretKey @@ -49,7 +51,9 @@ type genericAead struct { nonceSize int - makeMech func(nonce []byte, additionalData []byte) ([]*pkcs11.Mechanism, error) + // Note - if the GCMParams result is non-nil, the caller must call Free() on the params when + // finished. + makeMech func(nonce []byte, additionalData []byte) ([]*pkcs11.Mechanism, *pkcs11.GCMParams, error) } // NewGCM returns a given cipher wrapped in Galois Counter Mode, with the standard @@ -66,9 +70,9 @@ func (key *SecretKey) NewGCM() (cipher.AEAD, error) { key: key, overhead: 16, nonceSize: 12, - makeMech: func(nonce []byte, additionalData []byte) ([]*pkcs11.Mechanism, error) { + makeMech: func(nonce []byte, additionalData []byte) ([]*pkcs11.Mechanism, *pkcs11.GCMParams, error) { params := pkcs11.NewGCMParams(nonce, additionalData, 16*8 /*bits*/) - return []*pkcs11.Mechanism{pkcs11.NewMechanism(key.Cipher.GCMMech, params)}, nil + return []*pkcs11.Mechanism{pkcs11.NewMechanism(key.Cipher.GCMMech, params)}, params, nil }, } return g, nil @@ -96,12 +100,12 @@ func (key *SecretKey) NewCBC(paddingMode PaddingMode) (cipher.AEAD, error) { key: key, overhead: 0, nonceSize: key.BlockSize(), - makeMech: func(nonce []byte, additionalData []byte) ([]*pkcs11.Mechanism, error) { + makeMech: func(nonce []byte, additionalData []byte) ([]*pkcs11.Mechanism, *pkcs11.GCMParams, error) { if len(additionalData) > 0 { - return nil, errors.New("additional data not supported for CBC mode") + return nil, nil, errors.New("additional data not supported for CBC mode") } - return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcsMech, nonce)}, nil + return []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcsMech, nonce)}, nil, nil }, } @@ -117,12 +121,15 @@ func (g genericAead) Overhead() int { } func (g genericAead) Seal(dst, nonce, plaintext, additionalData []byte) []byte { + var result []byte if err := g.key.context.withSession(func(session *pkcs11Session) (err error) { - var mech []*pkcs11.Mechanism - if mech, err = g.makeMech(nonce, additionalData); err != nil { - return + mech, params, err := g.makeMech(nonce, additionalData) + if err != nil { + return err } + defer params.Free() + if err = session.ctx.EncryptInit(session.handle, mech, g.key.handle); err != nil { err = fmt.Errorf("C_EncryptInit: %v", err) return @@ -131,6 +138,15 @@ func (g genericAead) Seal(dst, nonce, plaintext, additionalData []byte) []byte { err = fmt.Errorf("C_Encrypt: %v", err) return } + + if g.key.context.cfg.UseGCMIVFromHSM { + if len(nonce) != len(params.IV()) { + return errBadGCMNonceSize + } + + copy(nonce, params.IV()) + } + return }); err != nil { panic(err) @@ -143,10 +159,12 @@ func (g genericAead) Seal(dst, nonce, plaintext, additionalData []byte) []byte { func (g genericAead) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) { var result []byte if err := g.key.context.withSession(func(session *pkcs11Session) (err error) { - var mech []*pkcs11.Mechanism - if mech, err = g.makeMech(nonce, additionalData); err != nil { + mech, params, err := g.makeMech(nonce, additionalData) + if err != nil { return } + defer params.Free() + if err = session.ctx.DecryptInit(session.handle, mech, g.key.handle); err != nil { err = fmt.Errorf("C_DecryptInit: %v", err) return diff --git a/attributes.go b/attributes.go index 8bbf1a8..17ae575 100644 --- a/attributes.go +++ b/attributes.go @@ -3,6 +3,7 @@ package crypto11 import ( "errors" "fmt" + "strings" "github.com/miekg/pkcs11" ) @@ -13,7 +14,7 @@ type AttributeType = uint // Attribute represents a PKCS#11 CK_ATTRIBUTE type. type Attribute = pkcs11.Attribute -//noinspection GoUnusedConst +//noinspection GoUnusedConst,GoDeprecation const ( CkaClass = AttributeType(0x00000000) CkaToken = AttributeType(0x00000001) @@ -191,6 +192,19 @@ func (a AttributeSet) Set(attributeType AttributeType, value interface{}) error return nil } +// cloneFrom make this AttributeSet a clone of the supplied set. Values are deep copied. +func (a AttributeSet) cloneFrom(set AttributeSet) { + for key := range a { + delete(a, key) + } + + // Use Copy to do the deep cloning for us + c := set.Copy() + for k, v := range c { + a[k] = v + } +} + // AddIfNotPresent adds the attributes if the Attribute Type is not already present in the AttributeSet. func (a AttributeSet) AddIfNotPresent(additional []*Attribute) { for _, additionalAttr := range additional { @@ -221,6 +235,20 @@ func (a AttributeSet) Copy() AttributeSet { return b } +// Unset removes an attribute from the attributes set. If the set does not contain the attribute, this +// is a no-op. +func (a AttributeSet) Unset(attributeType AttributeType) { + delete(a, attributeType) +} + +func (a AttributeSet) String() string { + result := new(strings.Builder) + for attr, value := range a { + _, _ = fmt.Fprintf(result, "%s: %x\n", attributeTypeString(attr), value.Value) + } + return result.String() +} + // NewAttributeSetWithID is a helper function that populates a new slice of Attributes with the provided ID. // This function returns an error if the ID is an empty slice. func NewAttributeSetWithID(id []byte) (AttributeSet, error) { @@ -246,3 +274,237 @@ func NewAttributeSetWithIDAndLabel(id, label []byte) (a AttributeSet, err error) _ = a.Set(CkaLabel, label) // error not possible with []byte return a, nil } + +func attributeTypeString(a AttributeType) string { + //noinspection GoDeprecation + switch a { + case CkaClass: + return "CkaClass" + case CkaToken: + return "CkaToken" + case CkaPrivate: + return "CkaPrivate" + case CkaLabel: + return "CkaLabel" + case CkaApplication: + return "CkaApplication" + case CkaValue: + return "CkaValue" + case CkaObjectId: + return "CkaObjectId" + case CkaCertificateType: + return "CkaCertificateType" + case CkaIssuer: + return "CkaIssuer" + case CkaSerialNumber: + return "CkaSerialNumber" + case CkaAcIssuer: + return "CkaAcIssuer" + case CkaOwner: + return "CkaOwner" + case CkaAttrTypes: + return "CkaAttrTypes" + case CkaTrusted: + return "CkaTrusted" + case CkaCertificateCategory: + return "CkaCertificateCategory" + case CkaJavaMIDPSecurityDomain: + return "CkaJavaMIDPSecurityDomain" + case CkaUrl: + return "CkaUrl" + case CkaHashOfSubjectPublicKey: + return "CkaHashOfSubjectPublicKey" + case CkaHashOfIssuerPublicKey: + return "CkaHashOfIssuerPublicKey" + case CkaNameHashAlgorithm: + return "CkaNameHashAlgorithm" + case CkaCheckValue: + return "CkaCheckValue" + + case CkaKeyType: + return "CkaKeyType" + case CkaSubject: + return "CkaSubject" + case CkaId: + return "CkaId" + case CkaSensitive: + return "CkaSensitive" + case CkaEncrypt: + return "CkaEncrypt" + case CkaDecrypt: + return "CkaDecrypt" + case CkaWrap: + return "CkaWrap" + case CkaUnwrap: + return "CkaUnwrap" + case CkaSign: + return "CkaSign" + case CkaSignRecover: + return "CkaSignRecover" + case CkaVerify: + return "CkaVerify" + case CkaVerifyRecover: + return "CkaVerifyRecover" + case CkaDerive: + return "CkaDerive" + case CkaStartDate: + return "CkaStartDate" + case CkaEndDate: + return "CkaEndDate" + case CkaModulus: + return "CkaModulus" + case CkaModulusBits: + return "CkaModulusBits" + case CkaPublicExponent: + return "CkaPublicExponent" + case CkaPrivateExponent: + return "CkaPrivateExponent" + case CkaPrime1: + return "CkaPrime1" + case CkaPrime2: + return "CkaPrime2" + case CkaExponent1: + return "CkaExponent1" + case CkaExponent2: + return "CkaExponent2" + case CkaCoefficient: + return "CkaCoefficient" + case CkaPublicKeyInfo: + return "CkaPublicKeyInfo" + case CkaPrime: + return "CkaPrime" + case CkaSubprime: + return "CkaSubprime" + case CkaBase: + return "CkaBase" + + case CkaPrimeBits: + return "CkaPrimeBits" + case CkaSubprimeBits: + return "CkaSubprimeBits" + + case CkaValueBits: + return "CkaValueBits" + case CkaValueLen: + return "CkaValueLen" + case CkaExtractable: + return "CkaExtractable" + case CkaLocal: + return "CkaLocal" + case CkaNeverExtractable: + return "CkaNeverExtractable" + case CkaAlwaysSensitive: + return "CkaAlwaysSensitive" + case CkaKeyGenMechanism: + return "CkaKeyGenMechanism" + + case CkaModifiable: + return "CkaModifiable" + case CkaCopyable: + return "CkaCopyable" + + case CkaDestroyable: + return "CkaDestroyable" + + case CkaEcParams: + return "CkaEcParams" + + case CkaEcPoint: + return "CkaEcPoint" + + case CkaSecondaryAuth: + return "CkaSecondaryAuth" + case CkaAuthPinFlags: + return "CkaAuthPinFlags" + + case CkaAlwaysAuthenticate: + return "CkaAlwaysAuthenticate" + + case CkaWrapWithTrusted: + return "CkaWrapWithTrusted" + + case ckfArrayAttribute: + return "ckfArrayAttribute" + + case CkaWrapTemplate: + return "CkaWrapTemplate" + case CkaUnwrapTemplate: + return "CkaUnwrapTemplate" + + case CkaOtpFormat: + return "CkaOtpFormat" + case CkaOtpLength: + return "CkaOtpLength" + case CkaOtpTimeInterval: + return "CkaOtpTimeInterval" + case CkaOtpUserFriendlyMode: + return "CkaOtpUserFriendlyMode" + case CkaOtpChallengeRequirement: + return "CkaOtpChallengeRequirement" + case CkaOtpTimeRequirement: + return "CkaOtpTimeRequirement" + case CkaOtpCounterRequirement: + return "CkaOtpCounterRequirement" + case CkaOtpPinRequirement: + return "CkaOtpPinRequirement" + case CkaOtpCounter: + return "CkaOtpCounter" + case CkaOtpTime: + return "CkaOtpTime" + case CkaOtpUserIdentifier: + return "CkaOtpUserIdentifier" + case CkaOtpServiceIdentifier: + return "CkaOtpServiceIdentifier" + case CkaOtpServiceLogo: + return "CkaOtpServiceLogo" + case CkaOtpServiceLogoType: + return "CkaOtpServiceLogoType" + + case CkaGOSTR3410Params: + return "CkaGOSTR3410Params" + case CkaGOSTR3411Params: + return "CkaGOSTR3411Params" + case CkaGOST28147Params: + return "CkaGOST28147Params" + + case CkaHwFeatureType: + return "CkaHwFeatureType" + case CkaResetOnInit: + return "CkaResetOnInit" + case CkaHasReset: + return "CkaHasReset" + + case CkaPixelX: + return "CkaPixelX" + case CkaPixelY: + return "CkaPixelY" + case CkaResolution: + return "CkaResolution" + case CkaCharRows: + return "CkaCharRows" + case CkaCharColumns: + return "CkaCharColumns" + case CkaColor: + return "CkaColor" + case CkaBitsPerPixel: + return "CkaBitsPerPixel" + case CkaCharSets: + return "CkaCharSets" + case CkaEncodingMethods: + return "CkaEncodingMethods" + case CkaMimeTypes: + return "CkaMimeTypes" + case CkaMechanismType: + return "CkaMechanismType" + case CkaRequiredCmsAttributes: + return "CkaRequiredCmsAttributes" + case CkaDefaultCmsAttributes: + return "CkaDefaultCmsAttributes" + case CkaSupportedCmsAttributes: + return "CkaSupportedCmsAttributes" + case CkaAllowedMechanisms: + return "CkaAllowedMechanisms" + default: + return "Unknown" + } +} diff --git a/certificates_test.go b/certificates_test.go index 49840de..cade98f 100644 --- a/certificates_test.go +++ b/certificates_test.go @@ -36,6 +36,8 @@ import ( ) func TestCertificate(t *testing.T) { + skipTest(t, skipTestCert) + ctx, err := ConfigureFromFile("config") require.NoError(t, err) @@ -70,6 +72,8 @@ func TestCertificate(t *testing.T) { // Test that provided attributes override default values func TestCertificateAttributes(t *testing.T) { + skipTest(t, skipTestCert) + ctx, err := ConfigureFromFile("config") require.NoError(t, err) @@ -103,6 +107,8 @@ func TestCertificateAttributes(t *testing.T) { } func TestCertificateRequiredArgs(t *testing.T) { + skipTest(t, skipTestCert) + ctx, err := ConfigureFromFile("config") require.NoError(t, err) diff --git a/close_test.go b/close_test.go index d3e7cfc..c1acfc9 100644 --- a/close_test.go +++ b/close_test.go @@ -24,59 +24,13 @@ package crypto11 import ( "crypto/dsa" "crypto/elliptic" - "fmt" - "math/rand" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestClose(t *testing.T) { - // Verify that close and re-open works. - - ctx, err := ConfigureFromFile("config") - require.NoError(t, err) - - const pSize = dsa.L1024N160 - id := randomBytes() - key, err := ctx.GenerateDSAKeyPair(id, dsaSizes[pSize]) - require.NoError(t, err) - require.NotNil(t, key) - - require.NoError(t, ctx.Close()) - - for i := 0; i < 5; i++ { - ctx, err := ConfigureFromFile("config") - require.NoError(t, err) - - key2, err := ctx.FindKeyPair(id, nil) - require.NoError(t, err) - - testDsaSigning(t, key2.(*pkcs11PrivateKeyDSA), pSize, fmt.Sprintf("close%d", i)) - - if i == 4 { - err = key2.Delete() - require.NoError(t, err) - } - - require.NoError(t, ctx.Close()) - } -} - -// randomBytes returns 32 random bytes. -func randomBytes() []byte { - result := make([]byte, 32) - rand.Read(result) - return result -} - -func init() { - rand.Seed(time.Now().UnixNano()) -} - func TestErrorAfterClosed(t *testing.T) { ctx, err := ConfigureFromFile("config") require.NoError(t, err) diff --git a/crypto11.go b/crypto11.go index d6f7b00..b37ebed 100644 --- a/crypto11.go +++ b/crypto11.go @@ -242,6 +242,12 @@ type Config struct { // LoginNotSupported should be set to true for tokens that do not support logging in. LoginNotSupported bool + + // UseGCMIVFromHSM should be set to true for tokens such as CloudHSM, which ignore the supplied IV for + // GCM mode and generate their own. In this case, the token will write the IV used into the CK_GCM_PARAMS. + // If UseGCMIVFromHSM is true, we will copy this IV and overwrite the 'nonce' slice passed to Seal and Open. It + // is therefore necessary that the nonce is the correct length (12 bytes for CloudHSM). + UseGCMIVFromHSM bool } // refCount counts the number of contexts using a particular P11 library. It must not be read or modified diff --git a/crypto11_test.go b/crypto11_test.go index 2c1297d..0a4cce4 100644 --- a/crypto11_test.go +++ b/crypto11_test.go @@ -22,7 +22,6 @@ package crypto11 import ( - "crypto/dsa" "encoding/json" "fmt" "log" @@ -39,37 +38,28 @@ import ( ) func TestKeysPersistAcrossContexts(t *testing.T) { - ctx, err := configureWithPin(t) + // Verify that close and re-open works. + ctx, err := ConfigureFromFile("config") require.NoError(t, err) - defer func() { - err = ctx.Close() - require.NoError(t, err) - }() - - // Generate a key and and close a session - const pSize = dsa.L1024N160 id := randomBytes() - key, err := ctx.GenerateDSAKeyPair(id, dsaSizes[pSize]) - require.NoError(t, err) - require.NotNil(t, key) + _, err = ctx.GenerateRSAKeyPair(id, rsaSize) + if err != nil { + _ = ctx.Close() + t.Fatal(err) + } - err = ctx.Close() - require.NoError(t, err) + require.NoError(t, ctx.Close()) - // Reopen a session and try to find a key. - // Valid session must enlist a key. - // If login is not performed than it will fail. - ctx, err = configureWithPin(t) + ctx, err = ConfigureFromFile("config") require.NoError(t, err) key2, err := ctx.FindKeyPair(id, nil) require.NoError(t, err) - testDsaSigning(t, key2.(*pkcs11PrivateKeyDSA), pSize, fmt.Sprintf("close%d", 0)) - - err = key2.Delete() - require.NoError(t, err) + testRsaSigning(t, key2, false) + _ = key2.Delete() + require.NoError(t, ctx.Close()) } func configureWithPin(t *testing.T) (*Context, error) { @@ -287,3 +277,14 @@ func TestNoLogin(t *testing.T) { assert.Equal(t, pkcs11.Error(pkcs11.CKR_USER_NOT_LOGGED_IN), p11Err) } + +// randomBytes returns 32 random bytes. +func randomBytes() []byte { + result := make([]byte, 32) + rand.Read(result) + return result +} + +func init() { + rand.Seed(time.Now().UnixNano()) +} diff --git a/dsa_test.go b/dsa_test.go index b155a01..ab0d73e 100644 --- a/dsa_test.go +++ b/dsa_test.go @@ -101,6 +101,8 @@ func TestNativeDSA(t *testing.T) { } func TestHardDSA(t *testing.T) { + skipTest(t, skipTestDSA) + ctx, err := ConfigureFromFile("config") require.NoError(t, err) @@ -115,9 +117,8 @@ func TestHardDSA(t *testing.T) { label := randomBytes() key, err := ctx.GenerateDSAKeyPairWithLabel(id, label, params) - require.NoError(t, err) - require.NotNil(t, key) - defer key.Delete() + require.NoError(t, err, "Failed for key size %s", parameterSizeToString(pSize)) + defer func(k Signer) { _ = k.Delete() }(key) testDsaSigning(t, key, pSize, "hard1") @@ -131,6 +132,21 @@ func TestHardDSA(t *testing.T) { } } +func parameterSizeToString(s dsa.ParameterSizes) string { + switch s { + case dsa.L1024N160: + return "L1024N160" + case dsa.L2048N224: + return "L2048N224" + case dsa.L2048N256: + return "L2048N256" + case dsa.L3072N256: + return "L3072N256" + default: + return "unknown" + } +} + func testDsaSigning(t *testing.T, key crypto.Signer, psize dsa.ParameterSizes, what string) { testDsaSigningWithHash(t, key, crypto.SHA1, psize, what) testDsaSigningWithHash(t, key, crypto.SHA224, psize, what) diff --git a/ecdsa_test.go b/ecdsa_test.go index dde920e..43407e3 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -31,6 +31,10 @@ import ( _ "crypto/sha512" "testing" + "github.com/miekg/pkcs11" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) @@ -50,11 +54,11 @@ func TestNativeECDSA(t *testing.T) { t.Errorf("crypto.ecdsa.GenerateKey: %v", err) return } - testEcdsaSigning(t, key, crypto.SHA1) - testEcdsaSigning(t, key, crypto.SHA224) - testEcdsaSigning(t, key, crypto.SHA256) - testEcdsaSigning(t, key, crypto.SHA384) - testEcdsaSigning(t, key, crypto.SHA512) + testEcdsaSigning(t, key, crypto.SHA1, curve.Params().Name, "SHA-1") + testEcdsaSigning(t, key, crypto.SHA224, curve.Params().Name, "SHA-224") + testEcdsaSigning(t, key, crypto.SHA256, curve.Params().Name, "SHA-256") + testEcdsaSigning(t, key, crypto.SHA384, curve.Params().Name, "SHA-384") + testEcdsaSigning(t, key, crypto.SHA512, curve.Params().Name, "SHA-512") } } @@ -74,34 +78,47 @@ func TestHardECDSA(t *testing.T) { key, err := ctx.GenerateECDSAKeyPairWithLabel(id, label, curve) require.NoError(t, err) require.NotNil(t, key) - defer key.Delete() + defer func(k Signer) { _ = k.Delete() }(key) - testEcdsaSigning(t, key, crypto.SHA1) - testEcdsaSigning(t, key, crypto.SHA224) - testEcdsaSigning(t, key, crypto.SHA256) - testEcdsaSigning(t, key, crypto.SHA384) - testEcdsaSigning(t, key, crypto.SHA512) + testEcdsaSigning(t, key, crypto.SHA1, curve.Params().Name, "SHA-1") + testEcdsaSigning(t, key, crypto.SHA224, curve.Params().Name, "SHA-224") + testEcdsaSigning(t, key, crypto.SHA256, curve.Params().Name, "SHA-256") + testEcdsaSigning(t, key, crypto.SHA384, curve.Params().Name, "SHA-384") + testEcdsaSigning(t, key, crypto.SHA512, curve.Params().Name, "SHA-512") key2, err := ctx.FindKeyPair(id, nil) require.NoError(t, err) - testEcdsaSigning(t, key2.(*pkcs11PrivateKeyECDSA), crypto.SHA256) + testEcdsaSigning(t, key2.(*pkcs11PrivateKeyECDSA), crypto.SHA256, curve.Params().Name, "SHA-256") key3, err := ctx.FindKeyPair(nil, label) require.NoError(t, err) - testEcdsaSigning(t, key3.(crypto.Signer), crypto.SHA384) + testEcdsaSigning(t, key3.(crypto.Signer), crypto.SHA384, curve.Params().Name, "SHA-384") } } -func testEcdsaSigning(t *testing.T, key crypto.Signer, hashFunction crypto.Hash) { +func testEcdsaSigning(t *testing.T, key crypto.Signer, hashFunction crypto.Hash, curveName, hashName string) { plaintext := []byte("sign me with ECDSA") h := hashFunction.New() _, err := h.Write(plaintext) require.NoError(t, err) - plaintextHash := h.Sum([]byte{}) // weird API + plaintextHash := h.Sum(nil) sigDER, err := key.Sign(rand.Reader, plaintextHash, nil) - require.NoError(t, err) + + p11Err, ok := err.(pkcs11.Error) + if ok && p11Err == pkcs11.CKR_KEY_SIZE_RANGE { + // Returned by CloudHSM (at least), for key sizes it doesn't support. + t.Logf("Skipping unsupported curve %s and hash %s", curveName, hashName) + return + } + + assert.NoErrorf(t, err, "Sign failed for curve %s and hash %s", curveName, hashName) + if err != nil { + // We assert and return, so that errors are more informative over a range of curves + // and hashes. + return + } var sig dsaSignature err = sig.unmarshalDER(sigDER) diff --git a/hmac_test.go b/hmac_test.go index 6762c92..403baee 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -57,6 +57,8 @@ func TestHmac(t *testing.T) { func testHmac(t *testing.T, ctx *Context, keytype int, mech int, length int, xlength int, full bool) { + skipIfMechUnsupported(t, ctx, uint(mech)) + id := randomBytes() key, err := ctx.GenerateSecretKey(id, 256, Ciphers[keytype]) require.NoError(t, err) diff --git a/keys.go b/keys.go index 88e3d8c..b0976a0 100644 --- a/keys.go +++ b/keys.go @@ -23,6 +23,7 @@ package crypto11 import ( "crypto" + "github.com/miekg/pkcs11" "github.com/pkg/errors" ) @@ -122,11 +123,39 @@ func (c *Context) makeKeyPair(session *pkcs11Session, privHandle *pkcs11.ObjectH return nil, errNoCkaId } + var pubHandle *pkcs11.ObjectHandle + // Find the public half which has a matching CKA_ID - pubHandle, err := findKey(session, id, label, uintPtr(pkcs11.CKO_PUBLIC_KEY), &keyType) + pubHandle, err = findKey(session, id, label, uintPtr(pkcs11.CKO_PUBLIC_KEY), &keyType) if err != nil { - return nil, err + p11Err, ok := err.(pkcs11.Error) + + if len(label) == 0 && ok && p11Err == pkcs11.CKR_TEMPLATE_INCONSISTENT { + // This probably means we are using a token that doesn't like us passing empty attributes in a template. + // For instance CloudHSM cannot search for a key with CKA_LABEL="". So if the private key doesn't have a + // label, we need to pass nil into findKeys, then match against the first key without a label. + + pubHandles, err := findKeys(session, id, nil, uintPtr(pkcs11.CKO_PUBLIC_KEY), &keyType) + if err != nil { + return nil, err + } + + for _, handle := range pubHandles { + template := []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_LABEL, nil)} + template, err = session.ctx.GetAttributeValue(session.handle, handle, template) + if err != nil { + return nil, err + } + if len(template[0].Value) == 0 { + pubHandle = &handle + break + } + } + } else { + return nil, err + } } + if pubHandle == nil { // We can't return a Signer if we don't have private and public key. Treat it as an error. return nil, errNoPublicHalf @@ -308,7 +337,11 @@ func (c *Context) FindKeyPairsWithAttributes(attributes AttributeSet) (signer [] return nil }) - return keys, err + if err != nil { + return nil, err + } + + return keys, nil } // FindAllKeyPairs retrieves all existing asymmetric key pairs, or a nil slice if none can be found. @@ -445,7 +478,10 @@ func (c *Context) FindKeysWithAttributes(attributes AttributeSet) ([]*SecretKey, return nil }) - return keys, err + if err != nil { + return nil, err + } + return keys, nil } // FindAllKeyPairs retrieves all existing symmetric keys, or a nil slice if none can be found. @@ -563,4 +599,4 @@ func (c *Context) GetPubAttribute(key interface{}, attribute AttributeType) (a * } return set[attribute], nil -} \ No newline at end of file +} diff --git a/keys_test.go b/keys_test.go index 3fb0798..ca8ab6a 100644 --- a/keys_test.go +++ b/keys_test.go @@ -5,6 +5,8 @@ import ( "crypto/rsa" "testing" + "github.com/miekg/pkcs11" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -39,31 +41,30 @@ func TestFindKeysRequiresIdOrLabel(t *testing.T) { func TestFindingKeysWithAttributes(t *testing.T) { withContext(t, func(ctx *Context) { - id := randomBytes() - id2 := randomBytes() - - key, err := ctx.GenerateSecretKey(id, 128, CipherAES) - require.NoError(t, err) - defer key.Delete() + label := randomBytes() + label2 := randomBytes() - key, err = ctx.GenerateSecretKey(id2, 128, CipherAES) + key, err := ctx.GenerateSecretKeyWithLabel(randomBytes(), label, 128, CipherAES) require.NoError(t, err) - defer key.Delete() + defer func(k *SecretKey) { _ = k.Delete() }(key) - key, err = ctx.GenerateSecretKey(id2, 256, CipherAES) + key, err = ctx.GenerateSecretKeyWithLabel(randomBytes(), label2, 128, CipherAES) require.NoError(t, err) - defer key.Delete() + defer func(k *SecretKey) { _ = k.Delete() }(key) - attrs, err := NewAttributeSetWithID(id) + key, err = ctx.GenerateSecretKeyWithLabel(randomBytes(), label2, 256, CipherAES) require.NoError(t, err) + defer func(k *SecretKey) { _ = k.Delete() }(key) + attrs := NewAttributeSet() + _ = attrs.Set(CkaLabel, label) keys, err := ctx.FindKeysWithAttributes(attrs) - require.Len(t, keys, 1) - - attrs, err = NewAttributeSetWithID(id2) require.NoError(t, err) + require.Len(t, keys, 1) + _ = attrs.Set(CkaLabel, label2) keys, err = ctx.FindKeysWithAttributes(attrs) + require.NoError(t, err) require.Len(t, keys, 2) attrs = NewAttributeSet() @@ -71,6 +72,7 @@ func TestFindingKeysWithAttributes(t *testing.T) { require.NoError(t, err) keys, err = ctx.FindKeysWithAttributes(attrs) + require.NoError(t, err) require.Len(t, keys, 2) attrs = NewAttributeSet() @@ -78,44 +80,47 @@ func TestFindingKeysWithAttributes(t *testing.T) { require.NoError(t, err) keys, err = ctx.FindKeysWithAttributes(attrs) + require.NoError(t, err) require.Len(t, keys, 1) }) } func TestFindingKeyPairsWithAttributes(t *testing.T) { withContext(t, func(ctx *Context) { - id := randomBytes() - id2 := randomBytes() - key, err := ctx.GenerateRSAKeyPair(id, 1024) - require.NoError(t, err) - defer key.Delete() + // Note: we use common labels, not IDs in this test code. AWS CloudHSM + // does not accept two keys with the same ID. + + label := randomBytes() + label2 := randomBytes() - key, err = ctx.GenerateRSAKeyPair(id2, 1024) + key, err := ctx.GenerateRSAKeyPairWithLabel(randomBytes(), label, rsaSize) require.NoError(t, err) - defer key.Delete() + defer func(k Signer) { _ = k.Delete() }(key) - key, err = ctx.GenerateRSAKeyPair(id2, 2048) + key, err = ctx.GenerateRSAKeyPairWithLabel(randomBytes(), label2, rsaSize) require.NoError(t, err) - defer key.Delete() + defer func(k Signer) { _ = k.Delete() }(key) - attrs, err := NewAttributeSetWithID(id) + key, err = ctx.GenerateRSAKeyPairWithLabel(randomBytes(), label2, rsaSize) require.NoError(t, err) + defer func(k Signer) { _ = k.Delete() }(key) + attrs := NewAttributeSet() + _ = attrs.Set(CkaLabel, label) keys, err := ctx.FindKeyPairsWithAttributes(attrs) - require.Len(t, keys, 1) - - attrs, err = NewAttributeSetWithID(id2) require.NoError(t, err) + require.Len(t, keys, 1) + _ = attrs.Set(CkaLabel, label2) keys, err = ctx.FindKeyPairsWithAttributes(attrs) + require.NoError(t, err) require.Len(t, keys, 2) attrs = NewAttributeSet() - err = attrs.Set(CkaPublicExponent, []byte{1, 0, 1}) - require.NoError(t, err) - + _ = attrs.Set(CkaKeyType, pkcs11.CKK_RSA) keys, err = ctx.FindKeyPairsWithAttributes(attrs) + require.NoError(t, err) require.Len(t, keys, 3) }) } @@ -127,7 +132,7 @@ func TestFindingAllKeys(t *testing.T) { key, err := ctx.GenerateSecretKey(id, 128, CipherAES) require.NoError(t, err) - defer key.Delete() + defer func(k *SecretKey) { _ = k.Delete() }(key) } keys, err := ctx.FindAllKeys() @@ -142,10 +147,10 @@ func TestFindingAllKeyPairs(t *testing.T) { withContext(t, func(ctx *Context) { for i := 1; i <= 5; i++ { id := randomBytes() - key, err := ctx.GenerateRSAKeyPair(id, 1024) + key, err := ctx.GenerateRSAKeyPair(id, rsaSize) require.NoError(t, err) - defer key.Delete() + defer func(k Signer) { _ = k.Delete() }(key) } keys, err := ctx.FindAllKeyPairs() @@ -160,16 +165,16 @@ func TestGettingPrivateKeyAttributes(t *testing.T) { withContext(t, func(ctx *Context) { id := randomBytes() - key, err := ctx.GenerateRSAKeyPair(id, 1024) + key, err := ctx.GenerateRSAKeyPair(id, rsaSize) require.NoError(t, err) - defer key.Delete() + defer func(k Signer) { _ = k.Delete() }(key) attrs, err := ctx.GetAttributes(key, []AttributeType{CkaModulus}) require.NoError(t, err) require.NotNil(t, attrs) require.Len(t, attrs, 1) - require.Len(t, attrs[CkaModulus].Value, 128) + require.Len(t, attrs[CkaModulus].Value, 256) }) } @@ -177,16 +182,16 @@ func TestGettingPublicKeyAttributes(t *testing.T) { withContext(t, func(ctx *Context) { id := randomBytes() - key, err := ctx.GenerateRSAKeyPair(id, 1024) + key, err := ctx.GenerateRSAKeyPair(id, rsaSize) require.NoError(t, err) - defer key.Delete() + defer func(k Signer) { _ = k.Delete() }(key) attrs, err := ctx.GetPubAttributes(key, []AttributeType{CkaModulusBits}) require.NoError(t, err) require.NotNil(t, attrs) require.Len(t, attrs, 1) - require.Equal(t, uint(1024), bytesToUlong(attrs[CkaModulusBits].Value)) + require.Equal(t, uint(rsaSize), bytesToUlong(attrs[CkaModulusBits].Value)) }) } @@ -196,7 +201,7 @@ func TestGettingSecretKeyAttributes(t *testing.T) { key, err := ctx.GenerateSecretKey(id, 128, CipherAES) require.NoError(t, err) - defer key.Delete() + defer func(k *SecretKey) { _ = k.Delete() }(key) attrs, err := ctx.GetAttributes(key, []AttributeType{CkaValueLen}) require.NoError(t, err) @@ -209,7 +214,7 @@ func TestGettingSecretKeyAttributes(t *testing.T) { func TestGettingUnsupportedKeyTypeAttributes(t *testing.T) { withContext(t, func(ctx *Context) { - key, err := rsa.GenerateKey(rand.Reader, 1024) + key, err := rsa.GenerateKey(rand.Reader, rsaSize) require.NoError(t, err) _, err = ctx.GetAttributes(key, []AttributeType{CkaModulusBits}) diff --git a/rand_test.go b/rand_test.go index f0fea63..45124f0 100644 --- a/rand_test.go +++ b/rand_test.go @@ -39,8 +39,8 @@ func TestRandomReader(t *testing.T) { reader, err := ctx.NewRandomReader() require.NoError(t, err) - var a [32768]byte - for _, size := range []int{1, 16, 32, 256, 347, 4096, 32768} { + var a [8192]byte + for _, size := range []int{1, 16, 32, 256, 347, 4096, 8192} { n, err := reader.Read(a[:size]) require.NoError(t, err) require.Equal(t, size, n) diff --git a/rsa.go b/rsa.go index c6fb989..234fd0a 100644 --- a/rsa.go +++ b/rsa.go @@ -27,7 +27,6 @@ import ( "errors" "io" "math/big" - "unsafe" "github.com/miekg/pkcs11" ) @@ -210,29 +209,25 @@ func decryptPKCS1v15(session *pkcs11Session, key *pkcs11PrivateKeyRSA, ciphertex return session.ctx.Decrypt(session.handle, ciphertext) } -func decryptOAEP(session *pkcs11Session, key *pkcs11PrivateKeyRSA, ciphertext []byte, hashFunction crypto.Hash, label []byte) ([]byte, error) { - var err error - var hMech, mgf, sourceData, sourceDataLen uint - if hMech, mgf, _, err = hashToPKCS11(hashFunction); err != nil { +func decryptOAEP(session *pkcs11Session, key *pkcs11PrivateKeyRSA, ciphertext []byte, hashFunction crypto.Hash, + label []byte) ([]byte, error) { + + hashAlg, mgfAlg, _, err := hashToPKCS11(hashFunction) + if err != nil { return nil, err } - if len(label) > 0 { - sourceData = uint(uintptr(unsafe.Pointer(&label[0]))) - sourceDataLen = uint(len(label)) - } - parameters := concat(ulongToBytes(hMech), - ulongToBytes(mgf), - ulongToBytes(pkcs11.CKZ_DATA_SPECIFIED), - ulongToBytes(sourceData), - ulongToBytes(sourceDataLen)) - mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_OAEP, parameters)} - if err = session.ctx.DecryptInit(session.handle, mech, key.handle); err != nil { + + mech := pkcs11.NewMechanism(pkcs11.CKM_RSA_PKCS_OAEP, + pkcs11.NewOAEPParams(hashAlg, mgfAlg, pkcs11.CKZ_DATA_SPECIFIED, label)) + + err = session.ctx.DecryptInit(session.handle, []*pkcs11.Mechanism{mech}, key.handle) + if err != nil { return nil, err } return session.ctx.Decrypt(session.handle, ciphertext) } -func hashToPKCS11(hashFunction crypto.Hash) (uint, uint, uint, error) { +func hashToPKCS11(hashFunction crypto.Hash) (hashAlg uint, mgfAlg uint, hashLen uint, err error) { switch hashFunction { case crypto.SHA1: return pkcs11.CKM_SHA_1, pkcs11.CKG_MGF1_SHA1, 20, nil diff --git a/rsa_test.go b/rsa_test.go index 9934c1e..c5ac06f 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -29,14 +29,14 @@ import ( _ "crypto/sha1" _ "crypto/sha256" _ "crypto/sha512" - "fmt" "testing" "github.com/miekg/pkcs11" "github.com/stretchr/testify/require" ) -var rsaSizes = []int{1024, 2048} +// Set to 2048, as most tokens will support this. 1024 not supported by some tokens (e.g. Amazon CloudHSM). +const rsaSize = 2048 func TestNativeRSA(t *testing.T) { @@ -47,68 +47,58 @@ func TestNativeRSA(t *testing.T) { require.NoError(t, ctx.Close()) }() - for _, nbits := range rsaSizes { - t.Run(fmt.Sprintf("%v", nbits), func(t *testing.T) { - key, err := rsa.GenerateKey(rand.Reader, nbits) - require.NoError(t, err) + key, err := rsa.GenerateKey(rand.Reader, rsaSize) + require.NoError(t, err) - err = key.Validate() - require.NoError(t, err) + err = key.Validate() + require.NoError(t, err) - t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, nbits, true) }) - t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, nbits, true) }) - }) - } + t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, true) }) + t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, true) }) } func TestHardRSA(t *testing.T) { ctx, err := ConfigureFromFile("config") require.NoError(t, err) - defer func() { require.NoError(t, ctx.Close()) }() - for _, nbits := range rsaSizes { - id := randomBytes() - label := randomBytes() - - t.Run(fmt.Sprintf("%v", nbits), func(t *testing.T) { - - key, err := ctx.GenerateRSAKeyPairWithLabel(id, label, nbits) - require.NoError(t, err) - require.NotNil(t, key) - defer key.Delete() - - var key2, key3 crypto.PrivateKey - - t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, nbits, false) }) - t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, nbits, false) }) - t.Run("FindId", func(t *testing.T) { - key2, err = ctx.FindKeyPair(id, nil) - require.NoError(t, err) - }) - t.Run("SignId", func(t *testing.T) { - if key2 == nil { - t.SkipNow() - } - testRsaSigning(t, key2.(*pkcs11PrivateKeyRSA), nbits, false) - }) - t.Run("FindLabel", func(t *testing.T) { - key3, err = ctx.FindKeyPair(nil, label) - require.NoError(t, err) - }) - t.Run("SignLabel", func(t *testing.T) { - if key3 == nil { - t.SkipNow() - } - testRsaSigning(t, key3.(crypto.Signer), nbits, false) - }) - }) - } + id := randomBytes() + label := randomBytes() + + key, err := ctx.GenerateRSAKeyPairWithLabel(id, label, rsaSize) + require.NoError(t, err) + require.NotNil(t, key) + defer func() { _ = key.Delete() }() + + var key2, key3 crypto.PrivateKey + + t.Run("Sign", func(t *testing.T) { testRsaSigning(t, key, false) }) + t.Run("Encrypt", func(t *testing.T) { testRsaEncryption(t, key, false) }) + t.Run("FindId", func(t *testing.T) { + key2, err = ctx.FindKeyPair(id, nil) + require.NoError(t, err) + }) + t.Run("SignId", func(t *testing.T) { + if key2 == nil { + t.SkipNow() + } + testRsaSigning(t, key2.(*pkcs11PrivateKeyRSA), false) + }) + t.Run("FindLabel", func(t *testing.T) { + key3, err = ctx.FindKeyPair(nil, label) + require.NoError(t, err) + }) + t.Run("SignLabel", func(t *testing.T) { + if key3 == nil { + t.SkipNow() + } + testRsaSigning(t, key3.(crypto.Signer), false) + }) } -func testRsaSigning(t *testing.T, key crypto.Signer, nbits int, native bool) { +func testRsaSigning(t *testing.T, key crypto.Signer, native bool) { t.Run("SHA1", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA1) }) t.Run("SHA224", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA224) }) t.Run("SHA256", func(t *testing.T) { testRsaSigningPKCS1v15(t, key, crypto.SHA256) }) @@ -118,13 +108,7 @@ func testRsaSigning(t *testing.T, key crypto.Signer, nbits int, native bool) { t.Run("PSSSHA224", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA224, native) }) t.Run("PSSSHA256", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA256, native) }) t.Run("PSSSHA384", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA384, native) }) - t.Run("PSSSHA512", func(t *testing.T) { - if nbits > 1024 { - testRsaSigningPSS(t, key, crypto.SHA512, native) - } else { - t.Skipf("key too smol for SHA512 with sLen=hLen") - } - }) + t.Run("PSSSHA512", func(t *testing.T) { testRsaSigningPSS(t, key, crypto.SHA512, native) }) } func testRsaSigningPKCS1v15(t *testing.T, key crypto.Signer, hashFunction crypto.Hash) { @@ -167,32 +151,23 @@ func testRsaSigningPSS(t *testing.T, key crypto.Signer, hashFunction crypto.Hash require.NoError(t, err) } -func testRsaEncryption(t *testing.T, key crypto.Decrypter, nbits int, native bool) { +func testRsaEncryption(t *testing.T, key crypto.Decrypter, native bool) { t.Run("PKCS1v15", func(t *testing.T) { testRsaEncryptionPKCS1v15(t, key) }) t.Run("OAEPSHA1", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{}, native) }) t.Run("OAEPSHA224", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{}, native) }) t.Run("OAEPSHA256", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{}, native) }) t.Run("OAEPSHA384", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{}, native) }) - t.Run("OAEPSHA512", func(t *testing.T) { - if nbits > 1024 { - testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{}, native) - } else { - t.Skipf("key too small for SHA512") - } - }) - t.Run("OAEPSHA1Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{1, 2, 3, 4}, native) }) - t.Run("OAEPSHA224Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{5, 6, 7, 8}, native) }) - t.Run("OAEPSHA256Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{9}, native) }) - t.Run("OAEPSHA384Label", func(t *testing.T) { - testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{10, 11, 12, 13, 14, 15}, native) - }) - t.Run("OAEPSHA512Label", func(t *testing.T) { - if nbits > 1024 { - testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{16, 17, 18}, native) - } else { - t.Skipf("key too small for SHA512") - } - }) + t.Run("OAEPSHA512", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{}, native) }) + + if !shouldSkipTest(skipTestOAEPLabel) { + t.Run("OAEPSHA1Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA1, []byte{1, 2, 3, 4}, native) }) + t.Run("OAEPSHA224Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA224, []byte{5, 6, 7, 8}, native) }) + t.Run("OAEPSHA256Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA256, []byte{9}, native) }) + t.Run("OAEPSHA384Label", func(t *testing.T) { + testRsaEncryptionOAEP(t, key, crypto.SHA384, []byte{10, 11, 12, 13, 14, 15}, native) + }) + t.Run("OAEPSHA512Label", func(t *testing.T) { testRsaEncryptionOAEP(t, key, crypto.SHA512, []byte{16, 17, 18}, native) }) + } } func testRsaEncryptionPKCS1v15(t *testing.T, key crypto.Decrypter) { @@ -265,7 +240,7 @@ func skipIfMechUnsupported(t *testing.T, ctx *Context, wantMech uint) { return } } - t.Skipf("mechanism %v not supported", wantMech) + t.Skipf("mechanism 0x%x not supported", wantMech) } func TestRsaRequiredArgs(t *testing.T) { diff --git a/skip_test.go b/skip_test.go new file mode 100644 index 0000000..40a0f5f --- /dev/null +++ b/skip_test.go @@ -0,0 +1,31 @@ +package crypto11 + +import ( + "os" + "strings" + "testing" +) + +const skipTestEnv = "CRYPTO11_SKIP" +const skipTestCert = "CERTS" +const skipTestOAEPLabel = "OAEP_LABEL" +const skipTestDSA = "DSA" + +// skipTest tests whether the CRYPTO11_SKIP environment variable contains +// flagName. If so, it skips the test. +func skipTest(t *testing.T, flagName string) { + if shouldSkipTest(flagName) { + t.Logf("Skipping test due to %s flag", flagName) + t.SkipNow() + } +} + +func shouldSkipTest(flagName string) bool { + thingsToSkip := strings.Split(os.Getenv(skipTestEnv), ",") + for _, s := range thingsToSkip { + if s == flagName { + return true + } + } + return false +} diff --git a/symmetric.go b/symmetric.go index bb61a08..4a01248 100644 --- a/symmetric.go +++ b/symmetric.go @@ -294,22 +294,24 @@ func (c *Context) GenerateSecretKeyWithAttributes(template AttributeSet, bits in // CKK_*_HMAC exists but there is no specific corresponding CKM_*_KEY_GEN // mechanism. Therefore we attempt both CKM_GENERIC_SECRET_KEY_GEN and // vendor-specific mechanisms. - for _, genMech := range cipher.GenParams { - template.AddIfNotPresent([]*pkcs11.Attribute{ - pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY), - pkcs11.NewAttribute(pkcs11.CKA_KEY_TYPE, genMech.KeyType), - pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), - pkcs11.NewAttribute(pkcs11.CKA_SIGN, cipher.MAC), - pkcs11.NewAttribute(pkcs11.CKA_VERIFY, cipher.MAC), - pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, cipher.Encrypt), - pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, cipher.Encrypt), - pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), - pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), - }) - - if bits > 0 { - template.Set(pkcs11.CKA_VALUE_LEN, bits/8) - } + + template.AddIfNotPresent([]*pkcs11.Attribute{ + pkcs11.NewAttribute(pkcs11.CKA_CLASS, pkcs11.CKO_SECRET_KEY), + pkcs11.NewAttribute(pkcs11.CKA_TOKEN, true), + pkcs11.NewAttribute(pkcs11.CKA_SIGN, cipher.MAC), + pkcs11.NewAttribute(pkcs11.CKA_VERIFY, cipher.MAC), + pkcs11.NewAttribute(pkcs11.CKA_ENCRYPT, cipher.Encrypt), // Not supported on CloudHSM + pkcs11.NewAttribute(pkcs11.CKA_DECRYPT, cipher.Encrypt), // Not supported on CloudHSM + pkcs11.NewAttribute(pkcs11.CKA_SENSITIVE, true), + pkcs11.NewAttribute(pkcs11.CKA_EXTRACTABLE, false), + }) + if bits > 0 { + _ = template.Set(pkcs11.CKA_VALUE_LEN, bits/8) // safe for an int + } + + for n, genMech := range cipher.GenParams { + + _ = template.Set(CkaKeyType, genMech.KeyType) mech := []*pkcs11.Mechanism{pkcs11.NewMechanism(genMech.GenMech, nil)} @@ -319,8 +321,33 @@ func (c *Context) GenerateSecretKeyWithAttributes(template AttributeSet, bits in return nil } - // nShield returns this if if doesn't like the CKK/CKM combination. - if e, ok := err.(pkcs11.Error); ok && e == pkcs11.CKR_TEMPLATE_INCONSISTENT { + // As a special case, AWS CloudHSM does not accept CKA_ENCRYPT and CKA_DECRYPT on a + // Generic Secret key. If we are in that special case, try again without those attributes. + if e, ok := err.(pkcs11.Error); ok && e == pkcs11.CKR_ARGUMENTS_BAD && genMech.GenMech == pkcs11.CKM_GENERIC_SECRET_KEY_GEN { + adjustedTemplate := template.Copy() + adjustedTemplate.Unset(CkaEncrypt) + adjustedTemplate.Unset(CkaDecrypt) + + privHandle, err = session.ctx.GenerateKey(session.handle, mech, adjustedTemplate.ToSlice()) + if err == nil { + // Store the actual attributes + template.cloneFrom(adjustedTemplate) + + k = &SecretKey{pkcs11Object{privHandle, c}, cipher} + return nil + } + } + + if n == len(cipher.GenParams)-1 { + // If we have tried all available gen params, we should return a sensible error. So we skip the + // retry logic below and return directly. + return err + } + + // nShield returns CKR_TEMPLATE_INCONSISTENT if if doesn't like the CKK/CKM combination. + // AWS CloudHSM returns CKR_ATTRIBUTE_VALUE_INVALID in the same circumstances. + if e, ok := err.(pkcs11.Error); ok && + e == pkcs11.CKR_TEMPLATE_INCONSISTENT || e == pkcs11.CKR_ATTRIBUTE_VALUE_INVALID { continue } diff --git a/symmetric_test.go b/symmetric_test.go index bbe9f37..265e49f 100644 --- a/symmetric_test.go +++ b/symmetric_test.go @@ -46,6 +46,9 @@ func TestHardSymmetric(t *testing.T) { } func testHardSymmetric(t *testing.T, ctx *Context, keytype int, bits int) { + for _, p := range Ciphers[keytype].GenParams { + skipIfMechUnsupported(t, ctx, p.GenMech) + } id := randomBytes() key, err := ctx.GenerateSecretKey(id, bits, Ciphers[keytype]) @@ -59,7 +62,10 @@ func testHardSymmetric(t *testing.T, ctx *Context, keytype int, bits int) { require.NoError(t, err) }) - t.Run("Block", func(t *testing.T) { testSymmetricBlock(t, key, key2) }) + t.Run("Block", func(t *testing.T) { + skipIfMechUnsupported(t, key.context, key.Cipher.ECBMech) + testSymmetricBlock(t, key, key2) + }) iv := make([]byte, key.BlockSize()) for i := range iv { @@ -67,11 +73,13 @@ func testHardSymmetric(t *testing.T, ctx *Context, keytype int, bits int) { } t.Run("CBC", func(t *testing.T) { + // By using cipher.NewCBCEncrypter, this test will actually use ECB mode on the key. + skipIfMechUnsupported(t, key2.context, key2.Cipher.ECBMech) testSymmetricMode(t, cipher.NewCBCEncrypter(key2, iv), cipher.NewCBCDecrypter(key2, iv)) }) t.Run("CBCClose", func(t *testing.T) { - + skipIfMechUnsupported(t, key2.context, key2.Cipher.CBCMech) enc, err := key2.NewCBCEncrypterCloser(iv) require.NoError(t, err) @@ -84,6 +92,7 @@ func testHardSymmetric(t *testing.T, ctx *Context, keytype int, bits int) { }) t.Run("CBCNoClose", func(t *testing.T) { + skipIfMechUnsupported(t, key2.context, key2.Cipher.CBCMech) enc, err := key2.NewCBCEncrypter(iv) require.NoError(t, err) @@ -127,6 +136,14 @@ func testHardSymmetric(t *testing.T, ctx *Context, keytype int, bits int) { } func testSymmetricBlock(t *testing.T, encryptKey cipher.Block, decryptKey cipher.Block) { + // The functions in cipher.Block have no error returns, so they panic if they encounter + // a problem. We catch these panics here, so the test can fail nicely + defer func() { + if cause := recover(); cause != nil { + t.Fatalf("Caught panic: %q", cause) + } + }() + b := encryptKey.BlockSize() input := make([]byte, 3*b) middle := make([]byte, 3*b) @@ -176,6 +193,14 @@ func testSymmetricBlock(t *testing.T, encryptKey cipher.Block, decryptKey cipher } func testSymmetricMode(t *testing.T, encrypt cipher.BlockMode, decrypt cipher.BlockMode) { + // The functions in cipher.Block have no error returns, so they panic if they encounter + // a problem. We catch these panics here, so the test can fail nicely + defer func() { + if cause := recover(); cause != nil { + t.Fatalf("Caught panic: %q", cause) + } + }() + input := make([]byte, 256) middle := make([]byte, 256) output := make([]byte, 256) diff --git a/thread_test.go b/thread_test.go index 7398644..77f5d4e 100644 --- a/thread_test.go +++ b/thread_test.go @@ -33,6 +33,9 @@ var threadCount = 32 var signaturesPerThread = 256 func TestThreadedRSA(t *testing.T) { + if testing.Short() { + t.Skip() + } ctx, err := ConfigureFromFile("config") require.NoError(t, err) @@ -42,9 +45,9 @@ func TestThreadedRSA(t *testing.T) { }() id := randomBytes() - key, err := ctx.GenerateRSAKeyPair(id, 1024) + key, err := ctx.GenerateRSAKeyPair(id, rsaSize) require.NoError(t, err) - defer key.Delete() + defer func(k Signer) { _ = k.Delete() }(key) done := make(chan int) started := time.Now() @@ -53,6 +56,9 @@ func TestThreadedRSA(t *testing.T) { for i := 0; i < threadCount; i++ { go signingRoutine(t, key, done) + + // CloudHSM falls over if you create sessions too quickly + time.Sleep(50 * time.Millisecond) } t.Logf("Waiting for %v threads", threadCount) for i := 0; i < threadCount; i++ { @@ -69,7 +75,9 @@ func TestThreadedRSA(t *testing.T) { func signingRoutine(t *testing.T, key crypto.Signer, done chan int) { for i := 0; i < signaturesPerThread; i++ { testRsaSigningPKCS1v15(t, key, crypto.SHA1) + + // CloudHSM falls over if you create sessions too quickly + time.Sleep(50 * time.Millisecond) } done <- 1 - }