diff --git a/internal/api/api.go b/internal/api/api.go index 53949081..d0f9444b 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -2,6 +2,7 @@ package api import ( "net/http" + "time" limits "github.com/gin-contrib/size" "github.com/gin-gonic/gin" @@ -12,9 +13,11 @@ import ( "github.com/systemli/ticker/internal/api/middleware/me" "github.com/systemli/ticker/internal/api/middleware/message" "github.com/systemli/ticker/internal/api/middleware/prometheus" + "github.com/systemli/ticker/internal/api/middleware/response_cache" "github.com/systemli/ticker/internal/api/middleware/ticker" "github.com/systemli/ticker/internal/api/middleware/user" "github.com/systemli/ticker/internal/bridge" + "github.com/systemli/ticker/internal/cache" "github.com/systemli/ticker/internal/config" "github.com/systemli/ticker/internal/storage" ) @@ -48,6 +51,10 @@ func API(config config.Config, storage storage.Storage, log *logrus.Logger) *gin bridges: bridge.RegisterBridges(config, storage), } + // TOOD: Make this configurable via config file + cacheTtl := 30 * time.Second + inMemoryCache := cache.NewCache(5 * time.Minute) + gin.SetMode(gin.ReleaseMode) r := gin.New() @@ -107,9 +114,9 @@ func API(config config.Config, storage storage.Storage, log *logrus.Logger) *gin { public.POST(`/admin/login`, authMiddleware.LoginHandler) - public.GET(`/init`, handler.GetInit) - public.GET(`/timeline`, ticker.PrefetchTickerFromRequest(storage), handler.GetTimeline) - public.GET(`/feed`, ticker.PrefetchTickerFromRequest(storage), handler.GetFeed) + public.GET(`/init`, response_cache.CachePage(inMemoryCache, cacheTtl, handler.GetInit)) + public.GET(`/timeline`, ticker.PrefetchTickerFromRequest(storage), response_cache.CachePage(inMemoryCache, cacheTtl, handler.GetTimeline)) + public.GET(`/feed`, ticker.PrefetchTickerFromRequest(storage), response_cache.CachePage(inMemoryCache, cacheTtl, handler.GetFeed)) } r.GET(`/media/:fileName`, handler.GetMedia) diff --git a/internal/api/middleware/response_cache/response_cache.go b/internal/api/middleware/response_cache/response_cache.go new file mode 100644 index 00000000..fd4aae61 --- /dev/null +++ b/internal/api/middleware/response_cache/response_cache.go @@ -0,0 +1,127 @@ +package response_cache + +import ( + "fmt" + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/systemli/ticker/internal/api/helper" + "github.com/systemli/ticker/internal/cache" +) + +// responseCache is a struct to cache the response +type responseCache struct { + Status int + Header http.Header + Body []byte +} + +// cachedWriter is a wrapper around the gin.ResponseWriter +var _ gin.ResponseWriter = &cachedWriter{} + +// cachedWriter is a wrapper around the gin.ResponseWriter +type cachedWriter struct { + gin.ResponseWriter + status int + written bool + key string + expires time.Duration + cache *cache.Cache +} + +// WriteHeader is a wrapper around the gin.ResponseWriter.WriteHeader +func (w *cachedWriter) WriteHeader(code int) { + w.status = code + w.written = true + w.ResponseWriter.WriteHeader(code) +} + +// Status is a wrapper around the gin.ResponseWriter.Status +func (w *cachedWriter) Status() int { + return w.ResponseWriter.Status() +} + +// Written is a wrapper around the gin.ResponseWriter.Written +func (w *cachedWriter) Written() bool { + return w.ResponseWriter.Written() +} + +// Write is a wrapper around the gin.ResponseWriter.Write +// It will cache the response if the status code is below 300 +func (w *cachedWriter) Write(data []byte) (int, error) { + ret, err := w.ResponseWriter.Write(data) + if err == nil && w.Status() < 300 { + value := responseCache{ + Status: w.Status(), + Header: w.Header(), + Body: data, + } + w.cache.Set(w.key, value, w.expires) + } + + return ret, err +} + +// WriteString is a wrapper around the gin.ResponseWriter.WriteString +// It will cache the response if the status code is below 300 +func (w *cachedWriter) WriteString(s string) (int, error) { + ret, err := w.ResponseWriter.WriteString(s) + if err == nil && w.Status() < 300 { + value := responseCache{ + Status: w.Status(), + Header: w.Header(), + Body: []byte(s), + } + w.cache.Set(w.key, value, w.expires) + } + + return ret, err +} + +func newCachedWriter(w gin.ResponseWriter, cache *cache.Cache, key string, expires time.Duration) *cachedWriter { + return &cachedWriter{ + ResponseWriter: w, + cache: cache, + key: key, + expires: expires, + } +} + +// CachePage is a middleware to cache the response of a request +func CachePage(cache *cache.Cache, expires time.Duration, handle gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + key := CreateKey(c) + if value, exists := cache.Get(key); exists { + v := value.(responseCache) + for k, values := range v.Header { + for _, value := range values { + c.Writer.Header().Add(k, value) + } + } + c.Writer.WriteHeader(v.Status) + _, _ = c.Writer.Write(v.Body) + + return + } else { + writer := newCachedWriter(c.Writer, cache, key, expires) + c.Writer = writer + handle(c) + + if c.IsAborted() { + cache.Delete(key) + } + } + } +} + +func CreateKey(c *gin.Context) string { + domain, err := helper.GetDomain(c) + if err != nil { + domain = "unknown" + } + name := c.HandlerName() + query := c.Request.URL.Query().Encode() + + return fmt.Sprintf("response:%s:%s:%s", domain, name, query) +} diff --git a/internal/api/middleware/response_cache/response_cache_test.go b/internal/api/middleware/response_cache/response_cache_test.go new file mode 100644 index 00000000..1df7e9d3 --- /dev/null +++ b/internal/api/middleware/response_cache/response_cache_test.go @@ -0,0 +1,53 @@ +package response_cache + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/systemli/ticker/internal/cache" +) + +func TestCreateKey(t *testing.T) { + c := gin.Context{ + Request: &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/api/v1/settings", RawQuery: "origin=localhost"}, + }, + } + + key := CreateKey(&c) + assert.Equal(t, "response:localhost::origin=localhost", key) + + c.Request.URL.RawQuery = "" + + key = CreateKey(&c) + assert.Equal(t, "response:unknown::", key) +} + +func TestCachePage(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = &http.Request{ + Method: "GET", + URL: &url.URL{Path: "/ping", RawQuery: "origin=localhost"}, + } + + inMemoryCache := cache.NewCache(time.Minute) + defer inMemoryCache.Close() + CachePage(inMemoryCache, time.Minute, func(c *gin.Context) { + c.String(http.StatusOK, "pong") + })(c) + + assert.Equal(t, http.StatusOK, w.Code) + + CachePage(inMemoryCache, time.Minute, func(c *gin.Context) { + c.String(http.StatusOK, "pong") + })(c) + + assert.Equal(t, http.StatusOK, w.Code) +} diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 00000000..b149c2a5 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,117 @@ +package cache + +import ( + "sync" + "time" + + "github.com/sirupsen/logrus" +) + +var log = logrus.WithField("package", "cache") + +// Cache is a simple in-memory cache with expiration. +type Cache struct { + items sync.Map + close chan struct{} +} + +type item struct { + data interface{} + expires int64 +} + +// NewCache creates a new cache with a cleaning interval. +func NewCache(cleaningInterval time.Duration) *Cache { + cache := &Cache{ + close: make(chan struct{}), + } + + go func() { + ticker := time.NewTicker(cleaningInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + now := time.Now().UnixNano() + + cache.items.Range(func(key, value interface{}) bool { + item := value.(item) + + if item.expires > 0 && now > item.expires { + cache.items.Delete(key) + } + + return true + }) + + case <-cache.close: + return + } + } + }() + + return cache +} + +// Get returns a value from the cache. +func (cache *Cache) Get(key interface{}) (interface{}, bool) { + obj, exists := cache.items.Load(key) + + if !exists { + log.WithField("key", key).Debug("cache miss") + return nil, false + } + + item := obj.(item) + + if item.expires > 0 && time.Now().UnixNano() > item.expires { + log.WithField("key", key).Debug("cache expired") + return nil, false + } + + log.WithField("key", key).Debug("cache hit") + return item.data, true +} + +// Set stores a value in the cache. +func (cache *Cache) Set(key interface{}, value interface{}, duration time.Duration) { + var expires int64 + + if duration > 0 { + expires = time.Now().Add(duration).UnixNano() + } + + cache.items.Store(key, item{ + data: value, + expires: expires, + }) +} + +// Range loops over all items in the cache. +func (cache *Cache) Range(f func(key, value interface{}) bool) { + now := time.Now().UnixNano() + + fn := func(key, value interface{}) bool { + item := value.(item) + + if item.expires > 0 && now > item.expires { + return true + } + + return f(key, item.data) + } + + cache.items.Range(fn) +} + +// Delete removes a value from the cache. +func (cache *Cache) Delete(key interface{}) { + cache.items.Delete(key) +} + +// Close stops the cleaning interval and clears the cache. +func (cache *Cache) Close() { + cache.close <- struct{}{} + cache.items = sync.Map{} +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 00000000..1d758828 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,128 @@ +package cache + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestCache(t *testing.T) { + interval := 100 * time.Microsecond + c := NewCache(interval) + defer c.Close() + + c.Set("foo", "bar", 0) + c.Set("baz", "qux", interval/2) + + baz, found := c.Get("baz") + assert.True(t, found) + assert.Equal(t, "qux", baz) + + time.Sleep(interval / 2) + + _, found = c.Get("baz") + assert.False(t, found) + + time.Sleep(interval) + + _, found = c.Get("404") + assert.False(t, found) + + foo, found := c.Get("foo") + assert.True(t, found) + assert.Equal(t, "bar", foo) +} + +func TestDelete(t *testing.T) { + c := NewCache(time.Minute) + c.Set("foo", "bar", time.Hour) + + _, found := c.Get("foo") + assert.True(t, found) + + c.Delete("foo") + + _, found = c.Get("foo") + assert.False(t, found) +} + +func TestRange(t *testing.T) { + c := NewCache(time.Minute) + c.Set("foo", "bar", time.Hour) + c.Set("baz", "qux", time.Hour) + + count := 0 + c.Range(func(key, value interface{}) bool { + count++ + return true + }) + assert.Equal(t, 2, count) +} + +func TestRangeTimer(t *testing.T) { + c := NewCache(time.Minute) + c.Set("foo", "bar", time.Nanosecond) + c.Set("baz", "qux", time.Nanosecond) + + time.Sleep(time.Microsecond) + + c.Range(func(key, value interface{}) bool { + assert.Fail(t, "should not be called") + return true + }) +} + +func BenchmarkNew(b *testing.B) { + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + NewCache(5 * time.Second).Close() + } + }) +} + +func BenchmarkGet(b *testing.B) { + c := NewCache(5 * time.Second) + defer c.Close() + + c.Set("foo", "bar", 0) + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c.Get("foo") + } + }) +} + +func BenchmarkSet(b *testing.B) { + c := NewCache(5 * time.Second) + defer c.Close() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c.Set("foo", "bar", 0) + } + }) +} + +func BenchmarkDelete(b *testing.B) { + c := NewCache(5 * time.Second) + defer c.Close() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c.Delete("foo") + } + }) +}