Skip to content

Commit

Permalink
Replace global map with Cache struct
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-strigo committed Sep 6, 2023
1 parent 90255e8 commit 83ff12f
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 65 deletions.
45 changes: 45 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package traefik_auth_middleware

import (
"fmt"
"sync"
"time"
)

const SIZE = 1024

type Cache struct {
sync.RWMutex

dirty map[string]Token
}

// Get token from cache. If token not found return status false.
func (c *Cache) Get(key string) (token Token, ok bool) {
c.RLock()
token, ok = c.dirty[key]
c.RUnlock()
return token, ok
}

// Store a token inside cache
func (c *Cache) Store(key string, t Token) {
c.Lock()
if c.dirty == nil {
c.dirty = make(map[string]Token, SIZE)
}
c.dirty[key] = t
c.Unlock()
}

// Clears cache of any expired tokens
func (c *Cache) ClearExpired() {
c.Lock()
for k, v := range c.dirty {
if v.ExpirationTime.Before(time.Now()) {
fmt.Println("deleting")
delete(c.dirty, k)
}
}
c.Unlock()
}
63 changes: 63 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package traefik_auth_middleware

import (
"testing"
"time"
)

func TestEmpty(t *testing.T) {
cache := Cache{}

_, ok := cache.Get("foo")
if ok {
t.Error("Expected get on empty cache to be empty, but got ok")
}

// check that call to ClearExpired doesn't blow up if cache empty
cache.ClearExpired()
}

func TestCache(t *testing.T) {
cache := Cache{}

items := map[string]Token{
"foo": {"fooAccessor", "fooSecret", time.Now()},
"bar": {"barAccessor", "barSecret", time.Now()},
"baz": {"bazAccessor", "bazSecret", time.Now()},
}

for k, v := range items {
cache.Store(k, v)
}

for k, v := range items {
rv, ok := cache.Get(k)
if !ok {
t.Errorf("exected %v to be found in cache, but didn't", k)
}
if rv != v {
t.Errorf("exected %v but got %v", v, rv)
}
}
}

func TestCacheExpiry(t *testing.T) {
cache := Cache{}

items := map[string]Token{
"foo": {"fooAccessor", "fooSecret", time.Now().Add(time.Hour)},
"bar": {"barAccessor", "barSecret", time.Now().Add(time.Hour)},
"baz": {"bazAccessor", "bazSecret", time.Now()},
}

for k, v := range items {
cache.Store(k, v)
}

cache.ClearExpired()

if _, ok := cache.Get("baz"); ok {
t.Errorf("expired item still returned from cache")
}

}
54 changes: 54 additions & 0 deletions nomad.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package traefik_auth_middleware

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
)

type Token struct {
AccessorID string `json:"AccessorID"`
SecretID string `json:"SecretID"`
ExpirationTime time.Time `json:"ExpirationTime"`
}

type LoginRequestBody struct {
AuthMethodName string
LoginToken string
}

// Login to Nomad with jwt and return a Token
func (p *Plugin) login(jwt string) (Token, error) {
req_body, err := json.Marshal(LoginRequestBody{p.config.AuthMethodName, jwt})
if err != nil {
return Token{}, err
}

url, err := url.JoinPath(p.config.NomadEndpoint, "v1", "acl/login")
if err != nil {
return Token{}, err
}

resp, err := p.client.Post(url, "application/json", bytes.NewReader(req_body))
if err != nil {
return Token{}, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return Token{}, fmt.Errorf("unexpected return code (%v) from nomad", resp.StatusCode)
}

resp_body, err := io.ReadAll(resp.Body)
if err != nil {
return Token{}, err
}
var token Token
json.Unmarshal(resp_body, &token)

return token, nil
}
90 changes: 25 additions & 65 deletions plugin.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,24 @@
package traefik_auth_middleware

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"time"
)

const (
CF_HEADER = "Cf-Access-Jwt-Assertion"
NOMAD_HEADER = "X-Nomad-Token"
CF_HEADER = "Cf-Access-Jwt-Assertion"
NOMAD_HEADER = "X-Nomad-Token"
CACHE_CLEAR_CYCLE_HRS = 1
)

var (
Cache map[string]Token
)
var tokenCache Cache

type Config struct {
NomadEndpoint string `json:"nomadEndpoint,omitempty"`
NomadEndpoint string `json:"nomadEndpoint,omitempty"`
AuthMethodName string `json:"authMethodName,omitempty"`
}

Expand All @@ -34,27 +29,35 @@ func CreateConfig() *Config {
}

type Plugin struct {
next http.Handler
name string
config *Config
client *http.Client
logger *log.Logger
next http.Handler
name string
config *Config
client *http.Client
logger *log.Logger
}

// Initiate new plugin instance
func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) {
Cache = make(map[string]Token, 1024)
// Start cache clearing cycle to remove any expired tokens
go func() {
for {
time.Sleep(CACHE_CLEAR_CYCLE_HRS * time.Hour)
tokenCache.ClearExpired()
}
}()

return &Plugin{
next: next,
name: name,
next: next,
name: name,
config: config,
client: &http.Client{},
logger: log.New(os.Stderr, fmt.Sprintf("[%v] " ,name), log.Ltime | log.Lmicroseconds),
logger: log.New(os.Stderr, fmt.Sprintf("[%v] ", name), log.Ltime|log.Lmicroseconds),
}, nil
}

// Handle HTTP request in the middleware chain
func (p *Plugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
cfjwt :=req.Header.Get(CF_HEADER)
cfjwt := req.Header.Get(CF_HEADER)
if cfjwt == "" {
p.logger.Println("No Cf-Access-Jwt-Assertion header found")
p.next.ServeHTTP(rw, req)
Expand All @@ -63,7 +66,7 @@ func (p *Plugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

// Check if token already cached and valid. If not, reach out to Nomad to
// get a new one and cache it.
token, ok := Cache[cfjwt]
token, ok := tokenCache.Get(cfjwt)
if !ok || time.Now().UTC().After(token.ExpirationTime) {
var err error

Expand All @@ -76,53 +79,10 @@ func (p *Plugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
return
}

Cache[cfjwt] = token
tokenCache.Store(cfjwt, token)
}

req.Header.Set(NOMAD_HEADER, token.SecretID)

p.next.ServeHTTP(rw, req)
}

type Token struct {
AccessorID string `json:"AccessorID"`
SecretID string `json:"SecretID"`
ExpirationTime time.Time `json:"ExpirationTime"`
}

type LoginRequestBody struct {
AuthMethodName string
LoginToken string
}

// Login to Nomad with jwt and return a Token
func (p *Plugin) login(jwt string) (Token, error) {
req_body, err := json.Marshal(LoginRequestBody{p.config.AuthMethodName, jwt})
if err != nil {
return Token{}, err
}

url, err := url.JoinPath(p.config.NomadEndpoint, "v1", "acl/login")
if err != nil {
return Token{}, err
}

resp, err := p.client.Post(url, "application/json", bytes.NewReader(req_body))
if err != nil {
return Token{}, err
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return Token{}, fmt.Errorf("unexpected return code (%v) from nomad", resp.StatusCode)
}

resp_body, err := io.ReadAll(resp.Body)
if err != nil {
return Token{}, err
}
var token Token
json.Unmarshal(resp_body, &token)

return token, nil
}

0 comments on commit 83ff12f

Please sign in to comment.