Skip to content

Commit

Permalink
Merge pull request #144 from ksysoev/context_getters
Browse files Browse the repository at this point in the history
Add GetClientIP and GetStash functions with tests
  • Loading branch information
ksysoev authored Nov 23, 2024
2 parents bebbb8f + c3919ab commit b8fa728
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 0 deletions.
11 changes: 11 additions & 0 deletions middleware/http/clientip.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ func NewClientIPMiddleware(provider Provider) func(next http.Handler) http.Handl
}
}

// GetClientIP retrieves the client IP address from the provided context.
// It takes a single parameter ctx of type context.Context.
// It returns a string representing the client IP address if found, otherwise an empty string.
func GetClientIP(ctx context.Context) string {
if ip, ok := ctx.Value(ClientIP).(string); ok {
return ip
}

return ""
}

func getIPFromRequest(provider Provider, r *http.Request) string {
var ip string

Expand Down
23 changes: 23 additions & 0 deletions middleware/http/clientip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,26 @@ func TestNewClientIPMiddleware(t *testing.T) {
t.Errorf("Expected IP to be 192.168.0.2, but got %s", ip)
}
}

func TestGetClientIP(t *testing.T) {
// Test with IP in context
ctx := context.WithValue(context.Background(), ClientIP, "192.168.0.1")

if ip := GetClientIP(ctx); ip != "192.168.0.1" {
t.Errorf("Expected IP to be 192.168.0.1, but got %s", ip)
}

// Test with no IP in context
ctx = context.Background()

if ip := GetClientIP(ctx); ip != "" {
t.Errorf("Expected IP to be empty, but got %s", ip)
}

// Test with non-string value in context
ctx = context.WithValue(context.Background(), ClientIP, 12345)

if ip := GetClientIP(ctx); ip != "" {
t.Errorf("Expected IP to be empty, but got %s", ip)
}
}
11 changes: 11 additions & 0 deletions middleware/http/stash.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,14 @@ func NewStashMiddleware() func(next http.Handler) http.Handler {
})
}
}

// GetStash retrieves the stash from the provided context.
// It takes a single parameter ctx of type context.Context.
// It returns a pointer to a sync.Map representing the stash if found, otherwise nil.
func GetStash(ctx context.Context) *sync.Map {
if stash, ok := ctx.Value(Stash).(*sync.Map); ok {
return stash
}

return nil
}
25 changes: 25 additions & 0 deletions middleware/http/stash_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package http

import (
"context"
"net/http"
"net/http/httptest"
"sync"
Expand Down Expand Up @@ -42,3 +43,27 @@ func TestNewStashMiddleware(t *testing.T) {
t.Errorf("Expected status code %d, but got %d", http.StatusOK, resp.StatusCode)
}
}

func TestGetStash(t *testing.T) {
ctx := context.Background()
stash := &sync.Map{}
ctx = context.WithValue(ctx, Stash, stash)

retrievedStash := GetStash(ctx)
if retrievedStash == nil {
t.Errorf("Expected stash to be retrieved from context, but got nil")
}

if retrievedStash != stash {
t.Errorf("Expected retrieved stash to be %v, but got %v", stash, retrievedStash)
}
}

func TestGetStash_NotFound(t *testing.T) {
ctx := context.Background()

retrievedStash := GetStash(ctx)
if retrievedStash != nil {
t.Errorf("Expected nil when stash is not found in context, but got %v", retrievedStash)
}
}

0 comments on commit b8fa728

Please sign in to comment.