Skip to content

Commit

Permalink
more strict key validation (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg authored Mar 3, 2021
1 parent 47423dc commit 7521642
Show file tree
Hide file tree
Showing 12 changed files with 226 additions and 42 deletions.
20 changes: 20 additions & 0 deletions algo.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,26 @@ type Algorithm string

func (a Algorithm) String() string { return string(a) }

// keySize of the algorithm's key (if exist). Is similar to Signer.SignSize.
func (a Algorithm) keySize() int { return algsKeySize[a] }

var algsKeySize = map[Algorithm]int{
// for EdDSA private and public key have different sizes, so 0
// for HS there is no limits for key size, so 0

RS256: 256,
RS384: 384,
RS512: 512,

ES256: 64,
ES384: 96,
ES512: 132,

PS256: 256,
PS384: 384,
PS512: 512,
}

// Algorithm names for signing and verifying.
const (
EdDSA Algorithm = "EdDSA"
Expand Down
10 changes: 8 additions & 2 deletions algo_eddsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ import (

// NewSignerEdDSA returns a new ed25519-based signer.
func NewSignerEdDSA(key ed25519.PrivateKey) (Signer, error) {
if len(key) == 0 || len(key) != ed25519.PrivateKeySize {
if len(key) == 0 {
return nil, ErrNilKey
}
if len(key) != ed25519.PrivateKeySize {
return nil, ErrInvalidKey
}
return &edDSAAlg{
Expand All @@ -17,7 +20,10 @@ func NewSignerEdDSA(key ed25519.PrivateKey) (Signer, error) {

// NewVerifierEdDSA returns a new ed25519-based verifier.
func NewVerifierEdDSA(key ed25519.PublicKey) (Verifier, error) {
if len(key) == 0 || len(key) != ed25519.PublicKeySize {
if len(key) == 0 {
return nil, ErrNilKey
}
if len(key) != ed25519.PublicKeySize {
return nil, ErrInvalidKey
}
return &edDSAAlg{
Expand Down
17 changes: 17 additions & 0 deletions algo_eddsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package jwt
import (
"crypto/ed25519"
"crypto/rand"
"errors"
"testing"
)

Expand Down Expand Up @@ -49,6 +50,22 @@ func TestEdDSA(t *testing.T) {
f(ed25519OtherPrivateKey, ed25519PublicKey, false)
}

func TestEdDSA_BadKeys(t *testing.T) {
f := func(err, wantErr error) {
if !errors.Is(err, wantErr) {
t.Fatalf("expected %v, got %v", wantErr, err)
}
}

f(getSignerError(NewSignerEdDSA(nil)), ErrNilKey)
priv := ed25519.PrivateKey(make([]byte, 72))
f(getSignerError(NewSignerEdDSA(priv)), ErrInvalidKey)

f(getVerifierError(NewVerifierEdDSA(nil)), ErrNilKey)
pub := ed25519.PublicKey(make([]byte, 72))
f(getVerifierError(NewVerifierEdDSA(pub)), ErrInvalidKey)
}

func ed25519Sign(t *testing.T, privateKey ed25519.PrivateKey, payload string) []byte {
t.Helper()

Expand Down
32 changes: 19 additions & 13 deletions algo_es.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ import (
// NewSignerES returns a new ECDSA-based signer.
func NewSignerES(alg Algorithm, key *ecdsa.PrivateKey) (Signer, error) {
if key == nil {
return nil, ErrInvalidKey
return nil, ErrNilKey
}
hash, ok := getParamsES(alg)
if !ok {
return nil, ErrUnsupportedAlg
hash, err := getParamsES(alg, roundBytes(key.PublicKey.Params().BitSize)*2)
if err != nil {
return nil, err
}
return &esAlg{
alg: alg,
Expand All @@ -27,11 +27,11 @@ func NewSignerES(alg Algorithm, key *ecdsa.PrivateKey) (Signer, error) {
// NewVerifierES returns a new ECDSA-based verifier.
func NewVerifierES(alg Algorithm, key *ecdsa.PublicKey) (Verifier, error) {
if key == nil {
return nil, ErrInvalidKey
return nil, ErrNilKey
}
hash, ok := getParamsES(alg)
if !ok {
return nil, ErrUnsupportedAlg
hash, err := getParamsES(alg, roundBytes(key.Params().BitSize)*2)
if err != nil {
return nil, err
}
return &esAlg{
alg: alg,
Expand All @@ -41,17 +41,23 @@ func NewVerifierES(alg Algorithm, key *ecdsa.PublicKey) (Verifier, error) {
}, nil
}

func getParamsES(alg Algorithm) (crypto.Hash, bool) {
func getParamsES(alg Algorithm, size int) (crypto.Hash, error) {
var hash crypto.Hash
switch alg {
case ES256:
return crypto.SHA256, true
hash = crypto.SHA256
case ES384:
return crypto.SHA384, true
hash = crypto.SHA384
case ES512:
return crypto.SHA512, true
hash = crypto.SHA512
default:
return 0, false
return 0, ErrUnsupportedAlg
}

if alg.keySize() != size {
return 0, ErrInvalidKey
}
return hash, nil
}

type esAlg struct {
Expand Down
37 changes: 37 additions & 0 deletions algo_es_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"errors"
"testing"
)

Expand Down Expand Up @@ -63,6 +64,42 @@ func TestES(t *testing.T) {
f(ES512, ecdsaOtherPrivateKey521, ecdsaPublicKey521, false)
}

func TestES_BadKeys(t *testing.T) {
f := func(err, wantErr error) {
t.Helper()

if !errors.Is(err, wantErr) {
t.Fatalf("expected %v, got %v", wantErr, err)
}
}

f(getSignerError(NewSignerES(ES256, nil)), ErrNilKey)
f(getSignerError(NewSignerES(ES384, nil)), ErrNilKey)
f(getSignerError(NewSignerES(ES512, nil)), ErrNilKey)

f(getSignerError(NewSignerES("foo", ecdsaPrivateKey384)), ErrUnsupportedAlg)

f(getSignerError(NewSignerES(ES256, ecdsaPrivateKey384)), ErrInvalidKey)
f(getSignerError(NewSignerES(ES256, ecdsaPrivateKey521)), ErrInvalidKey)
f(getSignerError(NewSignerES(ES384, ecdsaPrivateKey256)), ErrInvalidKey)
f(getSignerError(NewSignerES(ES384, ecdsaPrivateKey521)), ErrInvalidKey)
f(getSignerError(NewSignerES(ES512, ecdsaPrivateKey256)), ErrInvalidKey)
f(getSignerError(NewSignerES(ES512, ecdsaPrivateKey384)), ErrInvalidKey)

f(getVerifierError(NewVerifierES(ES256, nil)), ErrNilKey)
f(getVerifierError(NewVerifierES(ES384, nil)), ErrNilKey)
f(getVerifierError(NewVerifierES(ES512, nil)), ErrNilKey)

f(getVerifierError(NewVerifierES("boo", ecdsaPublicKey384)), ErrUnsupportedAlg)

f(getVerifierError(NewVerifierES(ES256, ecdsaPublicKey384)), ErrInvalidKey)
f(getVerifierError(NewVerifierES(ES256, ecdsaPublicKey521)), ErrInvalidKey)
f(getVerifierError(NewVerifierES(ES384, ecdsaPublicKey256)), ErrInvalidKey)
f(getVerifierError(NewVerifierES(ES384, ecdsaPublicKey521)), ErrInvalidKey)
f(getVerifierError(NewVerifierES(ES512, ecdsaPublicKey256)), ErrInvalidKey)
f(getVerifierError(NewVerifierES(ES512, ecdsaPublicKey384)), ErrInvalidKey)
}

func esSign(t *testing.T, alg Algorithm, privateKey *ecdsa.PrivateKey, payload string) []byte {
t.Helper()

Expand Down
2 changes: 1 addition & 1 deletion algo_hs.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type hmacAlgo interface {

func newHS(alg Algorithm, key []byte) (hmacAlgo, error) {
if len(key) == 0 {
return nil, ErrInvalidKey
return nil, ErrNilKey
}
hash, ok := getHashHMAC(alg)
if !ok {
Expand Down
33 changes: 20 additions & 13 deletions algo_ps.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
// NewSignerPS returns a new RSA-PSS-based signer.
func NewSignerPS(alg Algorithm, key *rsa.PrivateKey) (Signer, error) {
if key == nil {
return nil, ErrInvalidKey
return nil, ErrNilKey
}
hash, opts, ok := getParamsPS(alg)
if !ok {
return nil, ErrUnsupportedAlg
hash, opts, err := getParamsPS(alg, key.Size())
if err != nil {
return nil, err
}
return &psAlg{
alg: alg,
Expand All @@ -26,11 +26,11 @@ func NewSignerPS(alg Algorithm, key *rsa.PrivateKey) (Signer, error) {
// NewVerifierPS returns a new RSA-PSS-based signer.
func NewVerifierPS(alg Algorithm, key *rsa.PublicKey) (Verifier, error) {
if key == nil {
return nil, ErrInvalidKey
return nil, ErrNilKey
}
hash, opts, ok := getParamsPS(alg)
if !ok {
return nil, ErrUnsupportedAlg
hash, opts, err := getParamsPS(alg, key.Size())
if err != nil {
return nil, err
}
return &psAlg{
alg: alg,
Expand All @@ -40,17 +40,24 @@ func NewVerifierPS(alg Algorithm, key *rsa.PublicKey) (Verifier, error) {
}, nil
}

func getParamsPS(alg Algorithm) (crypto.Hash, *rsa.PSSOptions, bool) {
func getParamsPS(alg Algorithm, size int) (crypto.Hash, *rsa.PSSOptions, error) {
var hash crypto.Hash
var opts *rsa.PSSOptions
switch alg {
case PS256:
return crypto.SHA256, optsPS256, true
hash, opts = crypto.SHA256, optsPS256
case PS384:
return crypto.SHA384, optsPS384, true
hash, opts = crypto.SHA384, optsPS384
case PS512:
return crypto.SHA512, optsPS512, true
hash, opts = crypto.SHA512, optsPS512
default:
return 0, nil, false
return 0, nil, ErrUnsupportedAlg
}

if alg.keySize() != size {
return 0, nil, ErrInvalidKey
}
return hash, opts, nil
}

var (
Expand Down
37 changes: 37 additions & 0 deletions algo_ps_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package jwt

import (
"crypto/rsa"
"errors"
"testing"
)

Expand Down Expand Up @@ -35,6 +36,42 @@ func TestPS(t *testing.T) {
f(PS512, rsaOtherPrivateKey512, rsaPublicKey512, false)
}

func TestPS_BadKeys(t *testing.T) {
f := func(err, wantErr error) {
t.Helper()

if !errors.Is(err, wantErr) {
t.Fatalf("expected %v, got %v", wantErr, err)
}
}

f(getSignerError(NewSignerPS(PS256, nil)), ErrNilKey)
f(getSignerError(NewSignerPS(PS384, nil)), ErrNilKey)
f(getSignerError(NewSignerPS(PS512, nil)), ErrNilKey)

f(getSignerError(NewSignerPS("foo", rsaPrivateKey384)), ErrUnsupportedAlg)

f(getSignerError(NewSignerPS(PS256, rsaPrivateKey384)), ErrInvalidKey)
f(getSignerError(NewSignerPS(PS256, rsaPrivateKey512)), ErrInvalidKey)
f(getSignerError(NewSignerPS(PS384, rsaPrivateKey256)), ErrInvalidKey)
f(getSignerError(NewSignerPS(PS384, rsaPrivateKey512)), ErrInvalidKey)
f(getSignerError(NewSignerPS(PS512, rsaPrivateKey256)), ErrInvalidKey)
f(getSignerError(NewSignerPS(PS512, rsaPrivateKey384)), ErrInvalidKey)

f(getVerifierError(NewVerifierPS(PS256, nil)), ErrNilKey)
f(getVerifierError(NewVerifierPS(PS384, nil)), ErrNilKey)
f(getVerifierError(NewVerifierPS(PS512, nil)), ErrNilKey)

f(getVerifierError(NewVerifierPS("boo", rsaPublicKey384)), ErrUnsupportedAlg)

f(getVerifierError(NewVerifierPS(PS256, rsaPublicKey384)), ErrInvalidKey)
f(getVerifierError(NewVerifierPS(PS256, rsaPublicKey512)), ErrInvalidKey)
f(getVerifierError(NewVerifierPS(PS384, rsaPublicKey256)), ErrInvalidKey)
f(getVerifierError(NewVerifierPS(PS384, rsaPublicKey512)), ErrInvalidKey)
f(getVerifierError(NewVerifierPS(PS512, rsaPublicKey256)), ErrInvalidKey)
f(getVerifierError(NewVerifierPS(PS512, rsaPublicKey384)), ErrInvalidKey)
}

func psSign(t *testing.T, alg Algorithm, privateKey *rsa.PrivateKey, payload string) []byte {
t.Helper()

Expand Down
32 changes: 19 additions & 13 deletions algo_rs.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ import (
// NewSignerRS returns a new RSA-based signer.
func NewSignerRS(alg Algorithm, key *rsa.PrivateKey) (Signer, error) {
if key == nil {
return nil, ErrInvalidKey
return nil, ErrNilKey
}
hash, ok := getHashRSA(alg)
if !ok {
return nil, ErrUnsupportedAlg
hash, err := getHashRS(alg, key.Size())
if err != nil {
return nil, err
}
return &rsAlg{
alg: alg,
Expand All @@ -25,11 +25,11 @@ func NewSignerRS(alg Algorithm, key *rsa.PrivateKey) (Signer, error) {
// NewVerifierRS returns a new RSA-based verifier.
func NewVerifierRS(alg Algorithm, key *rsa.PublicKey) (Verifier, error) {
if key == nil {
return nil, ErrInvalidKey
return nil, ErrNilKey
}
hash, ok := getHashRSA(alg)
if !ok {
return nil, ErrUnsupportedAlg
hash, err := getHashRS(alg, key.Size())
if err != nil {
return nil, err
}
return &rsAlg{
alg: alg,
Expand All @@ -38,17 +38,23 @@ func NewVerifierRS(alg Algorithm, key *rsa.PublicKey) (Verifier, error) {
}, nil
}

func getHashRSA(alg Algorithm) (crypto.Hash, bool) {
func getHashRS(alg Algorithm, size int) (crypto.Hash, error) {
var hash crypto.Hash
switch alg {
case RS256:
return crypto.SHA256, true
hash = crypto.SHA256
case RS384:
return crypto.SHA384, true
hash = crypto.SHA384
case RS512:
return crypto.SHA512, true
hash = crypto.SHA512
default:
return 0, false
return 0, ErrUnsupportedAlg
}

if alg.keySize() != size {
return 0, ErrInvalidKey
}
return hash, nil
}

type rsAlg struct {
Expand Down
Loading

0 comments on commit 7521642

Please sign in to comment.