From 17cf6fb7952ce5a3a12ada0478ce2c6b076f6b9a Mon Sep 17 00:00:00 2001 From: Oleg Kovalov Date: Sat, 6 May 2023 19:26:04 +0200 Subject: [PATCH] Code simplifications (#141) --- algo_eddsa.go | 41 ++++++++++++-------------- algo_es.go | 4 +-- algo_ps.go | 10 +++---- algo_rs.go | 10 +++---- build.go | 77 ++++++++++++++++++++----------------------------- example_test.go | 6 ++-- parse.go | 4 ++- 7 files changed, 68 insertions(+), 84 deletions(-) diff --git a/algo_eddsa.go b/algo_eddsa.go index 43f86fd..fad1a7a 100644 --- a/algo_eddsa.go +++ b/algo_eddsa.go @@ -6,30 +6,32 @@ import ( // NewSignerEdDSA returns a new ed25519-based signer. func NewSignerEdDSA(key ed25519.PrivateKey) (*EdDSAAlg, error) { - if len(key) == 0 { + switch { + case len(key) == 0: return nil, ErrNilKey - } - if len(key) != ed25519.PrivateKeySize { + case len(key) != ed25519.PrivateKeySize: return nil, ErrInvalidKey + default: + return &EdDSAAlg{ + publicKey: nil, + privateKey: key, + }, nil } - return &EdDSAAlg{ - publicKey: nil, - privateKey: key, - }, nil } // NewVerifierEdDSA returns a new ed25519-based verifier. func NewVerifierEdDSA(key ed25519.PublicKey) (*EdDSAAlg, error) { - if len(key) == 0 { + switch { + case len(key) == 0: return nil, ErrNilKey - } - if len(key) != ed25519.PublicKeySize { + case len(key) != ed25519.PublicKeySize: return nil, ErrInvalidKey + default: + return &EdDSAAlg{ + publicKey: key, + privateKey: nil, + }, nil } - return &EdDSAAlg{ - publicKey: key, - privateKey: nil, - }, nil } type EdDSAAlg struct { @@ -55,14 +57,9 @@ func (ed *EdDSAAlg) Verify(token *Token) error { return ErrUninitializedToken case !constTimeAlgEqual(token.Header().Algorithm, EdDSA): return ErrAlgorithmMismatch - default: - return ed.verify(token.PayloadPart(), token.Signature()) - } -} - -func (ed *EdDSAAlg) verify(payload, signature []byte) error { - if !ed25519.Verify(ed.publicKey, payload, signature) { + case !ed25519.Verify(ed.publicKey, token.PayloadPart(), token.Signature()): return ErrInvalidSignature + default: + return nil } - return nil } diff --git a/algo_es.go b/algo_es.go index 09c9faf..578ab60 100644 --- a/algo_es.go +++ b/algo_es.go @@ -86,9 +86,9 @@ func (es *ESAlg) Sign(payload []byte) ([]byte, error) { return nil, err } - r, s, errSign := ecdsa.Sign(rand.Reader, es.privateKey, digest) + r, s, err := ecdsa.Sign(rand.Reader, es.privateKey, digest) if err != nil { - return nil, errSign + return nil, err } pivot := es.SignSize() / 2 diff --git a/algo_ps.go b/algo_ps.go index 31e739a..e38a7e1 100644 --- a/algo_ps.go +++ b/algo_ps.go @@ -92,9 +92,9 @@ func (ps *PSAlg) Sign(payload []byte) ([]byte, error) { return nil, err } - signature, errSign := rsa.SignPSS(rand.Reader, ps.privateKey, ps.hash, digest, ps.opts) - if errSign != nil { - return nil, errSign + signature, err := rsa.SignPSS(rand.Reader, ps.privateKey, ps.hash, digest, ps.opts) + if err != nil { + return nil, err } return signature, nil } @@ -116,8 +116,8 @@ func (ps *PSAlg) verify(payload, signature []byte) error { return err } - errVerify := rsa.VerifyPSS(ps.publicKey, ps.hash, digest, signature, ps.opts) - if errVerify != nil { + err = rsa.VerifyPSS(ps.publicKey, ps.hash, digest, signature, ps.opts) + if err != nil { return ErrInvalidSignature } return nil diff --git a/algo_rs.go b/algo_rs.go index 511d6b2..6e07c14 100644 --- a/algo_rs.go +++ b/algo_rs.go @@ -76,9 +76,9 @@ func (rs *RSAlg) Sign(payload []byte) ([]byte, error) { return nil, err } - signature, errSign := rsa.SignPKCS1v15(rand.Reader, rs.privateKey, rs.hash, digest) - if errSign != nil { - return nil, errSign + signature, err := rsa.SignPKCS1v15(rand.Reader, rs.privateKey, rs.hash, digest) + if err != nil { + return nil, err } return signature, nil } @@ -100,8 +100,8 @@ func (rs *RSAlg) verify(payload, signature []byte) error { return err } - errVerify := rsa.VerifyPKCS1v15(rs.publicKey, rs.hash, digest, signature) - if errVerify != nil { + err = rsa.VerifyPKCS1v15(rs.publicKey, rs.hash, digest, signature) + if err != nil { return ErrInvalidSignature } return nil diff --git a/build.go b/build.go index cde41fe..3a64ab5 100644 --- a/build.go +++ b/build.go @@ -48,9 +48,9 @@ func NewBuilder(signer Signer, opts ...BuilderOption) *Builder { // If claims param is of type []byte or string then it's treated as a marshaled JSON. // In other words you can pass already marshaled claims. func (b *Builder) Build(claims interface{}) (*Token, error) { - rawClaims, errClaims := encodeClaims(claims) - if errClaims != nil { - return nil, errClaims + rawClaims, err := encodeClaims(claims) + if err != nil { + return nil, err } lenH := len(b.headerRaw) @@ -68,9 +68,9 @@ func (b *Builder) Build(claims interface{}) (*Token, error) { idx += lenC // calculate signature of already written 'header.claims' - rawSignature, errSign := b.signer.Sign(token[:idx]) - if errSign != nil { - return nil, errSign + rawSignature, err := b.signer.Sign(token[:idx]) + if err != nil { + return nil, err } // add '.' and append encoded signature @@ -102,7 +102,7 @@ func encodeClaims(claims interface{}) ([]byte, error) { func encodeHeader(header Header) []byte { if header.Type == "JWT" && header.ContentType == "" && header.KeyID == "" { - if h := getPredefinedHeader(header); h != "" { + if h := predefinedHeaders[header.Algorithm]; h != "" { return []byte(h) } // another algorithm? encode below @@ -115,45 +115,30 @@ func encodeHeader(header Header) []byte { return encoded } -func getPredefinedHeader(header Header) string { - switch header.Algorithm { - case EdDSA: - return "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9" - - case HS256: - return "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" - case HS384: - return "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9" - case HS512: - return "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9" - - case RS256: - return "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9" - case RS384: - return "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9" - case RS512: - return "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9" - - case ES256: - return "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9" - case ES384: - return "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCJ9" - case ES512: - return "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9" - - case PS256: - return "eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9" - case PS384: - return "eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9" - case PS512: - return "eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9" +func b64Encode(dst, src []byte) { + base64.RawURLEncoding.Encode(dst, src) +} - default: - return "" - } +func b64EncodedLen(n int) int { + return base64.RawURLEncoding.EncodedLen(n) } -var ( - b64Encode = base64.RawURLEncoding.Encode - b64EncodedLen = base64.RawURLEncoding.EncodedLen -) +var predefinedHeaders = map[Algorithm]string{ + EdDSA: "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9", + + HS256: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9", + HS384: "eyJhbGciOiJIUzM4NCIsInR5cCI6IkpXVCJ9", + HS512: "eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCJ9", + + RS256: "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9", + RS384: "eyJhbGciOiJSUzM4NCIsInR5cCI6IkpXVCJ9", + RS512: "eyJhbGciOiJSUzUxMiIsInR5cCI6IkpXVCJ9", + + ES256: "eyJhbGciOiJFUzI1NiIsInR5cCI6IkpXVCJ9", + ES384: "eyJhbGciOiJFUzM4NCIsInR5cCI6IkpXVCJ9", + ES512: "eyJhbGciOiJFUzUxMiIsInR5cCI6IkpXVCJ9", + + PS256: "eyJhbGciOiJQUzI1NiIsInR5cCI6IkpXVCJ9", + PS384: "eyJhbGciOiJQUzM4NCIsInR5cCI6IkpXVCJ9", + PS512: "eyJhbGciOiJQUzUxMiIsInR5cCI6IkpXVCJ9", +} diff --git a/example_test.go b/example_test.go index 817deeb..dbfa5ba 100644 --- a/example_test.go +++ b/example_test.go @@ -8,7 +8,7 @@ import ( "github.com/cristalhq/jwt/v5" ) -func ExampleSignAndVerify() { +func Example_signAndVerify() { // create a Signer (HMAC in this example) key := []byte(`secret`) signer, err := jwt.NewSignerHS(jwt.HS256, key) @@ -63,7 +63,7 @@ func ExampleSignAndVerify() { // Output: } -func ExampleBuild() { +func ExampleBuilder() { key := []byte(`secret`) signer, _ := jwt.NewSignerHS(jwt.HS256, key) builder := jwt.NewBuilder(signer) @@ -103,7 +103,7 @@ type userClaims struct { Email string `json:"email"` } -func ExampleBuild_WithUserClaims() { +func ExampleBuilder_withUserClaims() { key := []byte(`secret`) signer, _ := jwt.NewSignerHS(jwt.HS256, key) builder := jwt.NewBuilder(signer) diff --git a/parse.go b/parse.go index 528e22d..29b83cd 100644 --- a/parse.go +++ b/parse.go @@ -80,4 +80,6 @@ func parse(token []byte) (*Token, error) { return tk, nil } -var b64Decode = base64.RawURLEncoding.Decode +func b64Decode(dst, src []byte) (n int, err error) { + return base64.RawURLEncoding.Decode(dst, src) +}