Skip to content

Commit

Permalink
Refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
minhduc140583 committed Apr 14, 2024
1 parent 14c9f15 commit 0c2f6c7
Show file tree
Hide file tree
Showing 8 changed files with 1,267 additions and 181 deletions.
20 changes: 16 additions & 4 deletions authorizer/session_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"time"
)

type ICacheService interface {
type CachePort interface {
Get(ctx context.Context, key string) (string, error)
Remove(ctx context.Context, key string) (bool, error)
Expire(ctx context.Context, key string, timeToLive time.Duration) (bool, error)
Expand All @@ -28,14 +28,14 @@ type SessionAuthorizer struct {
DecodeSessionID func(value string) (string, error)
EncodeSessionID func(sid string) string
VerifyToken func(tokenString string, secret string) (map[string]interface{}, int64, int64, error)
Cache ICacheService
Cache CachePort
sessionExpiredTime time.Duration
LogError func(ctx context.Context, msg string, opts ...map[string]interface{})
}

func NewSessionAuthorizer(secretKey string, verifyToken func(tokenString string, secret string) (map[string]interface{}, int64, int64, error),
refreshExpire func(w http.ResponseWriter, sessionId string) error,
cache ICacheService, sessionExpiredTime time.Duration, logError func(ctx context.Context, msg string, opts ...map[string]interface{}), singleSession bool,
cache CachePort, sessionExpiredTime time.Duration, logError func(ctx context.Context, msg string, opts ...map[string]interface{}), singleSession bool,
encodeSessionID func(sid string) string,
decodeSessionID func(value string) (string, error),
opts ...string) *SessionAuthorizer {
Expand Down Expand Up @@ -139,6 +139,9 @@ func (h *SessionAuthorizer) Authorize(next http.Handler, skipRefreshTTL bool) ht
return
}
ip := getForwardedRemoteIp(r)
if len(ip) == 0 {
ip = getRemoteIp(r)
}
sid, ok := uData[h.SId]
if !ok || sid != sessionId ||
getValue(uData, "userAgent") != r.UserAgent() ||
Expand Down Expand Up @@ -180,6 +183,9 @@ func (h *SessionAuthorizer) Verify(next http.Handler, skipRefreshTTL bool, sessi
return
}
ip := getForwardedRemoteIp(r)
if len(ip) == 0 {
ip = getRemoteIp(r)
}
ctx = context.WithValue(ctx, "ip", ip)
for k, e := range payload {
if len(k) > 0 {
Expand Down Expand Up @@ -232,7 +238,13 @@ func getForwardedRemoteIp(r *http.Request) string {
}
return ""
}

func getRemoteIp(r *http.Request) string {
remoteIP, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
remoteIP = r.RemoteAddr
}
return remoteIP
}
func getValue(data map[string]interface{}, key string) string {
if value, ok := data[key]; ok {
return value.(string)
Expand Down
Loading

0 comments on commit 0c2f6c7

Please sign in to comment.