Skip to content

Commit

Permalink
Code simplifications (#141)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristaloleg authored May 6, 2023
1 parent 2a32199 commit 17cf6fb
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 84 deletions.
41 changes: 19 additions & 22 deletions algo_eddsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,32 @@ import (

// NewSignerEdDSA returns a new ed25519-based signer.
func NewSignerEdDSA(key ed25519.PrivateKey) (*EdDSAAlg, error) {
if len(key) == 0 {
switch {
case len(key) == 0:
return nil, ErrNilKey
}
if len(key) != ed25519.PrivateKeySize {
case len(key) != ed25519.PrivateKeySize:
return nil, ErrInvalidKey
default:
return &EdDSAAlg{
publicKey: nil,
privateKey: key,
}, nil
}
return &EdDSAAlg{
publicKey: nil,
privateKey: key,
}, nil
}

// NewVerifierEdDSA returns a new ed25519-based verifier.
func NewVerifierEdDSA(key ed25519.PublicKey) (*EdDSAAlg, error) {
if len(key) == 0 {
switch {
case len(key) == 0:
return nil, ErrNilKey
}
if len(key) != ed25519.PublicKeySize {
case len(key) != ed25519.PublicKeySize:
return nil, ErrInvalidKey
default:
return &EdDSAAlg{
publicKey: key,
privateKey: nil,
}, nil
}
return &EdDSAAlg{
publicKey: key,
privateKey: nil,
}, nil
}

type EdDSAAlg struct {
Expand All @@ -55,14 +57,9 @@ func (ed *EdDSAAlg) Verify(token *Token) error {
return ErrUninitializedToken
case !constTimeAlgEqual(token.Header().Algorithm, EdDSA):
return ErrAlgorithmMismatch
default:
return ed.verify(token.PayloadPart(), token.Signature())
}
}

func (ed *EdDSAAlg) verify(payload, signature []byte) error {
if !ed25519.Verify(ed.publicKey, payload, signature) {
case !ed25519.Verify(ed.publicKey, token.PayloadPart(), token.Signature()):
return ErrInvalidSignature
default:
return nil
}
return nil
}
4 changes: 2 additions & 2 deletions algo_es.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ func (es *ESAlg) Sign(payload []byte) ([]byte, error) {
return nil, err
}

r, s, errSign := ecdsa.Sign(rand.Reader, es.privateKey, digest)
r, s, err := ecdsa.Sign(rand.Reader, es.privateKey, digest)
if err != nil {
return nil, errSign
return nil, err
}

pivot := es.SignSize() / 2
Expand Down
10 changes: 5 additions & 5 deletions algo_ps.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ func (ps *PSAlg) Sign(payload []byte) ([]byte, error) {
return nil, err
}

signature, errSign := rsa.SignPSS(rand.Reader, ps.privateKey, ps.hash, digest, ps.opts)
if errSign != nil {
return nil, errSign
signature, err := rsa.SignPSS(rand.Reader, ps.privateKey, ps.hash, digest, ps.opts)
if err != nil {
return nil, err
}
return signature, nil
}
Expand All @@ -116,8 +116,8 @@ func (ps *PSAlg) verify(payload, signature []byte) error {
return err
}

errVerify := rsa.VerifyPSS(ps.publicKey, ps.hash, digest, signature, ps.opts)
if errVerify != nil {
err = rsa.VerifyPSS(ps.publicKey, ps.hash, digest, signature, ps.opts)
if err != nil {
return ErrInvalidSignature
}
return nil
Expand Down
10 changes: 5 additions & 5 deletions algo_rs.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ func (rs *RSAlg) Sign(payload []byte) ([]byte, error) {
return nil, err
}

signature, errSign := rsa.SignPKCS1v15(rand.Reader, rs.privateKey, rs.hash, digest)
if errSign != nil {
return nil, errSign
signature, err := rsa.SignPKCS1v15(rand.Reader, rs.privateKey, rs.hash, digest)
if err != nil {
return nil, err
}
return signature, nil
}
Expand All @@ -100,8 +100,8 @@ func (rs *RSAlg) verify(payload, signature []byte) error {
return err
}

errVerify := rsa.VerifyPKCS1v15(rs.publicKey, rs.hash, digest, signature)
if errVerify != nil {
err = rsa.VerifyPKCS1v15(rs.publicKey, rs.hash, digest, signature)
if err != nil {
return ErrInvalidSignature
}
return nil
Expand Down
77 changes: 31 additions & 46 deletions build.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ func NewBuilder(signer Signer, opts ...BuilderOption) *Builder {
// If claims param is of type []byte or string then it's treated as a marshaled JSON.
// In other words you can pass already marshaled claims.
func (b *Builder) Build(claims interface{}) (*Token, error) {
rawClaims, errClaims := encodeClaims(claims)
if errClaims != nil {
return nil, errClaims
rawClaims, err := encodeClaims(claims)
if err != nil {
return nil, err
}

lenH := len(b.headerRaw)
Expand All @@ -68,9 +68,9 @@ func (b *Builder) Build(claims interface{}) (*Token, error) {
idx += lenC

// calculate signature of already written 'header.claims'
rawSignature, errSign := b.signer.Sign(token[:idx])
if errSign != nil {
return nil, errSign
rawSignature, err := b.signer.Sign(token[:idx])
if err != nil {
return nil, err
}

// add '.' and append encoded signature
Expand Down Expand Up @@ -102,7 +102,7 @@ func encodeClaims(claims interface{}) ([]byte, error) {

func encodeHeader(header Header) []byte {
if header.Type == "JWT" && header.ContentType == "" && header.KeyID == "" {
if h := getPredefinedHeader(header); h != "" {
if h := predefinedHeaders[header.Algorithm]; h != "" {
return []byte(h)
}
// another algorithm? encode below
Expand All @@ -115,45 +115,30 @@ func encodeHeader(header Header) []byte {
return encoded
}

func getPredefinedHeader(header Header) string {
switch header.Algorithm {
case EdDSA:
return "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9"

case HS256:
return "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
case HS384:
return "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9"
case HS512:
return "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9"

case RS256:
return "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9"
case RS384:
return "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9"
case RS512:
return "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9"

case ES256:
return "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9"
case ES384:
return "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCJ9"
case ES512:
return "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9"

case PS256:
return "eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9"
case PS384:
return "eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9"
case PS512:
return "eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9"
func b64Encode(dst, src []byte) {
base64.RawURLEncoding.Encode(dst, src)
}

default:
return ""
}
func b64EncodedLen(n int) int {
return base64.RawURLEncoding.EncodedLen(n)
}

var (
b64Encode = base64.RawURLEncoding.Encode
b64EncodedLen = base64.RawURLEncoding.EncodedLen
)
var predefinedHeaders = map[Algorithm]string{
EdDSA: "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9",

HS256: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9",
HS384: "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9",
HS512: "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9",

RS256: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9",
RS384: "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9",
RS512: "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9",

ES256: "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9",
ES384: "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCJ9",
ES512: "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9",

PS256: "eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9",
PS384: "eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9",
PS512: "eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9",
}
6 changes: 3 additions & 3 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"github.com/cristalhq/jwt/v5"
)

func ExampleSignAndVerify() {
func Example_signAndVerify() {
// create a Signer (HMAC in this example)
key := []byte(`secret`)
signer, err := jwt.NewSignerHS(jwt.HS256, key)
Expand Down Expand Up @@ -63,7 +63,7 @@ func ExampleSignAndVerify() {
// Output:
}

func ExampleBuild() {
func ExampleBuilder() {
key := []byte(`secret`)
signer, _ := jwt.NewSignerHS(jwt.HS256, key)
builder := jwt.NewBuilder(signer)
Expand Down Expand Up @@ -103,7 +103,7 @@ type userClaims struct {
Email string `json:"email"`
}

func ExampleBuild_WithUserClaims() {
func ExampleBuilder_withUserClaims() {
key := []byte(`secret`)
signer, _ := jwt.NewSignerHS(jwt.HS256, key)
builder := jwt.NewBuilder(signer)
Expand Down
4 changes: 3 additions & 1 deletion parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,4 +80,6 @@ func parse(token []byte) (*Token, error) {
return tk, nil
}

var b64Decode = base64.RawURLEncoding.Decode
func b64Decode(dst, src []byte) (n int, err error) {
return base64.RawURLEncoding.Decode(dst, src)
}

0 comments on commit 17cf6fb

Please sign in to comment.