Skip to content

Commit

Permalink
refactor: replace map claims with built-in jwt struct
Browse files Browse the repository at this point in the history
  • Loading branch information
ayaanqui committed Jun 3, 2024
1 parent 35e6ed2 commit 72289c2
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 58 deletions.
32 changes: 18 additions & 14 deletions services/user_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package services
import (
"context"
"fmt"
"strconv"
"time"

"github.com/go-jet/jet/v2/postgres"
Expand Down Expand Up @@ -354,14 +355,12 @@ func (Service) VerifyPasswordHash(password string, hash string) bool {

// Generates a JWT with claims, signed with key
func (Service) GenerateJWT(key string, user *gmodel.User) (string, error) {
jwt := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
"id": user.ID,
"name": user.Name,
"email": user.Email,
"authPlatform": (*user.AuthPlatform).String(),
"authStateId": *user.AuthStateID,
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 24 * 30).Unix(),
jwt := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.StandardClaims{
Subject: fmt.Sprintf("%d", user.ID),
Audience: user.Email,
Id: fmt.Sprintf("%d", *user.AuthStateID),
IssuedAt: time.Now().Unix(),
ExpiresAt: time.Now().Add(time.Hour * 24 * 30).Unix(),
})
token, err := jwt.SignedString([]byte(key))
if err != nil {
Expand All @@ -384,9 +383,14 @@ func (service Service) VerifyJwt(ctx context.Context, authorization types.Author
return gmodel.User{}, fmt.Errorf("token expired")
}

authStateId := int64(claims["authStateId"].(float64))
userId := int64(claims["id"].(float64))
email := claims["email"].(string)
authStateId, err := strconv.ParseInt(claims.Id, 10, 64)
if err != nil {
return gmodel.User{}, fmt.Errorf("auth state id is invalid")
}
userId, err := strconv.ParseInt(claims.Subject, 10, 64)
if err != nil {
return gmodel.User{}, fmt.Errorf("user id is invalid")
}
query := table.User.
SELECT(
table.User.AllColumns,
Expand All @@ -399,9 +403,9 @@ func (service Service) VerifyJwt(ctx context.Context, authorization types.Author
)).
WHERE(
table.User.ID.
EQ(postgres.Int64(userId)).
AND(table.User.Email.EQ(postgres.String(email))).
AND(table.AuthState.ID.EQ(postgres.Int64(authStateId))),
EQ(postgres.Int(userId)).
AND(table.User.Email.EQ(postgres.String(claims.Audience))).
AND(table.AuthState.ID.EQ(postgres.Int(authStateId))),
).
LIMIT(1)
var user gmodel.User
Expand Down
7 changes: 4 additions & 3 deletions tests/user_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tests

import (
"strconv"
"testing"

"github.com/go-jet/jet/v2/postgres"
Expand Down Expand Up @@ -142,11 +143,11 @@ func TestUser(t *testing.T) {
t.Fatal("invalid jwt", err.Error())
}

claims_user_id, ok := claims["id"].(float64)
if !ok {
claims_user_id, err := strconv.ParseInt(claims.Subject, 10, 64)
if err != nil {
t.Fatal("could not convert claims.id to float64")
}
if int64(claims_user_id) != user1.ID {
if claims_user_id != user1.ID {
t.Fatal("jwt claim user.id does not match")
}
})
Expand Down
86 changes: 45 additions & 41 deletions utils/parsers.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
package utils

import (
"encoding/json"
"fmt"
"os"

"github.com/golang-jwt/jwt"
)

// Given a JSON file, map the contents into any struct dest
func FileMapper[T any](filename string, dest T) error {
file, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("%s not found", filename)
}
if err = json.Unmarshal(file, dest); err != nil {
return err
}
return nil
}

// Given a raw jwt token and an encryption key return the mapped jwt claims or an error
func GetJwtClaims(jwt_token string, key string) (jwt.MapClaims, error) {
token, token_err := jwt.Parse(jwt_token, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("invalid signing method")
}
return []byte(key), nil
})
if token_err != nil {
return nil, fmt.Errorf("could not parse jwt token")
}

// Get claims stored in parsed JWT token
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, fmt.Errorf("could not fetch jwt claims")
}
return claims, nil
}
package utils

import (
"encoding/json"
"fmt"
"os"

"github.com/golang-jwt/jwt"
)

// Given a JSON file, map the contents into any struct dest
func FileMapper[T any](filename string, dest T) error {
file, err := os.ReadFile(filename)
if err != nil {
return fmt.Errorf("%s not found", filename)
}
if err = json.Unmarshal(file, dest); err != nil {
return err
}
return nil
}

// Given a raw jwt token and an encryption key return the mapped jwt claims or an error
func GetJwtClaims(jwt_token string, key string) (claims jwt.StandardClaims, err error) {
token, token_err := jwt.Parse(jwt_token, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("invalid signing method")
}
return []byte(key), nil
})
if token_err != nil {
return jwt.StandardClaims{}, fmt.Errorf("could not parse jwt token")
}

// Get claims_map stored in parsed JWT token
claims_map, ok := token.Claims.(jwt.MapClaims)
if !ok {
return jwt.StandardClaims{}, fmt.Errorf("could not transform claim map to desired object")
}
marshalled_claims, err := json.Marshal(claims_map)
if err != nil {
return jwt.StandardClaims{}, err
}
return claims, json.Unmarshal(marshalled_claims, &claims)
}

0 comments on commit 72289c2

Please sign in to comment.