Skip to content

Commit

Permalink
add RSA PSS (PS)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg committed Jan 4, 2020
1 parent 128633e commit 08ce406
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 0 deletions.
126 changes: 126 additions & 0 deletions algo_ps.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package jwt

import (
"crypto"
"crypto/rand"
"crypto/rsa"
)

var (
optsPS256 = &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA256,
}

optsPS384 = &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA384,
}

optsPS512 = &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
Hash: crypto.SHA512,
}
)

var _ Signer = (*psAlg)(nil)

type psAlg struct {
alg Algorithm
hash crypto.Hash
publickey *rsa.PublicKey
privateKey *rsa.PrivateKey
opts *rsa.PSSOptions
}

// NewPS256 returns new PS256 Signer using RSA PSS and SHA256 hash.
//
// Both public and private keys must not be nil.
//
func NewPS256(publicKey *rsa.PublicKey, privateKey *rsa.PrivateKey) (Signer, error) {
if publicKey == nil || privateKey == nil {
return nil, ErrInvalidKey
}
return &psAlg{
alg: PS256,
hash: crypto.SHA256,
publickey: publicKey,
privateKey: privateKey,
opts: optsPS256,
}, nil
}

// NewPS384 returns new PS384 Signer using RSA PSS and SHA384 hash.
//
// Both public and private keys must not be nil.
//
func NewPS384(publicKey *rsa.PublicKey, privateKey *rsa.PrivateKey) (Signer, error) {
if publicKey == nil || privateKey == nil {
return nil, ErrInvalidKey
}
return &psAlg{
alg: PS384,
hash: crypto.SHA384,
publickey: publicKey,
privateKey: privateKey,
opts: optsPS384,
}, nil
}

// NewPS512 returns new PS512 Signer using RSA PSS and SHA512 hash.
//
// Both public and private keys must not be nil.
//
func NewPS512(publicKey *rsa.PublicKey, privateKey *rsa.PrivateKey) (Signer, error) {
if publicKey == nil || privateKey == nil {
return nil, ErrInvalidKey
}
return &psAlg{
alg: PS512,
hash: crypto.SHA512,
publickey: publicKey,
privateKey: privateKey,
opts: optsPS512,
}, nil
}

func (h psAlg) Algorithm() Algorithm {
return h.alg
}

func (h psAlg) Sign(payload []byte) ([]byte, error) {
signed, err := h.sign(payload)
if err != nil {
return nil, err
}

signature, err := rsa.SignPSS(rand.Reader, h.privateKey, h.hash, signed, h.opts)
if err != nil {
return nil, err
}
return signature, nil
}

func (h psAlg) Verify(payload, signature []byte) error {
signed, err := h.sign(payload)
if err != nil {
return err
}

err = rsa.VerifyPSS(h.publickey, h.hash, signed, signature, h.opts)
if err != nil {
return ErrInvalidSignature
}
return nil
}

func (h psAlg) sign(payload []byte) ([]byte, error) {
hasher := h.hash.New()

_, err := hasher.Write(payload)
if err != nil {
return nil, err
}
signed := hasher.Sum(nil)
return signed, nil
}
102 changes: 102 additions & 0 deletions algo_ps_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package jwt

import (
"testing"
)

func TestPS256_WithValidSignature(t *testing.T) {
f := func(signer Signer, claims BinaryMarshaler) {
t.Helper()

tokenBuilder := NewTokenBuilder(signer)
token, _ := tokenBuilder.Build(claims)

err := signer.Verify(token.Payload(), token.Signature())
if err != nil {
t.Errorf("want no err, got: `%v`", err)
}
}

f(
getSigner(NewPS256(rsaPublicKey1, rsaPrivateKey1)),
&StandardClaims{},
)
f(
getSigner(NewPS384(rsaPublicKey1, rsaPrivateKey1)),
&StandardClaims{},
)
f(
getSigner(NewPS512(rsaPublicKey1, rsaPrivateKey1)),
&StandardClaims{},
)

f(
getSigner(NewPS256(rsaPublicKey1, rsaPrivateKey1)),
&customClaims{
TestField: "foo",
},
)
f(
getSigner(NewPS384(rsaPublicKey1, rsaPrivateKey1)),
&customClaims{
TestField: "bar",
},
)
f(
getSigner(NewPS512(rsaPublicKey1, rsaPrivateKey1)),
&customClaims{
TestField: "baz",
},
)
}

func TestPS384_WithInvalidSignature(t *testing.T) {
f := func(signer, verifier Signer, claims BinaryMarshaler) {
t.Helper()

tokenBuilder := NewTokenBuilder(signer)
token, _ := tokenBuilder.Build(claims)

err := verifier.Verify(token.Payload(), token.Signature())
if err == nil {
t.Errorf("want %v, got nil", ErrInvalidSignature)
}
}
f(
getSigner(NewPS256(rsaPublicKey1, rsaPrivateKey1)),
getSigner(NewPS256(rsaPublicKey2, rsaPrivateKey2)),
&StandardClaims{},
)
f(
getSigner(NewPS384(rsaPublicKey1, rsaPrivateKey1)),
getSigner(NewPS384(rsaPublicKey2, rsaPrivateKey2)),
&StandardClaims{},
)
f(
getSigner(NewPS512(rsaPublicKey1, rsaPrivateKey1)),
getSigner(NewPS512(rsaPublicKey2, rsaPrivateKey2)),
&StandardClaims{},
)

f(
getSigner(NewPS256(rsaPublicKey1, rsaPrivateKey1)),
getSigner(NewPS256(rsaPublicKey2, rsaPrivateKey2)),
&customClaims{
TestField: "foo",
},
)
f(
getSigner(NewPS384(rsaPublicKey1, rsaPrivateKey1)),
getSigner(NewPS384(rsaPublicKey2, rsaPrivateKey2)),
&customClaims{
TestField: "bar",
},
)
f(
getSigner(NewPS512(rsaPublicKey1, rsaPrivateKey1)),
getSigner(NewPS512(rsaPublicKey2, rsaPrivateKey2)),
&customClaims{
TestField: "baz",
},
)
}

0 comments on commit 08ce406

Please sign in to comment.