From 2564dcee7ca8f1770640e1d4f73424837d01acc3 Mon Sep 17 00:00:00 2001 From: louis Date: Fri, 20 Oct 2023 20:04:10 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=97=83=EF=B8=8F=20Make=20Queries=20config?= =?UTF-8?q?urable=20to=20preload=20associations?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/api/api.go | 52 ++-- internal/api/init.go | 3 +- internal/api/init_test.go | 4 +- internal/api/messages.go | 2 +- internal/api/messages_test.go | 6 +- internal/api/middleware/message/message.go | 2 +- .../api/middleware/message/message_test.go | 4 +- internal/api/middleware/ticker/ticker.go | 9 +- internal/api/tickers_test.go | 6 +- internal/storage/mock_Storage.go | 175 +++++++---- internal/storage/sql_storage.go | 65 +++- internal/storage/sql_storage_test.go | 278 +++++++++++++----- internal/storage/storage.go | 15 +- 13 files changed, 418 insertions(+), 203 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index d0f9444b..5bc7b72b 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -44,11 +44,11 @@ type handler struct { // @host localhost:8080 // @BasePath /v1 -func API(config config.Config, storage storage.Storage, log *logrus.Logger) *gin.Engine { +func API(config config.Config, store storage.Storage, log *logrus.Logger) *gin.Engine { handler := handler{ config: config, - storage: storage, - bridges: bridge.RegisterBridges(config, storage), + storage: store, + bridges: bridge.RegisterBridges(config, store), } // TOOD: Make this configurable via config file @@ -65,11 +65,11 @@ func API(config config.Config, storage storage.Storage, log *logrus.Logger) *gin r.Use(limits.RequestSizeLimiter(1024 * 1024 * 10)) // the jwt middleware - authMiddleware := auth.AuthMiddleware(storage, config.Secret) + authMiddleware := auth.AuthMiddleware(store, config.Secret) admin := r.Group("/v1/admin") { - meMiddleware := me.MeMiddleware(storage) + meMiddleware := me.MeMiddleware(store) admin.Use(authMiddleware.MiddlewareFunc()) admin.Use(meMiddleware) @@ -78,32 +78,32 @@ func API(config config.Config, storage storage.Storage, log *logrus.Logger) *gin admin.GET("/features", handler.GetFeatures) admin.GET(`/tickers`, handler.GetTickers) - admin.GET(`/tickers/:tickerID`, ticker.PrefetchTicker(storage), handler.GetTicker) + admin.GET(`/tickers/:tickerID`, ticker.PrefetchTicker(store, storage.WithPreload()), handler.GetTicker) admin.POST(`/tickers`, user.NeedAdmin(), handler.PostTicker) - admin.PUT(`/tickers/:tickerID`, ticker.PrefetchTicker(storage), handler.PutTicker) - admin.PUT(`/tickers/:tickerID/telegram`, ticker.PrefetchTicker(storage), handler.PutTickerTelegram) - admin.DELETE(`/tickers/:tickerID/telegram`, ticker.PrefetchTicker(storage), handler.DeleteTickerTelegram) - admin.PUT(`/tickers/:tickerID/mastodon`, ticker.PrefetchTicker(storage), handler.PutTickerMastodon) - admin.DELETE(`/tickers/:tickerID/mastodon`, ticker.PrefetchTicker(storage), handler.DeleteTickerMastodon) - admin.DELETE(`/tickers/:tickerID`, user.NeedAdmin(), ticker.PrefetchTicker(storage), handler.DeleteTicker) - admin.PUT(`/tickers/:tickerID/reset`, ticker.PrefetchTicker(storage), ticker.PrefetchTicker(storage), handler.ResetTicker) - admin.GET(`/tickers/:tickerID/users`, ticker.PrefetchTicker(storage), handler.GetTickerUsers) - admin.PUT(`/tickers/:tickerID/users`, user.NeedAdmin(), ticker.PrefetchTicker(storage), handler.PutTickerUsers) - admin.DELETE(`/tickers/:tickerID/users/:userID`, user.NeedAdmin(), ticker.PrefetchTicker(storage), handler.DeleteTickerUser) - - admin.GET(`/tickers/:tickerID/messages`, ticker.PrefetchTicker(storage), handler.GetMessages) - admin.GET(`/tickers/:tickerID/messages/:messageID`, ticker.PrefetchTicker(storage), message.PrefetchMessage(storage), handler.GetMessage) - admin.POST(`/tickers/:tickerID/messages`, ticker.PrefetchTicker(storage), handler.PostMessage) - admin.DELETE(`/tickers/:tickerID/messages/:messageID`, ticker.PrefetchTicker(storage), message.PrefetchMessage(storage), handler.DeleteMessage) + admin.PUT(`/tickers/:tickerID`, ticker.PrefetchTicker(store, storage.WithPreload()), handler.PutTicker) + admin.PUT(`/tickers/:tickerID/telegram`, ticker.PrefetchTicker(store, storage.WithPreload()), handler.PutTickerTelegram) + admin.DELETE(`/tickers/:tickerID/telegram`, ticker.PrefetchTicker(store, storage.WithPreload()), handler.DeleteTickerTelegram) + admin.PUT(`/tickers/:tickerID/mastodon`, ticker.PrefetchTicker(store, storage.WithPreload()), handler.PutTickerMastodon) + admin.DELETE(`/tickers/:tickerID/mastodon`, ticker.PrefetchTicker(store, storage.WithPreload()), handler.DeleteTickerMastodon) + admin.DELETE(`/tickers/:tickerID`, user.NeedAdmin(), ticker.PrefetchTicker(store), handler.DeleteTicker) + admin.PUT(`/tickers/:tickerID/reset`, ticker.PrefetchTicker(store, storage.WithPreload()), ticker.PrefetchTicker(store), handler.ResetTicker) + admin.GET(`/tickers/:tickerID/users`, ticker.PrefetchTicker(store), handler.GetTickerUsers) + admin.PUT(`/tickers/:tickerID/users`, user.NeedAdmin(), ticker.PrefetchTicker(store), handler.PutTickerUsers) + admin.DELETE(`/tickers/:tickerID/users/:userID`, user.NeedAdmin(), ticker.PrefetchTicker(store), handler.DeleteTickerUser) + + admin.GET(`/tickers/:tickerID/messages`, ticker.PrefetchTicker(store, storage.WithPreload()), handler.GetMessages) + admin.GET(`/tickers/:tickerID/messages/:messageID`, ticker.PrefetchTicker(store, storage.WithPreload()), message.PrefetchMessage(store), handler.GetMessage) + admin.POST(`/tickers/:tickerID/messages`, ticker.PrefetchTicker(store), handler.PostMessage) + admin.DELETE(`/tickers/:tickerID/messages/:messageID`, ticker.PrefetchTicker(store), message.PrefetchMessage(store), handler.DeleteMessage) admin.POST(`/upload`, handler.PostUpload) admin.GET(`/users`, user.NeedAdmin(), handler.GetUsers) - admin.GET(`/users/:userID`, user.PrefetchUser(storage), handler.GetUser) + admin.GET(`/users/:userID`, user.PrefetchUser(store), handler.GetUser) admin.POST(`/users`, user.NeedAdmin(), handler.PostUser) admin.PUT(`/users/me`, handler.PutMe) - admin.PUT(`/users/:userID`, user.NeedAdmin(), user.PrefetchUser(storage), handler.PutUser) - admin.DELETE(`/users/:userID`, user.NeedAdmin(), user.PrefetchUser(storage), handler.DeleteUser) + admin.PUT(`/users/:userID`, user.NeedAdmin(), user.PrefetchUser(store), handler.PutUser) + admin.DELETE(`/users/:userID`, user.NeedAdmin(), user.PrefetchUser(store), handler.DeleteUser) admin.GET(`/settings/:name`, user.NeedAdmin(), handler.GetSetting) admin.PUT(`/settings/inactive_settings`, user.NeedAdmin(), handler.PutInactiveSettings) @@ -115,8 +115,8 @@ func API(config config.Config, storage storage.Storage, log *logrus.Logger) *gin public.POST(`/admin/login`, authMiddleware.LoginHandler) public.GET(`/init`, response_cache.CachePage(inMemoryCache, cacheTtl, handler.GetInit)) - public.GET(`/timeline`, ticker.PrefetchTickerFromRequest(storage), response_cache.CachePage(inMemoryCache, cacheTtl, handler.GetTimeline)) - public.GET(`/feed`, ticker.PrefetchTickerFromRequest(storage), response_cache.CachePage(inMemoryCache, cacheTtl, handler.GetFeed)) + public.GET(`/timeline`, ticker.PrefetchTickerFromRequest(store), response_cache.CachePage(inMemoryCache, cacheTtl, handler.GetTimeline)) + public.GET(`/feed`, ticker.PrefetchTickerFromRequest(store), response_cache.CachePage(inMemoryCache, cacheTtl, handler.GetFeed)) } r.GET(`/media/:fileName`, handler.GetMedia) diff --git a/internal/api/init.go b/internal/api/init.go index d52b13b5..8eb7364f 100644 --- a/internal/api/init.go +++ b/internal/api/init.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/systemli/ticker/internal/api/helper" "github.com/systemli/ticker/internal/api/response" + "github.com/systemli/ticker/internal/storage" ) // GetInit returns the basic settings for the ticker. @@ -30,7 +31,7 @@ func (h *handler) GetInit(c *gin.Context) { return } - ticker, err := h.storage.FindTickerByDomain(domain) + ticker, err := h.storage.FindTickerByDomain(domain, storage.WithInformation()) if err != nil || !ticker.Active { settings.InactiveSettings = h.storage.GetInactiveSettings() c.JSON(http.StatusOK, response.SuccessResponse(map[string]interface{}{"ticker": nil, "settings": settings})) diff --git a/internal/api/init_test.go b/internal/api/init_test.go index 30cf5282..23e56add 100644 --- a/internal/api/init_test.go +++ b/internal/api/init_test.go @@ -21,7 +21,7 @@ func TestGetInit(t *testing.T) { ticker.Active = true s := &storage.MockStorage{} s.On("GetRefreshIntervalSettings").Return(storage.DefaultRefreshIntervalSettings()) - s.On("FindTickerByDomain", mock.AnythingOfType("string")).Return(ticker, nil) + s.On("FindTickerByDomain", mock.AnythingOfType("string"), mock.Anything).Return(ticker, nil) h := handler{ storage: s, @@ -60,7 +60,7 @@ func TestGetInitInactiveTicker(t *testing.T) { s := &storage.MockStorage{} s.On("GetRefreshIntervalSettings").Return(storage.DefaultRefreshIntervalSettings()) s.On("GetInactiveSettings").Return(storage.DefaultInactiveSettings()) - s.On("FindTickerByDomain", mock.AnythingOfType("string")).Return(ticker, nil) + s.On("FindTickerByDomain", mock.AnythingOfType("string"), mock.Anything).Return(ticker, nil) h := handler{ storage: s, diff --git a/internal/api/messages.go b/internal/api/messages.go index 7a6d007e..371d1ac3 100644 --- a/internal/api/messages.go +++ b/internal/api/messages.go @@ -19,7 +19,7 @@ func (h *handler) GetMessages(c *gin.Context) { } pagination := pagination.NewPagination(c) - messages, err := h.storage.FindMessagesByTickerAndPagination(ticker, *pagination) + messages, err := h.storage.FindMessagesByTickerAndPagination(ticker, *pagination, storage.WithAttachments()) if err != nil { c.JSON(http.StatusNotFound, response.ErrorResponse(response.CodeDefault, response.StorageError)) return diff --git a/internal/api/messages_test.go b/internal/api/messages_test.go index 0646057b..a15e760e 100644 --- a/internal/api/messages_test.go +++ b/internal/api/messages_test.go @@ -39,7 +39,7 @@ func TestGetMessagesStorageError(t *testing.T) { c, _ := gin.CreateTestContext(w) c.Set("ticker", storage.Ticker{}) s := &storage.MockStorage{} - s.On("FindMessagesByTickerAndPagination", mock.Anything, mock.Anything).Return([]storage.Message{}, errors.New("storage error")) + s.On("FindMessagesByTickerAndPagination", mock.Anything, mock.Anything, mock.Anything).Return([]storage.Message{}, errors.New("storage error")) h := handler{ storage: s, config: config.NewConfig(), @@ -55,7 +55,7 @@ func TestGetMessagesEmptyResult(t *testing.T) { c, _ := gin.CreateTestContext(w) c.Set("ticker", storage.Ticker{}) s := &storage.MockStorage{} - s.On("FindMessagesByTickerAndPagination", mock.Anything, mock.Anything).Return([]storage.Message{}, errors.New("not found")) + s.On("FindMessagesByTickerAndPagination", mock.Anything, mock.Anything, mock.Anything).Return([]storage.Message{}, errors.New("not found")) h := handler{ storage: s, config: config.NewConfig(), @@ -71,7 +71,7 @@ func TestGetMessages(t *testing.T) { c, _ := gin.CreateTestContext(w) c.Set("ticker", storage.Ticker{}) s := &storage.MockStorage{} - s.On("FindMessagesByTickerAndPagination", mock.Anything, mock.Anything).Return([]storage.Message{}, nil) + s.On("FindMessagesByTickerAndPagination", mock.Anything, mock.Anything, mock.Anything).Return([]storage.Message{}, nil) h := handler{ storage: s, config: config.NewConfig(), diff --git a/internal/api/middleware/message/message.go b/internal/api/middleware/message/message.go index 32bf6acd..8aeef79b 100644 --- a/internal/api/middleware/message/message.go +++ b/internal/api/middleware/message/message.go @@ -20,7 +20,7 @@ func PrefetchMessage(s storage.Storage) gin.HandlerFunc { return } - message, err := s.FindMessage(ticker.ID, messageID) + message, err := s.FindMessage(ticker.ID, messageID, storage.WithAttachments()) if err != nil { c.JSON(http.StatusNotFound, response.ErrorResponse(response.CodeNotFound, response.MessageNotFound)) return diff --git a/internal/api/middleware/message/message_test.go b/internal/api/middleware/message/message_test.go index f0f48464..320459ea 100644 --- a/internal/api/middleware/message/message_test.go +++ b/internal/api/middleware/message/message_test.go @@ -34,7 +34,7 @@ func TestPrefetchMessageStorageError(t *testing.T) { c.AddParam("messageID", "1") c.Set("ticker", storage.Ticker{}) s := &storage.MockStorage{} - s.On("FindMessage", mock.Anything, mock.Anything).Return(storage.Message{}, errors.New("storage error")) + s.On("FindMessage", mock.Anything, mock.Anything, mock.Anything).Return(storage.Message{}, errors.New("storage error")) mw := PrefetchMessage(s) mw(c) @@ -49,7 +49,7 @@ func TestPrefetchMessage(t *testing.T) { c.Set("ticker", storage.Ticker{}) s := &storage.MockStorage{} message := storage.Message{ID: 1} - s.On("FindMessage", mock.Anything, mock.Anything).Return(message, nil) + s.On("FindMessage", mock.Anything, mock.Anything, mock.Anything).Return(message, nil) mw := PrefetchMessage(s) mw(c) diff --git a/internal/api/middleware/ticker/ticker.go b/internal/api/middleware/ticker/ticker.go index a087f70e..3bd04aa6 100644 --- a/internal/api/middleware/ticker/ticker.go +++ b/internal/api/middleware/ticker/ticker.go @@ -9,9 +9,10 @@ import ( "github.com/systemli/ticker/internal/api/response" "github.com/systemli/ticker/internal/storage" "github.com/systemli/ticker/internal/util" + "gorm.io/gorm" ) -func PrefetchTicker(s storage.Storage) gin.HandlerFunc { +func PrefetchTicker(s storage.Storage, opts ...func(*gorm.DB) *gorm.DB) gin.HandlerFunc { return func(c *gin.Context) { user, _ := helper.Me(c) tickerID, err := strconv.Atoi(c.Param("tickerID")) @@ -31,7 +32,7 @@ func PrefetchTicker(s storage.Storage) gin.HandlerFunc { } } - ticker, err := s.FindTickerByID(tickerID) + ticker, err := s.FindTickerByID(tickerID, opts...) if err != nil { c.JSON(http.StatusNotFound, response.ErrorResponse(response.CodeNotFound, response.TickerNotFound)) @@ -42,7 +43,7 @@ func PrefetchTicker(s storage.Storage) gin.HandlerFunc { } } -func PrefetchTickerFromRequest(s storage.Storage) gin.HandlerFunc { +func PrefetchTickerFromRequest(s storage.Storage, opts ...func(*gorm.DB) *gorm.DB) gin.HandlerFunc { return func(c *gin.Context) { domain, err := helper.GetDomain(c) if err != nil { @@ -50,7 +51,7 @@ func PrefetchTickerFromRequest(s storage.Storage) gin.HandlerFunc { return } - ticker, err := s.FindTickerByDomain(domain) + ticker, err := s.FindTickerByDomain(domain, opts...) if err != nil { c.JSON(http.StatusOK, response.ErrorResponse(response.CodeDefault, response.TickerNotFound)) return diff --git a/internal/api/tickers_test.go b/internal/api/tickers_test.go index 3b4b025d..91f97824 100644 --- a/internal/api/tickers_test.go +++ b/internal/api/tickers_test.go @@ -40,7 +40,7 @@ func TestGetTickersStorageError(t *testing.T) { c, _ := gin.CreateTestContext(w) c.Set("me", storage.User{IsSuperAdmin: true}) s := &storage.MockStorage{} - s.On("FindTickers").Return([]storage.Ticker{}, errors.New("storage error")) + s.On("FindTickers", mock.Anything).Return([]storage.Ticker{}, errors.New("storage error")) h := handler{ storage: s, config: config.NewConfig(), @@ -56,7 +56,7 @@ func TestGetTickers(t *testing.T) { c, _ := gin.CreateTestContext(w) c.Set("me", storage.User{IsSuperAdmin: false, Tickers: []storage.Ticker{{ID: 2}}}) s := &storage.MockStorage{} - s.On("FindTickersByIDs", mock.Anything).Return([]storage.Ticker{}, nil) + s.On("FindTickersByIDs", mock.Anything, mock.Anything).Return([]storage.Ticker{}, nil) h := handler{ storage: s, config: config.NewConfig(), @@ -115,7 +115,7 @@ func TestGetTickerUsers(t *testing.T) { c, _ := gin.CreateTestContext(w) c.Set("ticker", storage.Ticker{}) s := &storage.MockStorage{} - s.On("FindUsersByTicker", mock.Anything).Return([]storage.User{}, nil) + s.On("FindUsersByTicker", mock.Anything, mock.Anything).Return([]storage.User{}, nil) h := handler{ storage: s, config: config.NewConfig(), diff --git a/internal/storage/mock_Storage.go b/internal/storage/mock_Storage.go index aaf64356..34eccccb 100644 --- a/internal/storage/mock_Storage.go +++ b/internal/storage/mock_Storage.go @@ -5,6 +5,7 @@ package storage import ( mock "github.com/stretchr/testify/mock" pagination "github.com/systemli/ticker/internal/api/pagination" + gorm "gorm.io/gorm" ) // MockStorage is an autogenerated mock type for the Storage type @@ -167,23 +168,30 @@ func (_m *MockStorage) DeleteUser(user User) error { return r0 } -// FindMessage provides a mock function with given fields: tickerID, messageID -func (_m *MockStorage) FindMessage(tickerID int, messageID int) (Message, error) { - ret := _m.Called(tickerID, messageID) +// FindMessage provides a mock function with given fields: tickerID, messageID, opts +func (_m *MockStorage) FindMessage(tickerID int, messageID int, opts ...func(*gorm.DB) *gorm.DB) (Message, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, tickerID, messageID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 Message var r1 error - if rf, ok := ret.Get(0).(func(int, int) (Message, error)); ok { - return rf(tickerID, messageID) + if rf, ok := ret.Get(0).(func(int, int, ...func(*gorm.DB) *gorm.DB) (Message, error)); ok { + return rf(tickerID, messageID, opts...) } - if rf, ok := ret.Get(0).(func(int, int) Message); ok { - r0 = rf(tickerID, messageID) + if rf, ok := ret.Get(0).(func(int, int, ...func(*gorm.DB) *gorm.DB) Message); ok { + r0 = rf(tickerID, messageID, opts...) } else { r0 = ret.Get(0).(Message) } - if rf, ok := ret.Get(1).(func(int, int) error); ok { - r1 = rf(tickerID, messageID) + if rf, ok := ret.Get(1).(func(int, int, ...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(tickerID, messageID, opts...) } else { r1 = ret.Error(1) } @@ -191,25 +199,32 @@ func (_m *MockStorage) FindMessage(tickerID int, messageID int) (Message, error) return r0, r1 } -// FindMessagesByTicker provides a mock function with given fields: ticker -func (_m *MockStorage) FindMessagesByTicker(ticker Ticker) ([]Message, error) { - ret := _m.Called(ticker) +// FindMessagesByTicker provides a mock function with given fields: ticker, opts +func (_m *MockStorage) FindMessagesByTicker(ticker Ticker, opts ...func(*gorm.DB) *gorm.DB) ([]Message, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ticker) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 []Message var r1 error - if rf, ok := ret.Get(0).(func(Ticker) ([]Message, error)); ok { - return rf(ticker) + if rf, ok := ret.Get(0).(func(Ticker, ...func(*gorm.DB) *gorm.DB) ([]Message, error)); ok { + return rf(ticker, opts...) } - if rf, ok := ret.Get(0).(func(Ticker) []Message); ok { - r0 = rf(ticker) + if rf, ok := ret.Get(0).(func(Ticker, ...func(*gorm.DB) *gorm.DB) []Message); ok { + r0 = rf(ticker, opts...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]Message) } } - if rf, ok := ret.Get(1).(func(Ticker) error); ok { - r1 = rf(ticker) + if rf, ok := ret.Get(1).(func(Ticker, ...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(ticker, opts...) } else { r1 = ret.Error(1) } @@ -217,25 +232,32 @@ func (_m *MockStorage) FindMessagesByTicker(ticker Ticker) ([]Message, error) { return r0, r1 } -// FindMessagesByTickerAndPagination provides a mock function with given fields: ticker, _a1 -func (_m *MockStorage) FindMessagesByTickerAndPagination(ticker Ticker, _a1 pagination.Pagination) ([]Message, error) { - ret := _m.Called(ticker, _a1) +// FindMessagesByTickerAndPagination provides a mock function with given fields: ticker, _a1, opts +func (_m *MockStorage) FindMessagesByTickerAndPagination(ticker Ticker, _a1 pagination.Pagination, opts ...func(*gorm.DB) *gorm.DB) ([]Message, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ticker, _a1) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 []Message var r1 error - if rf, ok := ret.Get(0).(func(Ticker, pagination.Pagination) ([]Message, error)); ok { - return rf(ticker, _a1) + if rf, ok := ret.Get(0).(func(Ticker, pagination.Pagination, ...func(*gorm.DB) *gorm.DB) ([]Message, error)); ok { + return rf(ticker, _a1, opts...) } - if rf, ok := ret.Get(0).(func(Ticker, pagination.Pagination) []Message); ok { - r0 = rf(ticker, _a1) + if rf, ok := ret.Get(0).(func(Ticker, pagination.Pagination, ...func(*gorm.DB) *gorm.DB) []Message); ok { + r0 = rf(ticker, _a1, opts...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]Message) } } - if rf, ok := ret.Get(1).(func(Ticker, pagination.Pagination) error); ok { - r1 = rf(ticker, _a1) + if rf, ok := ret.Get(1).(func(Ticker, pagination.Pagination, ...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(ticker, _a1, opts...) } else { r1 = ret.Error(1) } @@ -243,23 +265,30 @@ func (_m *MockStorage) FindMessagesByTickerAndPagination(ticker Ticker, _a1 pagi return r0, r1 } -// FindTickerByDomain provides a mock function with given fields: domain -func (_m *MockStorage) FindTickerByDomain(domain string) (Ticker, error) { - ret := _m.Called(domain) +// FindTickerByDomain provides a mock function with given fields: domain, opts +func (_m *MockStorage) FindTickerByDomain(domain string, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, domain) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 Ticker var r1 error - if rf, ok := ret.Get(0).(func(string) (Ticker, error)); ok { - return rf(domain) + if rf, ok := ret.Get(0).(func(string, ...func(*gorm.DB) *gorm.DB) (Ticker, error)); ok { + return rf(domain, opts...) } - if rf, ok := ret.Get(0).(func(string) Ticker); ok { - r0 = rf(domain) + if rf, ok := ret.Get(0).(func(string, ...func(*gorm.DB) *gorm.DB) Ticker); ok { + r0 = rf(domain, opts...) } else { r0 = ret.Get(0).(Ticker) } - if rf, ok := ret.Get(1).(func(string) error); ok { - r1 = rf(domain) + if rf, ok := ret.Get(1).(func(string, ...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(domain, opts...) } else { r1 = ret.Error(1) } @@ -267,23 +296,30 @@ func (_m *MockStorage) FindTickerByDomain(domain string) (Ticker, error) { return r0, r1 } -// FindTickerByID provides a mock function with given fields: id -func (_m *MockStorage) FindTickerByID(id int) (Ticker, error) { - ret := _m.Called(id) +// FindTickerByID provides a mock function with given fields: id, opts +func (_m *MockStorage) FindTickerByID(id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, id) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 Ticker var r1 error - if rf, ok := ret.Get(0).(func(int) (Ticker, error)); ok { - return rf(id) + if rf, ok := ret.Get(0).(func(int, ...func(*gorm.DB) *gorm.DB) (Ticker, error)); ok { + return rf(id, opts...) } - if rf, ok := ret.Get(0).(func(int) Ticker); ok { - r0 = rf(id) + if rf, ok := ret.Get(0).(func(int, ...func(*gorm.DB) *gorm.DB) Ticker); ok { + r0 = rf(id, opts...) } else { r0 = ret.Get(0).(Ticker) } - if rf, ok := ret.Get(1).(func(int) error); ok { - r1 = rf(id) + if rf, ok := ret.Get(1).(func(int, ...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(id, opts...) } else { r1 = ret.Error(1) } @@ -291,25 +327,31 @@ func (_m *MockStorage) FindTickerByID(id int) (Ticker, error) { return r0, r1 } -// FindTickers provides a mock function with given fields: -func (_m *MockStorage) FindTickers() ([]Ticker, error) { - ret := _m.Called() +// FindTickers provides a mock function with given fields: opts +func (_m *MockStorage) FindTickers(opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 []Ticker var r1 error - if rf, ok := ret.Get(0).(func() ([]Ticker, error)); ok { - return rf() + if rf, ok := ret.Get(0).(func(...func(*gorm.DB) *gorm.DB) ([]Ticker, error)); ok { + return rf(opts...) } - if rf, ok := ret.Get(0).(func() []Ticker); ok { - r0 = rf() + if rf, ok := ret.Get(0).(func(...func(*gorm.DB) *gorm.DB) []Ticker); ok { + r0 = rf(opts...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]Ticker) } } - if rf, ok := ret.Get(1).(func() error); ok { - r1 = rf() + if rf, ok := ret.Get(1).(func(...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(opts...) } else { r1 = ret.Error(1) } @@ -317,25 +359,32 @@ func (_m *MockStorage) FindTickers() ([]Ticker, error) { return r0, r1 } -// FindTickersByIDs provides a mock function with given fields: ids -func (_m *MockStorage) FindTickersByIDs(ids []int) ([]Ticker, error) { - ret := _m.Called(ids) +// FindTickersByIDs provides a mock function with given fields: ids, opts +func (_m *MockStorage) FindTickersByIDs(ids []int, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, ids) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) var r0 []Ticker var r1 error - if rf, ok := ret.Get(0).(func([]int) ([]Ticker, error)); ok { - return rf(ids) + if rf, ok := ret.Get(0).(func([]int, ...func(*gorm.DB) *gorm.DB) ([]Ticker, error)); ok { + return rf(ids, opts...) } - if rf, ok := ret.Get(0).(func([]int) []Ticker); ok { - r0 = rf(ids) + if rf, ok := ret.Get(0).(func([]int, ...func(*gorm.DB) *gorm.DB) []Ticker); ok { + r0 = rf(ids, opts...) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]Ticker) } } - if rf, ok := ret.Get(1).(func([]int) error); ok { - r1 = rf(ids) + if rf, ok := ret.Get(1).(func([]int, ...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(ids, opts...) } else { r1 = ret.Error(1) } diff --git a/internal/storage/sql_storage.go b/internal/storage/sql_storage.go index 9191d26e..02660aa3 100644 --- a/internal/storage/sql_storage.go +++ b/internal/storage/sql_storage.go @@ -91,32 +91,36 @@ func (s *SqlStorage) AddTickerUser(ticker *Ticker, user *User) error { return err } -func (s *SqlStorage) FindTickers() ([]Ticker, error) { +func (s *SqlStorage) FindTickers(opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) { tickers := make([]Ticker, 0) - err := s.db.Preload(clause.Associations).Find(&tickers).Error + db := s.prepareDb(opts...) + err := db.Find(&tickers).Error return tickers, err } -func (s *SqlStorage) FindTickersByIDs(ids []int) ([]Ticker, error) { +func (s *SqlStorage) FindTickersByIDs(ids []int, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) { tickers := make([]Ticker, 0) - err := s.db.Preload(clause.Associations).Find(&tickers, ids).Error + db := s.prepareDb(opts...) + err := db.Find(&tickers, ids).Error return tickers, err } -func (s *SqlStorage) FindTickerByDomain(domain string) (Ticker, error) { +func (s *SqlStorage) FindTickerByDomain(domain string, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) { var ticker Ticker + db := s.prepareDb(opts...) - err := s.db.Preload(clause.Associations).First(&ticker, "domain = ?", domain).Error + err := db.First(&ticker, "domain = ?", domain).Error return ticker, err } -func (s *SqlStorage) FindTickerByID(id int) (Ticker, error) { +func (s *SqlStorage) FindTickerByID(id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) { var ticker Ticker + db := s.prepareDb(opts...) - err := s.db.Preload(clause.Associations).First(&ticker, id).Error + err := db.First(&ticker, id).Error return ticker, err } @@ -191,24 +195,28 @@ func (s *SqlStorage) DeleteUploadsByTicker(ticker Ticker) error { return nil } -func (s *SqlStorage) FindMessage(tickerID, messageID int) (Message, error) { +func (s *SqlStorage) FindMessage(tickerID, messageID int, opts ...func(*gorm.DB) *gorm.DB) (Message, error) { var message Message + db := s.prepareDb(opts...) - err := s.db.Preload(clause.Associations).First(&message, "ticker_id = ? AND id = ?", tickerID, messageID).Error + err := db.First(&message, "ticker_id = ? AND id = ?", tickerID, messageID).Error return message, err } -func (s *SqlStorage) FindMessagesByTicker(ticker Ticker) ([]Message, error) { +func (s *SqlStorage) FindMessagesByTicker(ticker Ticker, opts ...func(*gorm.DB) *gorm.DB) ([]Message, error) { messages := make([]Message, 0) - err := s.db.Preload(clause.Associations).Model(&Message{}).Where("ticker_id = ?", ticker.ID).Find(&messages).Error + db := s.prepareDb(opts...) + + err := db.Model(&Message{}).Where("ticker_id = ?", ticker.ID).Find(&messages).Error return messages, err } -func (s *SqlStorage) FindMessagesByTickerAndPagination(ticker Ticker, pagination pagination.Pagination) ([]Message, error) { +func (s *SqlStorage) FindMessagesByTickerAndPagination(ticker Ticker, pagination pagination.Pagination, opts ...func(*gorm.DB) *gorm.DB) ([]Message, error) { messages := make([]Message, 0) - query := s.db.Preload(clause.Associations).Where("ticker_id = ?", ticker.ID) + db := s.prepareDb(opts...) + query := db.Where("ticker_id = ?", ticker.ID) if pagination.GetBefore() > 0 { query = query.Where("id < ?", pagination.GetBefore()) @@ -297,3 +305,32 @@ func (s *SqlStorage) SaveRefreshIntervalSettings(refreshInterval RefreshInterval return s.db.Save(&setting).Error } + +func (s *SqlStorage) prepareDb(opts ...func(*gorm.DB) *gorm.DB) *gorm.DB { + db := s.db + for _, opt := range opts { + db = opt(db) + } + + return db +} + +// WithPreload is a helper function to preload all associations. +func WithPreload() func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Preload(clause.Associations) + } +} + +// WithAttachments is a helper function to preload the attachments association. +func WithAttachments() func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Preload("Attachments") + } +} + +func WithInformation() func(*gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Preload("Information") + } +} diff --git a/internal/storage/sql_storage_test.go b/internal/storage/sql_storage_test.go index be943d04..8b3cd62e 100644 --- a/internal/storage/sql_storage_test.go +++ b/internal/storage/sql_storage_test.go @@ -22,7 +22,7 @@ var _ = Describe("SqlStorage", func() { db, err := gorm.Open(sqlite.Open("file:testdatabase?mode=memory&cache=shared"), &gorm.Config{}) Expect(err).ToNot(HaveOccurred()) - var storage = NewSqlStorage(db, "/uploads") + var store = NewSqlStorage(db, "/uploads") err = db.AutoMigrate( &Ticker{}, @@ -48,24 +48,24 @@ var _ = Describe("SqlStorage", func() { Describe("CountUser", func() { It("returns the number of users", func() { - Expect(storage.CountUser()).To(Equal(0)) + Expect(store.CountUser()).To(Equal(0)) err := db.Create(&User{}).Error Expect(err).ToNot(HaveOccurred()) - Expect(storage.CountUser()).To(Equal(1)) + Expect(store.CountUser()).To(Equal(1)) }) }) Describe("FindUsers", func() { It("returns all users", func() { - users, err := storage.FindUsers() + users, err := store.FindUsers() Expect(err).ToNot(HaveOccurred()) Expect(users).To(BeEmpty()) err = db.Create(&User{}).Error Expect(err).ToNot(HaveOccurred()) - users, err = storage.FindUsers() + users, err = store.FindUsers() Expect(err).ToNot(HaveOccurred()) Expect(users).To(HaveLen(1)) }) @@ -73,14 +73,14 @@ var _ = Describe("SqlStorage", func() { Describe("FindUserByID", func() { It("returns the user with the given id", func() { - user, err := storage.FindUserByID(1) + user, err := store.FindUserByID(1) Expect(err).To(HaveOccurred()) Expect(user).To(BeZero()) err = db.Create(&User{}).Error Expect(err).ToNot(HaveOccurred()) - user, err = storage.FindUserByID(1) + user, err = store.FindUserByID(1) Expect(err).ToNot(HaveOccurred()) Expect(user).ToNot(BeZero()) }) @@ -88,14 +88,14 @@ var _ = Describe("SqlStorage", func() { Describe("FindUsersByIDs", func() { It("returns the users with the given ids", func() { - users, err := storage.FindUsersByIDs([]int{1, 2}) + users, err := store.FindUsersByIDs([]int{1, 2}) Expect(err).ToNot(HaveOccurred()) Expect(users).To(BeEmpty()) err = db.Create(&User{}).Error Expect(err).ToNot(HaveOccurred()) - users, err = storage.FindUsersByIDs([]int{1, 2}) + users, err = store.FindUsersByIDs([]int{1, 2}) Expect(err).ToNot(HaveOccurred()) Expect(users).To(HaveLen(1)) }) @@ -103,14 +103,14 @@ var _ = Describe("SqlStorage", func() { Describe("FindUserByEmail", func() { It("returns the user with the given email", func() { - user, err := storage.FindUserByEmail("user@systemli.org") + user, err := store.FindUserByEmail("user@systemli.org") Expect(err).To(HaveOccurred()) Expect(user).To(BeZero()) err = db.Create(&User{Email: "user@systemli.org"}).Error Expect(err).ToNot(HaveOccurred()) - user, err = storage.FindUserByEmail("user@systemli.org") + user, err = store.FindUserByEmail("user@systemli.org") Expect(err).ToNot(HaveOccurred()) Expect(user).ToNot(BeZero()) }) @@ -119,23 +119,23 @@ var _ = Describe("SqlStorage", func() { Describe("FindUsersByTicker", func() { It("returns the users with the given ticker", func() { ticker := NewTicker() - err := storage.SaveTicker(&ticker) + err := store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) - users, err := storage.FindUsersByTicker(ticker) + users, err := store.FindUsersByTicker(ticker) Expect(err).ToNot(HaveOccurred()) Expect(users).To(BeEmpty()) user, err := NewUser("user@systemli.org", "password") Expect(err).ToNot(HaveOccurred()) - err = storage.SaveUser(&user) + err = store.SaveUser(&user) Expect(err).ToNot(HaveOccurred()) ticker.Users = append(ticker.Users, user) - err = storage.SaveTicker(&ticker) + err = store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) - users, err = storage.FindUsersByTicker(ticker) + users, err = store.FindUsersByTicker(ticker) Expect(err).ToNot(HaveOccurred()) Expect(users).To(HaveLen(1)) }) @@ -146,7 +146,7 @@ var _ = Describe("SqlStorage", func() { user, err := NewUser("user@systemli.org", "password") Expect(err).ToNot(HaveOccurred()) - err = storage.SaveUser(&user) + err = store.SaveUser(&user) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -161,7 +161,7 @@ var _ = Describe("SqlStorage", func() { user, err := NewUser("user@systemli.org", "password") Expect(err).ToNot(HaveOccurred()) - err = storage.SaveUser(&user) + err = store.SaveUser(&user) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -169,7 +169,7 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) - err = storage.DeleteUser(user) + err = store.DeleteUser(user) Expect(err).ToNot(HaveOccurred()) err = db.Model(&User{}).Count(&count).Error @@ -181,16 +181,16 @@ var _ = Describe("SqlStorage", func() { Describe("DeleteTickerUsers", func() { It("deletes the users from the ticker", func() { ticker := NewTicker() - err := storage.SaveTicker(&ticker) + err := store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) user, err := NewUser("user@systemli.org", "password") Expect(err).ToNot(HaveOccurred()) - err = storage.SaveUser(&user) + err = store.SaveUser(&user) Expect(err).ToNot(HaveOccurred()) ticker.Users = append(ticker.Users, user) - err = storage.SaveTicker(&ticker) + err = store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -198,7 +198,7 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) - err = storage.DeleteTickerUsers(&ticker) + err = store.DeleteTickerUsers(&ticker) Expect(err).ToNot(HaveOccurred()) Expect(ticker.Users).To(BeEmpty()) @@ -210,16 +210,16 @@ var _ = Describe("SqlStorage", func() { Describe("DeleteTickerUser", func() { It("deletes the user from the ticker", func() { ticker := NewTicker() - err := storage.SaveTicker(&ticker) + err := store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) user, err := NewUser("user@systemli.org", "password") Expect(err).ToNot(HaveOccurred()) - err = storage.SaveUser(&user) + err = store.SaveUser(&user) Expect(err).ToNot(HaveOccurred()) ticker.Users = append(ticker.Users, user) - err = storage.SaveTicker(&ticker) + err = store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -227,7 +227,7 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) - err = storage.DeleteTickerUser(&ticker, &user) + err = store.DeleteTickerUser(&ticker, &user) Expect(err).ToNot(HaveOccurred()) Expect(ticker.Users).To(BeEmpty()) }) @@ -236,15 +236,15 @@ var _ = Describe("SqlStorage", func() { Describe("AddTickerUser", func() { It("adds the user to the ticker", func() { ticker := NewTicker() - err := storage.SaveTicker(&ticker) + err := store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) user, err := NewUser("user@systemli.org", "password") Expect(err).ToNot(HaveOccurred()) - err = storage.SaveUser(&user) + err = store.SaveUser(&user) Expect(err).ToNot(HaveOccurred()) - err = storage.AddTickerUser(&ticker, &user) + err = store.AddTickerUser(&ticker, &user) Expect(err).ToNot(HaveOccurred()) Expect(ticker.Users).To(HaveLen(1)) }) @@ -252,69 +252,130 @@ var _ = Describe("SqlStorage", func() { Describe("FindTickers", func() { It("returns all tickers", func() { - tickers, err := storage.FindTickers() + tickers, err := store.FindTickers() Expect(err).ToNot(HaveOccurred()) Expect(tickers).To(BeEmpty()) err = db.Create(&Ticker{}).Error Expect(err).ToNot(HaveOccurred()) - tickers, err = storage.FindTickers() + tickers, err = store.FindTickers() Expect(err).ToNot(HaveOccurred()) Expect(tickers).To(HaveLen(1)) }) + + It("returns all tickers with preload", func() { + err = db.Create(&Ticker{ + Information: TickerInformation{ + Author: "Author", + }, + }).Error + Expect(err).ToNot(HaveOccurred()) + + tickers, err := store.FindTickers(WithPreload()) + Expect(err).ToNot(HaveOccurred()) + Expect(tickers).To(HaveLen(1)) + + Expect(tickers[0].Information.Author).To(Equal("Author")) + }) }) Describe("FindTickersByIDs", func() { It("returns the tickers with the given ids", func() { - tickers, err := storage.FindTickersByIDs([]int{1, 2}) + tickers, err := store.FindTickersByIDs([]int{1, 2}) Expect(err).ToNot(HaveOccurred()) Expect(tickers).To(BeEmpty()) err = db.Create(&Ticker{}).Error Expect(err).ToNot(HaveOccurred()) - tickers, err = storage.FindTickersByIDs([]int{1, 2}) + tickers, err = store.FindTickersByIDs([]int{1, 2}) Expect(err).ToNot(HaveOccurred()) Expect(tickers).To(HaveLen(1)) }) + + It("returns the tickers with the given ids and preload", func() { + err = db.Create(&Ticker{ + Information: TickerInformation{ + Author: "Author", + }, + }).Error + Expect(err).ToNot(HaveOccurred()) + + tickers, err := store.FindTickersByIDs([]int{1, 2}, WithPreload()) + Expect(err).ToNot(HaveOccurred()) + Expect(tickers).To(HaveLen(1)) + + Expect(tickers[0].Information.Author).To(Equal("Author")) + }) }) Describe("FindTickerByID", func() { It("returns the ticker with the given id", func() { - ticker, err := storage.FindTickerByID(1) + ticker, err := store.FindTickerByID(1) Expect(err).To(HaveOccurred()) Expect(ticker).To(BeZero()) err = db.Create(&Ticker{}).Error Expect(err).ToNot(HaveOccurred()) - ticker, err = storage.FindTickerByID(1) + ticker, err = store.FindTickerByID(1) Expect(err).ToNot(HaveOccurred()) Expect(ticker).ToNot(BeZero()) }) + + It("returns the ticker with the given id and preload", func() { + err = db.Create(&Ticker{ + Information: TickerInformation{ + Author: "Author", + }, + }).Error + Expect(err).ToNot(HaveOccurred()) + + ticker, err := store.FindTickerByID(1, WithPreload()) + Expect(err).ToNot(HaveOccurred()) + Expect(ticker).ToNot(BeZero()) + + Expect(ticker.Information.Author).To(Equal("Author")) + }) }) Describe("FindTickerByDomain", func() { It("returns the ticker with the given domain", func() { - ticker, err := storage.FindTickerByDomain("systemli.org") + ticker, err := store.FindTickerByDomain("systemli.org") Expect(err).To(HaveOccurred()) Expect(ticker).To(BeZero()) err = db.Create(&Ticker{Domain: "systemli.org"}).Error Expect(err).ToNot(HaveOccurred()) - ticker, err = storage.FindTickerByDomain("systemli.org") + ticker, err = store.FindTickerByDomain("systemli.org") Expect(err).ToNot(HaveOccurred()) Expect(ticker).ToNot(BeZero()) }) + + It("returns the ticker with the given domain and preload information", func() { + err = db.Create(&Ticker{ + Domain: "systemli.org", + Information: TickerInformation{ + Author: "Author", + }, + }).Error + Expect(err).ToNot(HaveOccurred()) + + ticker, err := store.FindTickerByDomain("systemli.org", WithInformation()) + Expect(err).ToNot(HaveOccurred()) + Expect(ticker).ToNot(BeZero()) + + Expect(ticker.Information.Author).To(Equal("Author")) + }) }) Describe("SaveTicker", func() { It("persists the ticker", func() { ticker := NewTicker() - err = storage.SaveTicker(&ticker) + err = store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -328,7 +389,7 @@ var _ = Describe("SqlStorage", func() { It("deletes the ticker", func() { ticker := NewTicker() - err = storage.SaveTicker(&ticker) + err = store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -336,7 +397,7 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) - err = storage.DeleteTicker(ticker) + err = store.DeleteTicker(ticker) Expect(err).ToNot(HaveOccurred()) err = db.Model(&Ticker{}).Count(&count).Error @@ -347,14 +408,14 @@ var _ = Describe("SqlStorage", func() { Describe("FindUploadByUUID", func() { It("returns the upload with the given uuid", func() { - upload, err := storage.FindUploadByUUID("uuid") + upload, err := store.FindUploadByUUID("uuid") Expect(err).To(HaveOccurred()) Expect(upload).To(BeZero()) err = db.Create(&Upload{UUID: "uuid"}).Error Expect(err).ToNot(HaveOccurred()) - upload, err = storage.FindUploadByUUID("uuid") + upload, err = store.FindUploadByUUID("uuid") Expect(err).ToNot(HaveOccurred()) Expect(upload).ToNot(BeZero()) }) @@ -362,14 +423,14 @@ var _ = Describe("SqlStorage", func() { Describe("FindUploadsByIDs", func() { It("returns the uploads with the given ids", func() { - uploads, err := storage.FindUploadsByIDs([]int{1, 2}) + uploads, err := store.FindUploadsByIDs([]int{1, 2}) Expect(err).ToNot(HaveOccurred()) Expect(uploads).To(BeEmpty()) err = db.Create(&Upload{}).Error Expect(err).ToNot(HaveOccurred()) - uploads, err = storage.FindUploadsByIDs([]int{1, 2}) + uploads, err = store.FindUploadsByIDs([]int{1, 2}) Expect(err).ToNot(HaveOccurred()) Expect(uploads).To(HaveLen(1)) }) @@ -379,7 +440,7 @@ var _ = Describe("SqlStorage", func() { It("persists the upload", func() { upload := NewUpload("image.jpg", "content-type", 1) - err = storage.SaveUpload(&upload) + err = store.SaveUpload(&upload) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -393,7 +454,7 @@ var _ = Describe("SqlStorage", func() { It("deletes the upload", func() { upload := NewUpload("image.jpg", "content-type", 1) - err = storage.SaveUpload(&upload) + err = store.SaveUpload(&upload) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -401,7 +462,7 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) - err = storage.DeleteUpload(upload) + err = store.DeleteUpload(upload) Expect(err).ToNot(HaveOccurred()) err = db.Model(&Upload{}).Count(&count).Error @@ -414,7 +475,7 @@ var _ = Describe("SqlStorage", func() { It("deletes the uploads", func() { upload := NewUpload("image.jpg", "content-type", 1) - err = storage.SaveUpload(&upload) + err = store.SaveUpload(&upload) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -423,7 +484,7 @@ var _ = Describe("SqlStorage", func() { Expect(count).To(Equal(int64(1))) uploads := []Upload{upload} - storage.DeleteUploads(uploads) + store.DeleteUploads(uploads) err = db.Model(&Upload{}).Count(&count).Error Expect(err).ToNot(HaveOccurred()) @@ -434,11 +495,11 @@ var _ = Describe("SqlStorage", func() { Describe("DeleteUploadsByTicker", func() { It("deletes the uploads", func() { ticker := NewTicker() - err := storage.SaveTicker(&ticker) + err := store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) upload := NewUpload("image.jpg", "content-type", ticker.ID) - err = storage.SaveUpload(&upload) + err = store.SaveUpload(&upload) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -446,7 +507,7 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) - err = storage.DeleteUploadsByTicker(ticker) + err = store.DeleteUploadsByTicker(ticker) Expect(err).ToNot(HaveOccurred()) err = db.Model(&Upload{}).Count(&count).Error @@ -457,54 +518,95 @@ var _ = Describe("SqlStorage", func() { Describe("FindMessage", func() { It("returns the message with the given id", func() { - message, err := storage.FindMessage(1, 1) + message, err := store.FindMessage(1, 1) Expect(err).To(HaveOccurred()) Expect(message).To(BeZero()) err = db.Create(&Message{ID: 1, TickerID: 1}).Error Expect(err).ToNot(HaveOccurred()) - message, err = storage.FindMessage(1, 1) + message, err = store.FindMessage(1, 1) Expect(err).ToNot(HaveOccurred()) Expect(message).ToNot(BeZero()) }) + + It("returns the message with the given id and attachments", func() { + err = db.Create(&Message{ + ID: 1, + TickerID: 1, + Text: "Text", + Attachments: []Attachment{ + {ID: 1, MessageID: 1, UUID: "uuid", ContentType: "image/jpg", Extension: "jpg"}, + }, + }).Error + Expect(err).ToNot(HaveOccurred()) + + message, err := store.FindMessage(1, 1, WithAttachments()) + Expect(err).ToNot(HaveOccurred()) + Expect(message).ToNot(BeZero()) + + Expect(message.Attachments).To(HaveLen(1)) + Expect(message.Attachments[0].UUID).To(Equal("uuid")) + }) }) Describe("FindMessagesByTicker", func() { It("returns the messages with the given ticker", func() { ticker := NewTicker() - err := storage.SaveTicker(&ticker) + err := store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) - messages, err := storage.FindMessagesByTicker(ticker) + messages, err := store.FindMessagesByTicker(ticker) Expect(err).ToNot(HaveOccurred()) Expect(messages).To(BeEmpty()) err = db.Create(&Message{TickerID: ticker.ID}).Error Expect(err).ToNot(HaveOccurred()) - messages, err = storage.FindMessagesByTicker(ticker) + messages, err = store.FindMessagesByTicker(ticker) + Expect(err).ToNot(HaveOccurred()) + Expect(messages).To(HaveLen(1)) + }) + + It("returns the messages with the given ticker and attachments", func() { + ticker := NewTicker() + err := store.SaveTicker(&ticker) + Expect(err).ToNot(HaveOccurred()) + + err = db.Create(&Message{ + TickerID: ticker.ID, + Text: "Text", + Attachments: []Attachment{ + {ID: 1, MessageID: 1, UUID: "uuid", ContentType: "image/jpg", Extension: "jpg"}, + }, + }).Error + Expect(err).ToNot(HaveOccurred()) + + messages, err := store.FindMessagesByTicker(ticker, WithAttachments()) Expect(err).ToNot(HaveOccurred()) Expect(messages).To(HaveLen(1)) + + Expect(messages[0].Attachments).To(HaveLen(1)) + Expect(messages[0].Attachments[0].UUID).To(Equal("uuid")) }) }) Describe("FindMessagesByTickerAndPagination", func() { It("returns the messages with the given ticker and pagination", func() { ticker := NewTicker() - err := storage.SaveTicker(&ticker) + err := store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) c := &gin.Context{} p := pagination.NewPagination(c) - messages, err := storage.FindMessagesByTickerAndPagination(ticker, *p) + messages, err := store.FindMessagesByTickerAndPagination(ticker, *p) Expect(err).ToNot(HaveOccurred()) Expect(messages).To(BeEmpty()) err = db.Create(&Message{TickerID: ticker.ID}).Error Expect(err).ToNot(HaveOccurred()) - messages, err = storage.FindMessagesByTickerAndPagination(ticker, *p) + messages, err = store.FindMessagesByTickerAndPagination(ticker, *p) Expect(err).ToNot(HaveOccurred()) Expect(messages).To(HaveLen(1)) @@ -518,31 +620,55 @@ var _ = Describe("SqlStorage", func() { c = &gin.Context{} c.Request = &http.Request{URL: &url.URL{RawQuery: "limit=2"}} p = pagination.NewPagination(c) - messages, err = storage.FindMessagesByTickerAndPagination(ticker, *p) + messages, err = store.FindMessagesByTickerAndPagination(ticker, *p) Expect(err).ToNot(HaveOccurred()) Expect(messages).To(HaveLen(2)) c = &gin.Context{} c.Request = &http.Request{URL: &url.URL{RawQuery: "limit=2&after=2"}} p = pagination.NewPagination(c) - messages, err = storage.FindMessagesByTickerAndPagination(ticker, *p) + messages, err = store.FindMessagesByTickerAndPagination(ticker, *p) Expect(err).ToNot(HaveOccurred()) Expect(messages).To(HaveLen(2)) c = &gin.Context{} c.Request = &http.Request{URL: &url.URL{RawQuery: "limit=2&before=4"}} p = pagination.NewPagination(c) - messages, err = storage.FindMessagesByTickerAndPagination(ticker, *p) + messages, err = store.FindMessagesByTickerAndPagination(ticker, *p) Expect(err).ToNot(HaveOccurred()) Expect(messages).To(HaveLen(2)) }) + + It("returns the messages with the given ticker, pagination and attachments", func() { + ticker := NewTicker() + err := store.SaveTicker(&ticker) + Expect(err).ToNot(HaveOccurred()) + + err = db.Create(&Message{ + TickerID: ticker.ID, + Text: "Text", + Attachments: []Attachment{ + {ID: 1, MessageID: 1, UUID: "uuid", ContentType: "image/jpg", Extension: "jpg"}, + }, + }).Error + Expect(err).ToNot(HaveOccurred()) + + c := &gin.Context{} + p := pagination.NewPagination(c) + messages, err := store.FindMessagesByTickerAndPagination(ticker, *p, WithAttachments()) + Expect(err).ToNot(HaveOccurred()) + Expect(messages).To(HaveLen(1)) + + Expect(messages[0].Attachments).To(HaveLen(1)) + Expect(messages[0].Attachments[0].UUID).To(Equal("uuid")) + }) }) Describe("SaveMessage", func() { It("persists the message", func() { message := NewMessage() - err = storage.SaveMessage(&message) + err = store.SaveMessage(&message) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -556,7 +682,7 @@ var _ = Describe("SqlStorage", func() { It("deletes the message", func() { message := NewMessage() - err = storage.SaveMessage(&message) + err = store.SaveMessage(&message) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -564,7 +690,7 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) - err = storage.DeleteMessage(message) + err = store.DeleteMessage(message) Expect(err).ToNot(HaveOccurred()) err = db.Model(&Message{}).Count(&count).Error @@ -576,12 +702,12 @@ var _ = Describe("SqlStorage", func() { Describe("DeleteMessages", func() { It("deletes the messages", func() { ticker := NewTicker() - err := storage.SaveTicker(&ticker) + err := store.SaveTicker(&ticker) Expect(err).ToNot(HaveOccurred()) message := NewMessage() message.TickerID = ticker.ID - err = storage.SaveMessage(&message) + err = store.SaveMessage(&message) Expect(err).ToNot(HaveOccurred()) var count int64 @@ -589,7 +715,7 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(1))) - err = storage.DeleteMessages(ticker) + err = store.DeleteMessages(ticker) Expect(err).ToNot(HaveOccurred()) err = db.Model(&Message{}).Count(&count).Error @@ -600,7 +726,7 @@ var _ = Describe("SqlStorage", func() { Describe("GetInactiveSettings", func() { It("returns the default inactive setting", func() { - setting := storage.GetInactiveSettings() + setting := store.GetInactiveSettings() Expect(setting.Author).To(Equal(DefaultInactiveSettings().Author)) }) @@ -609,17 +735,17 @@ var _ = Describe("SqlStorage", func() { Author: "author", } - err = storage.SaveInactiveSettings(settings) + err = store.SaveInactiveSettings(settings) Expect(err).ToNot(HaveOccurred()) - setting := storage.GetInactiveSettings() + setting := store.GetInactiveSettings() Expect(setting.Author).To(Equal(settings.Author)) }) }) Describe("GetRefreshIntervalSetting", func() { It("returns the default refresh interval setting", func() { - setting := storage.GetRefreshIntervalSettings() + setting := store.GetRefreshIntervalSettings() Expect(setting.RefreshInterval).To(Equal(DefaultRefreshIntervalSettings().RefreshInterval)) }) @@ -628,10 +754,10 @@ var _ = Describe("SqlStorage", func() { RefreshInterval: 1000, } - err = storage.SaveRefreshIntervalSettings(settings) + err = store.SaveRefreshIntervalSettings(settings) Expect(err).ToNot(HaveOccurred()) - setting := storage.GetRefreshIntervalSettings() + setting := store.GetRefreshIntervalSettings() Expect(setting.RefreshInterval).To(Equal(settings.RefreshInterval)) }) }) diff --git a/internal/storage/storage.go b/internal/storage/storage.go index fbb0d4ea..bbcece04 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -3,6 +3,7 @@ package storage import ( "github.com/sirupsen/logrus" "github.com/systemli/ticker/internal/api/pagination" + "gorm.io/gorm" ) var log = logrus.WithField("package", "storage") @@ -19,19 +20,19 @@ type Storage interface { DeleteTickerUsers(ticker *Ticker) error DeleteTickerUser(ticker *Ticker, user *User) error AddTickerUser(ticker *Ticker, user *User) error - FindTickers() ([]Ticker, error) - FindTickersByIDs(ids []int) ([]Ticker, error) - FindTickerByDomain(domain string) (Ticker, error) - FindTickerByID(id int) (Ticker, error) + FindTickers(opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) + FindTickersByIDs(ids []int, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) + FindTickerByDomain(domain string, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) + FindTickerByID(id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) SaveTicker(ticker *Ticker) error DeleteTicker(ticker Ticker) error SaveUpload(upload *Upload) error DeleteUpload(upload Upload) error DeleteUploads(uploads []Upload) DeleteUploadsByTicker(ticker Ticker) error - FindMessage(tickerID, messageID int) (Message, error) - FindMessagesByTicker(ticker Ticker) ([]Message, error) - FindMessagesByTickerAndPagination(ticker Ticker, pagination pagination.Pagination) ([]Message, error) + FindMessage(tickerID, messageID int, opts ...func(*gorm.DB) *gorm.DB) (Message, error) + FindMessagesByTicker(ticker Ticker, opts ...func(*gorm.DB) *gorm.DB) ([]Message, error) + FindMessagesByTickerAndPagination(ticker Ticker, pagination pagination.Pagination, opts ...func(*gorm.DB) *gorm.DB) ([]Message, error) SaveMessage(message *Message) error DeleteMessage(message Message) error DeleteMessages(ticker Ticker) error