diff --git a/algo_es.go b/algo_es.go index 8fdde9e..c22b73c 100644 --- a/algo_es.go +++ b/algo_es.go @@ -12,7 +12,7 @@ func NewSignerES(alg Algorithm, key *ecdsa.PrivateKey) (Signer, error) { if key == nil { return nil, ErrInvalidKey } - hash, keySize, curveBits, err := getParamsES(alg) + hash, err := getParamsES(alg) if err != nil { return nil, err } @@ -20,8 +20,7 @@ func NewSignerES(alg Algorithm, key *ecdsa.PrivateKey) (Signer, error) { alg: alg, hash: hash, privateKey: key, - keySize: keySize, - curveBits: curveBits, + signSize: roundBytes(key.PublicKey.Params().BitSize) * 2, }, nil } @@ -30,7 +29,7 @@ func NewVerifierES(alg Algorithm, key *ecdsa.PublicKey) (Verifier, error) { if key == nil { return nil, ErrInvalidKey } - hash, keySize, curveBits, err := getParamsES(alg) + hash, err := getParamsES(alg) if err != nil { return nil, err } @@ -38,21 +37,20 @@ func NewVerifierES(alg Algorithm, key *ecdsa.PublicKey) (Verifier, error) { alg: alg, hash: hash, publickey: key, - keySize: keySize, - curveBits: curveBits, + signSize: roundBytes(key.Params().BitSize) * 2, }, nil } -func getParamsES(alg Algorithm) (crypto.Hash, int, int, error) { +func getParamsES(alg Algorithm) (crypto.Hash, error) { switch alg { case ES256: - return crypto.SHA256, 32, 256, nil + return crypto.SHA256, nil case ES384: - return crypto.SHA384, 48, 384, nil + return crypto.SHA384, nil case ES512: - return crypto.SHA512, 66, 521, nil + return crypto.SHA512, nil default: - return 0, 0, 0, ErrUnsupportedAlg + return 0, ErrUnsupportedAlg } } @@ -61,8 +59,7 @@ type esAlg struct { hash crypto.Hash publickey *ecdsa.PublicKey privateKey *ecdsa.PrivateKey - keySize int - curveBits int + signSize int } func (es esAlg) Algorithm() Algorithm { @@ -70,7 +67,7 @@ func (es esAlg) Algorithm() Algorithm { } func (es esAlg) SignSize() int { - return (es.privateKey.Curve.Params().BitSize + 7) / 4 + return es.signSize } func (es esAlg) Sign(payload []byte) ([]byte, error) { @@ -84,17 +81,17 @@ func (es esAlg) Sign(payload []byte) ([]byte, error) { return nil, err } - keyBytes := es.SignSize() / 2 + pivot := es.SignSize() / 2 rBytes, sBytes := r.Bytes(), s.Bytes() - signature := make([]byte, keyBytes*2) - copy(signature[keyBytes-len(rBytes):], rBytes) - copy(signature[keyBytes*2-len(sBytes):], sBytes) + signature := make([]byte, es.SignSize()) + copy(signature[pivot-len(rBytes):], rBytes) + copy(signature[pivot*2-len(sBytes):], sBytes) return signature, nil } func (es esAlg) Verify(payload, signature []byte) error { - if len(signature) != 2*es.keySize { + if len(signature) != es.SignSize() { return ErrInvalidSignature } @@ -103,8 +100,9 @@ func (es esAlg) Verify(payload, signature []byte) error { return err } - r := big.NewInt(0).SetBytes(signature[:es.keySize]) - s := big.NewInt(0).SetBytes(signature[es.keySize:]) + pivot := es.SignSize() / 2 + r := big.NewInt(0).SetBytes(signature[:pivot]) + s := big.NewInt(0).SetBytes(signature[pivot:]) if !ecdsa.Verify(es.publickey, signed, r, s) { return ErrInvalidSignature @@ -122,3 +120,11 @@ func (es esAlg) sign(payload []byte) ([]byte, error) { signed := hasher.Sum(nil) return signed, nil } + +func roundBytes(n int) int { + res := n / 8 + if n%8 > 0 { + return res + 1 + } + return res +} diff --git a/build_test.go b/build_test.go index c9a56c3..08c76a1 100644 --- a/build_test.go +++ b/build_test.go @@ -1,9 +1,14 @@ package jwt import ( + "crypto/ecdsa" + "crypto/x509" "encoding/base64" + "encoding/json" + "encoding/pem" "errors" "testing" + "time" ) func TestBuild(t *testing.T) { @@ -101,6 +106,87 @@ func TestBuildMalformed(t *testing.T) { ) } +var tests = []struct { + key *ecdsa.PrivateKey + alg Algorithm +}{ + {testKeyEC256, ES256}, + {testKeyEC384, ES384}, + {testKeyEC521, ES512}, +} + +var mybenchClaims = &struct { + StandardClaims +}{ + StandardClaims: StandardClaims{ + Issuer: "benchmark", + IssuedAt: NewNumericDate(time.Now()), + }, +} + +func Test_Two_ECDSA(t *testing.T) { + for _, test := range tests { + signer, err := NewSignerES(test.alg, test.key) + if err != nil { + t.Fatal(err) + } + bui := NewBuilder(signer) + token, err := bui.BuildBytes(mybenchClaims) + if err != nil { + t.Fatal(err) + } + + verifier, err := NewVerifierES(test.alg, &test.key.PublicKey) + if err != nil { + t.Fatal(err) + } + t.Run("check-"+test.alg.String(), func(t *testing.T) { + obj, err := ParseAndVerify(token, verifier) + if err != nil { + t.Fatal(err) + } + err = json.Unmarshal(obj.RawClaims(), new(map[string]interface{})) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func mustParseECKey(s string) *ecdsa.PrivateKey { + block, _ := pem.Decode([]byte(s)) + if block == nil { + panic("invalid PEM") + } + + key, err := x509.ParseECPrivateKey(block.Bytes) + if err != nil { + panic(err) + } + return key +} + +var testKeyEC256 = mustParseECKey(`-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIBOm12aaXvqSzysOSGV2yL/xKY3kCtaOfAPY1KQN2sTJoAoGCCqGSM49 +AwEHoUQDQgAEX0iTLAcGqlWeGIRtIk0G2PRgpf/6gLxOTyMAdriP4NLRkuu+9Idt +y3qmEizRC0N81j84E213/LuqLqnsrgfyiw== +-----END EC PRIVATE KEY-----`) + +var testKeyEC384 = mustParseECKey(`-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDBluSyfK9BEPc9y944ZLahd4xHRVse64iCeEC5gBQ4UM1961bsEthUC +NKXyTGTBuW2gBwYFK4EEACKhZANiAAR3Il6V61OwAnb6oYm4hQ4TVVaGQ2QGzrSi +eYGoRewNhAaZ8wfemWX4fww7yNi6AmUzWV8Su5Qq3dtN3nLpKUEaJrTvfjtowrr/ +ZtU1fZxzI/agEpG2+uLFW6JNdYzp67w= +-----END EC PRIVATE KEY-----`) + +var testKeyEC521 = mustParseECKey(`-----BEGIN EC PRIVATE KEY----- +MIHcAgEBBEIBH31vhkSH+x+J8C/xf/PRj81u3MCqgiaGdW1S1jcjEuikczbbX689 +9ETHGCPtHEWw/Il1RAFaKMvndmfDVd/YapmgBwYFK4EEACOhgYkDgYYABAGNpBDA +Lx6rKQXWdWQR581uw9dTuV8zjmkSpLZ3k0qLHVlOqt00AfEL4NO+E7fxh4SuAZPb +RDMu2lx4lWOM2EyFvgFIyu8xlA9lEg5GKq+A7+y5r99RLughiDd52vGnudMspHEy +x6IpwXzTZR/T8TkluL3jDWtVNFxGBf/aEErnpeLfRQ== +-----END EC PRIVATE KEY-----`) + func toBase64(s string) string { buf := make([]byte, base64EncodedLen(len(s))) base64.RawURLEncoding.Encode(buf, []byte(s))