diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..fa21c7a --- /dev/null +++ b/cache.go @@ -0,0 +1,45 @@ +package traefik_auth_middleware + +import ( + "fmt" + "sync" + "time" +) + +const SIZE = 1024 + +type Cache struct { + sync.RWMutex + + dirty map[string]Token +} + +// Get token from cache. If token not found return status false. +func (c *Cache) Get(key string) (token Token, ok bool) { + c.RLock() + token, ok = c.dirty[key] + c.RUnlock() + return token, ok +} + +// Store a token inside cache +func (c *Cache) Store(key string, t Token) { + c.Lock() + if c.dirty == nil { + c.dirty = make(map[string]Token, SIZE) + } + c.dirty[key] = t + c.Unlock() +} + +// Clears cache of any expired tokens +func (c *Cache) ClearExpired() { + c.Lock() + for k, v := range c.dirty { + if v.ExpirationTime.Before(time.Now()) { + fmt.Println("deleting") + delete(c.dirty, k) + } + } + c.Unlock() +} diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..72f72f1 --- /dev/null +++ b/cache_test.go @@ -0,0 +1,63 @@ +package traefik_auth_middleware + +import ( + "testing" + "time" +) + +func TestEmpty(t *testing.T) { + cache := Cache{} + + _, ok := cache.Get("foo") + if ok { + t.Error("Expected get on empty cache to be empty, but got ok") + } + + // check that call to ClearExpired doesn't blow up if cache empty + cache.ClearExpired() +} + +func TestCache(t *testing.T) { + cache := Cache{} + + items := map[string]Token{ + "foo": {"fooAccessor", "fooSecret", time.Now()}, + "bar": {"barAccessor", "barSecret", time.Now()}, + "baz": {"bazAccessor", "bazSecret", time.Now()}, + } + + for k, v := range items { + cache.Store(k, v) + } + + for k, v := range items { + rv, ok := cache.Get(k) + if !ok { + t.Errorf("exected %v to be found in cache, but didn't", k) + } + if rv != v { + t.Errorf("exected %v but got %v", v, rv) + } + } +} + +func TestCacheExpiry(t *testing.T) { + cache := Cache{} + + items := map[string]Token{ + "foo": {"fooAccessor", "fooSecret", time.Now().Add(time.Hour)}, + "bar": {"barAccessor", "barSecret", time.Now().Add(time.Hour)}, + "baz": {"bazAccessor", "bazSecret", time.Now()}, + } + + for k, v := range items { + cache.Store(k, v) + } + + cache.ClearExpired() + + if _, ok := cache.Get("baz"); ok { + t.Errorf("expired item still returned from cache") + } + +} diff --git a/nomad.go b/nomad.go new file mode 100644 index 0000000..d728ee6 --- /dev/null +++ b/nomad.go @@ -0,0 +1,54 @@ +package traefik_auth_middleware + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type Token struct { + AccessorID string `json:"AccessorID"` + SecretID string `json:"SecretID"` + ExpirationTime time.Time `json:"ExpirationTime"` +} + +type LoginRequestBody struct { + AuthMethodName string + LoginToken string +} + +// Login to Nomad with jwt and return a Token +func (p *Plugin) login(jwt string) (Token, error) { + req_body, err := json.Marshal(LoginRequestBody{p.config.AuthMethodName, jwt}) + if err != nil { + return Token{}, err + } + + url, err := url.JoinPath(p.config.NomadEndpoint, "v1", "acl/login") + if err != nil { + return Token{}, err + } + + resp, err := p.client.Post(url, "application/json", bytes.NewReader(req_body)) + if err != nil { + return Token{}, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return Token{}, fmt.Errorf("unexpected return code (%v) from nomad", resp.StatusCode) + } + + resp_body, err := io.ReadAll(resp.Body) + if err != nil { + return Token{}, err + } + var token Token + json.Unmarshal(resp_body, &token) + + return token, nil +} diff --git a/plugin.go b/plugin.go index 3496882..3bc1913 100644 --- a/plugin.go +++ b/plugin.go @@ -1,29 +1,24 @@ package traefik_auth_middleware import ( - "bytes" "context" - "encoding/json" "fmt" - "io" "log" "net/http" - "net/url" "os" "time" ) const ( - CF_HEADER = "Cf-Access-Jwt-Assertion" - NOMAD_HEADER = "X-Nomad-Token" + CF_HEADER = "Cf-Access-Jwt-Assertion" + NOMAD_HEADER = "X-Nomad-Token" + CACHE_CLEAR_CYCLE_HRS = 1 ) -var ( - Cache map[string]Token -) +var tokenCache Cache type Config struct { - NomadEndpoint string `json:"nomadEndpoint,omitempty"` + NomadEndpoint string `json:"nomadEndpoint,omitempty"` AuthMethodName string `json:"authMethodName,omitempty"` } @@ -34,27 +29,35 @@ func CreateConfig() *Config { } type Plugin struct { - next http.Handler - name string - config *Config - client *http.Client - logger *log.Logger + next http.Handler + name string + config *Config + client *http.Client + logger *log.Logger } +// Initiate new plugin instance func New(ctx context.Context, next http.Handler, config *Config, name string) (http.Handler, error) { - Cache = make(map[string]Token, 1024) + // Start cache clearing cycle to remove any expired tokens + go func() { + for { + time.Sleep(CACHE_CLEAR_CYCLE_HRS * time.Hour) + tokenCache.ClearExpired() + } + }() + return &Plugin{ - next: next, - name: name, + next: next, + name: name, config: config, client: &http.Client{}, - logger: log.New(os.Stderr, fmt.Sprintf("[%v] " ,name), log.Ltime | log.Lmicroseconds), + logger: log.New(os.Stderr, fmt.Sprintf("[%v] ", name), log.Ltime|log.Lmicroseconds), }, nil } // Handle HTTP request in the middleware chain func (p *Plugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - cfjwt :=req.Header.Get(CF_HEADER) + cfjwt := req.Header.Get(CF_HEADER) if cfjwt == "" { p.logger.Println("No Cf-Access-Jwt-Assertion header found") p.next.ServeHTTP(rw, req) @@ -63,7 +66,7 @@ func (p *Plugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Check if token already cached and valid. If not, reach out to Nomad to // get a new one and cache it. - token, ok := Cache[cfjwt] + token, ok := tokenCache.Get(cfjwt) if !ok || time.Now().UTC().After(token.ExpirationTime) { var err error @@ -76,53 +79,10 @@ func (p *Plugin) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - Cache[cfjwt] = token + tokenCache.Store(cfjwt, token) } req.Header.Set(NOMAD_HEADER, token.SecretID) p.next.ServeHTTP(rw, req) } - -type Token struct { - AccessorID string `json:"AccessorID"` - SecretID string `json:"SecretID"` - ExpirationTime time.Time `json:"ExpirationTime"` -} - -type LoginRequestBody struct { - AuthMethodName string - LoginToken string -} - -// Login to Nomad with jwt and return a Token -func (p *Plugin) login(jwt string) (Token, error) { - req_body, err := json.Marshal(LoginRequestBody{p.config.AuthMethodName, jwt}) - if err != nil { - return Token{}, err - } - - url, err := url.JoinPath(p.config.NomadEndpoint, "v1", "acl/login") - if err != nil { - return Token{}, err - } - - resp, err := p.client.Post(url, "application/json", bytes.NewReader(req_body)) - if err != nil { - return Token{}, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return Token{}, fmt.Errorf("unexpected return code (%v) from nomad", resp.StatusCode) - } - - resp_body, err := io.ReadAll(resp.Body) - if err != nil { - return Token{}, err - } - var token Token - json.Unmarshal(resp_body, &token) - - return token, nil -}