diff --git a/build.go b/build.go index 57ea53f..f96439f 100644 --- a/build.go +++ b/build.go @@ -36,7 +36,7 @@ func NewBuilder(signer Signer) *Builder { Type: "JWT", }, } - b.headerRaw = encodeHeader(&b.header) + b.headerRaw = encodeHeader(b.header) return b } @@ -77,7 +77,6 @@ func (b *Builder) Build(claims interface{}) (*Token, error) { raw[idx] = '.' idx++ base64Encode(raw[idx:], signature) - idx += lenS token := &Token{ raw: raw, @@ -98,43 +97,12 @@ func encodeClaims(claims interface{}) ([]byte, error) { } } -func encodeHeader(header *Header) []byte { +func encodeHeader(header Header) []byte { if header.Type == "JWT" && header.ContentType == "" { - switch header.Algorithm { - case EdDSA: - return []byte(encHeaderEdDSA) - - case HS256: - return []byte(encHeaderHS256) - case HS384: - return []byte(encHeaderHS384) - case HS512: - return []byte(encHeaderHS512) - - case RS256: - return []byte(encHeaderRS256) - case RS384: - return []byte(encHeaderRS384) - case RS512: - return []byte(encHeaderRS512) - - case ES256: - return []byte(encHeaderES256) - case ES384: - return []byte(encHeaderES384) - case ES512: - return []byte(encHeaderES512) - - case PS256: - return []byte(encHeaderPS256) - case PS384: - return []byte(encHeaderPS384) - case PS512: - return []byte(encHeaderPS512) - - default: - // another algorithm? encode below + if h := getPredefinedHeader(header); h != nil { + return h } + // another algorithm? encode below } // returned err is always nil, see *Header.MarshalJSON buf, _ := json.Marshal(header) @@ -144,6 +112,44 @@ func encodeHeader(header *Header) []byte { return encoded } +func getPredefinedHeader(header Header) []byte { + switch header.Algorithm { + case EdDSA: + return []byte(encHeaderEdDSA) + + case HS256: + return []byte(encHeaderHS256) + case HS384: + return []byte(encHeaderHS384) + case HS512: + return []byte(encHeaderHS512) + + case RS256: + return []byte(encHeaderRS256) + case RS384: + return []byte(encHeaderRS384) + case RS512: + return []byte(encHeaderRS512) + + case ES256: + return []byte(encHeaderES256) + case ES384: + return []byte(encHeaderES384) + case ES512: + return []byte(encHeaderES512) + + case PS256: + return []byte(encHeaderPS256) + case PS384: + return []byte(encHeaderPS384) + case PS512: + return []byte(encHeaderPS512) + + default: + return nil + } +} + const ( encHeaderEdDSA = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9" diff --git a/example_test.go b/example_test.go index 627ea7c..16d6cc1 100644 --- a/example_test.go +++ b/example_test.go @@ -27,7 +27,8 @@ func Example_JWT() { builder := jwt.NewBuilder(signer) // 4. and build a token - token, err := builder.Build(claims) + token, errBuild := builder.Build(claims) + checkErr(errBuild) // 5. here is your token :) var _ []byte = token.Raw() // or just token.String() for string