Skip to content

Commit

Permalink
Optimise builder (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg authored May 9, 2020
1 parent 1b483c8 commit 378bc1f
Show file tree
Hide file tree
Showing 10 changed files with 117 additions and 46 deletions.
3 changes: 3 additions & 0 deletions algo.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
// Signer is used to sign tokens.
type Signer interface {
Algorithm() Algorithm
SignSize() int
Sign(payload []byte) ([]byte, error)
}

Expand All @@ -20,6 +21,8 @@ type Verifier interface {
// Algorithm for signing and verifying.
type Algorithm string

func (a Algorithm) String() string { return string(a) }

// Algorithm names for signing and verifying.
const (
EdDSA Algorithm = "EdDSA"
Expand Down
4 changes: 4 additions & 0 deletions algo_eddsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ func (h edDSAAlg) Algorithm() Algorithm {
return h.alg
}

func (h edDSAAlg) SignSize() int {
return ed25519.SignatureSize
}

func (h edDSAAlg) Sign(payload []byte) ([]byte, error) {
return ed25519.Sign(h.privateKey, payload), nil
}
Expand Down
4 changes: 4 additions & 0 deletions algo_es.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ func (h esAlg) Algorithm() Algorithm {
return h.alg
}

func (h esAlg) SignSize() int {
return 2 * h.curveBits
}

func (h esAlg) Sign(payload []byte) ([]byte, error) {
signed, err := h.sign(payload)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions algo_hs.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ func (h hsAlg) Algorithm() Algorithm {
return h.alg
}

func (h hsAlg) SignSize() int {
return h.hash.Size()
}

func (h hsAlg) Sign(payload []byte) ([]byte, error) {
return h.sign(payload)
}
Expand Down
4 changes: 4 additions & 0 deletions algo_ps.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ type psAlg struct {
opts *rsa.PSSOptions
}

func (h psAlg) SignSize() int {
return h.privateKey.Size()
}

func (h psAlg) Algorithm() Algorithm {
return h.alg
}
Expand Down
4 changes: 4 additions & 0 deletions algo_rs.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ func (h rsAlg) Algorithm() Algorithm {
return h.alg
}

func (h rsAlg) SignSize() int {
return h.privateKey.Size()
}

func (h rsAlg) Sign(payload []byte) ([]byte, error) {
signed, err := h.sign(payload)
if err != nil {
Expand Down
121 changes: 83 additions & 38 deletions build.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ var (

// Builder is used to create a new token.
type Builder struct {
signer Signer
header Header
signer Signer
header Header
headerRaw []byte
}

// BuildBytes is used to create and encode JWT with a provided claims.
Expand All @@ -30,12 +31,12 @@ func Build(signer Signer, claims interface{}) (*Token, error) {
func NewBuilder(signer Signer) *Builder {
b := &Builder{
signer: signer,

header: Header{
Algorithm: signer.Algorithm(),
Type: "JWT",
},
}
b.headerRaw = encodeHeader(&b.header)
return b
}

Expand All @@ -49,72 +50,116 @@ func (b *Builder) BuildBytes(claims interface{}) ([]byte, error) {
}

// Build used to create and encode JWT with a provided claims.
// If claims param is of type []byte then we treat it as a marshaled JSON.
// In other words you can pass already marshaled claims.
func (b *Builder) Build(claims interface{}) (*Token, error) {
rawClaims, encodedClaims, err := encodeClaims(claims)
rawClaims, err := encodeClaims(claims)
if err != nil {
return nil, err
}

encodedHeader := encodeHeader(&b.header)
payload := concatParts(encodedHeader, encodedClaims)
lenH := len(b.headerRaw)
lenC := base64EncodedLen(len(rawClaims))
lenS := base64EncodedLen(b.signer.SignSize())

raw := make([]byte, lenH+1+lenC+1+lenS)
idx := 0
idx += copy(raw[idx:], b.headerRaw)
raw[idx] = '.'
idx++
base64Encode(raw[idx:], rawClaims)
idx += lenC

raw, signature, err := signPayload(b.signer, payload)
signature, err := b.signer.Sign(raw[:idx])
if err != nil {
return nil, err
}
raw[idx] = '.'
idx++
base64Encode(raw[idx:], signature)
idx += lenS

token := &Token{
raw: raw,
payload: payload,
payload: raw[:lenH+1+lenC],
signature: signature,
header: b.header,
claims: rawClaims,
}
return token, nil
}

func encodeClaims(claims interface{}) (raw, encoded []byte, err error) {
raw, err = json.Marshal(claims)
if err != nil {
return nil, nil, err
func encodeClaims(claims interface{}) ([]byte, error) {
switch claims := claims.(type) {
case []byte:
return claims, nil
default:
return json.Marshal(claims)
}

encoded = make([]byte, base64EncodedLen(len(raw)))
base64Encode(encoded, raw)

return raw, encoded, nil
}

func encodeHeader(header *Header) []byte {
if header.Type == "JWT" && header.ContentType == "" {
switch header.Algorithm {
case EdDSA:
return []byte(encHeaderEdDSA)

case HS256:
return []byte(encHeaderHS256)
case HS384:
return []byte(encHeaderHS384)
case HS512:
return []byte(encHeaderHS512)

case RS256:
return []byte(encHeaderRS256)
case RS384:
return []byte(encHeaderRS384)
case RS512:
return []byte(encHeaderRS512)

case ES256:
return []byte(encHeaderES256)
case ES384:
return []byte(encHeaderES384)
case ES512:
return []byte(encHeaderES512)

case PS256:
return []byte(encHeaderPS256)
case PS384:
return []byte(encHeaderPS384)
case PS512:
return []byte(encHeaderPS512)

default:
// another algorithm? encode below
}
}
// returned err is always nil, see *Header.MarshalJSON
buf, _ := header.MarshalJSON()
buf, _ := json.Marshal(header)

encoded := make([]byte, base64EncodedLen(len(buf)))
base64Encode(encoded, buf)

return encoded
}

func signPayload(signer Signer, payload []byte) (signed, signature []byte, err error) {
signature, err = signer.Sign(payload)
if err != nil {
return nil, nil, err
}
const (
encHeaderEdDSA = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9"

encodedSignature := make([]byte, base64EncodedLen(len(signature)))
base64Encode(encodedSignature, signature)
encHeaderHS256 = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
encHeaderHS384 = "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9"
encHeaderHS512 = "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9"

signed = concatParts(payload, encodedSignature)

return signed, signature, nil
}
encHeaderRS256 = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
encHeaderRS384 = "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9"
encHeaderRS512 = "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9"

func concatParts(a, b []byte) []byte {
buf := make([]byte, len(a)+1+len(b))
buf[len(a)] = '.'
encHeaderES256 = "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9"
encHeaderES384 = "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCJ9"
encHeaderES512 = "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9"

copy(buf[:len(a)], a)
copy(buf[len(a)+1:], b)

return buf
}
encHeaderPS256 = "eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9"
encHeaderPS384 = "eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9"
encHeaderPS512 = "eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9"
)
7 changes: 5 additions & 2 deletions build_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func TestBuild(t *testing.T) {

raw := string(token)
if raw != want {
t.Errorf("want %v, got %v", want, raw)
t.Errorf("want %v,\n got %v", want, raw)
}
}

Expand All @@ -43,7 +43,7 @@ func TestBuildHeader(t *testing.T) {
want = toBase64(want)
raw := string(token.RawHeader())
if raw != want {
t.Errorf("want %v, got %v", want, raw)
t.Errorf("\nwant %v,\n got %v", want, raw)
}
}

Expand Down Expand Up @@ -109,6 +109,9 @@ func toBase64(s string) string {

type badSigner struct{}

func (badSigner) SignSize() int {
return 0
}
func (badSigner) Algorithm() Algorithm {
return "bad"
}
Expand Down
2 changes: 1 addition & 1 deletion jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ type Header struct {
}

// MarshalJSON implements the json.Marshaler interface.
func (h *Header) MarshalJSON() (data []byte, err error) {
func (h *Header) MarshalJSON() ([]byte, error) {
buf := bytes.Buffer{}
buf.WriteString(`{"alg":"`)
buf.WriteString(string(h.Algorithm))
Expand Down
10 changes: 5 additions & 5 deletions jwt_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/cristalhq/jwt/v2"
)

func BenchmarkEDSA(b *testing.B) {
func BenchmarkAlgEDSA(b *testing.B) {
pubKey, privKey, keyErr := ed25519.GenerateKey(rand.Reader)
if keyErr != nil {
b.Fatal(keyErr)
Expand All @@ -35,7 +35,7 @@ func BenchmarkEDSA(b *testing.B) {
})
}

func BenchmarkES(b *testing.B) {
func BenchmarkAlgES(b *testing.B) {
esAlgos := map[jwt.Algorithm]elliptic.Curve{
jwt.ES256: elliptic.P256(),
jwt.ES384: elliptic.P384(),
Expand Down Expand Up @@ -64,7 +64,7 @@ func BenchmarkES(b *testing.B) {
}
}

func BenchmarkPS(b *testing.B) {
func BenchmarkAlgPS(b *testing.B) {
psAlgos := []jwt.Algorithm{jwt.PS256, jwt.PS384, jwt.PS512}
for _, algo := range psAlgos {
key, keyErr := rsa.GenerateKey(rand.Reader, 2048)
Expand All @@ -89,7 +89,7 @@ func BenchmarkPS(b *testing.B) {
}
}

func BenchmarkRS(b *testing.B) {
func BenchmarkAlgRS(b *testing.B) {
rsAlgos := []jwt.Algorithm{jwt.RS256, jwt.RS384, jwt.RS512}
for _, algo := range rsAlgos {
key, keyErr := rsa.GenerateKey(rand.Reader, 2048)
Expand All @@ -114,7 +114,7 @@ func BenchmarkRS(b *testing.B) {
}
}

func BenchmarkHS(b *testing.B) {
func BenchmarkAlgHS(b *testing.B) {
key := []byte("12345")
hsAlgos := []jwt.Algorithm{jwt.HS256, jwt.HS384, jwt.HS512}
for _, algo := range hsAlgos {
Expand Down

0 comments on commit 378bc1f

Please sign in to comment.