diff --git a/internal/api/middleware/ticker/ticker.go b/internal/api/middleware/ticker/ticker.go index 3bd04aa6..e4f39bb7 100644 --- a/internal/api/middleware/ticker/ticker.go +++ b/internal/api/middleware/ticker/ticker.go @@ -8,7 +8,6 @@ import ( "github.com/systemli/ticker/internal/api/helper" "github.com/systemli/ticker/internal/api/response" "github.com/systemli/ticker/internal/storage" - "github.com/systemli/ticker/internal/util" "gorm.io/gorm" ) @@ -21,19 +20,7 @@ func PrefetchTicker(s storage.Storage, opts ...func(*gorm.DB) *gorm.DB) gin.Hand return } - if !user.IsSuperAdmin { - var tickerIDs []int - for _, t := range user.Tickers { - tickerIDs = append(tickerIDs, t.ID) - } - if !util.Contains(tickerIDs, tickerID) { - c.JSON(http.StatusForbidden, response.ErrorResponse(response.CodeInsufficientPermissions, response.InsufficientPermissions)) - return - } - } - - ticker, err := s.FindTickerByID(tickerID, opts...) - + ticker, err := s.FindTickerByUserAndID(user, tickerID, opts...) if err != nil { c.JSON(http.StatusNotFound, response.ErrorResponse(response.CodeNotFound, response.TickerNotFound)) return diff --git a/internal/api/middleware/ticker/ticker_test.go b/internal/api/middleware/ticker/ticker_test.go index aeabbca4..2bc0db04 100644 --- a/internal/api/middleware/ticker/ticker_test.go +++ b/internal/api/middleware/ticker/ticker_test.go @@ -28,26 +28,13 @@ func TestPrefetchTickerParamMissing(t *testing.T) { assert.Equal(t, http.StatusBadRequest, w.Code) } -func TestPrefetchTickerNoPermission(t *testing.T) { - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.AddParam("tickerID", "1") - c.Set("me", storage.User{IsSuperAdmin: false, Tickers: []storage.Ticker{{ID: 2}}}) - s := &storage.MockStorage{} - mw := PrefetchTicker(s) - - mw(c) - - assert.Equal(t, http.StatusForbidden, w.Code) -} - func TestPrefetchTickerStorageError(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) c.AddParam("tickerID", "1") c.Set("me", storage.User{IsSuperAdmin: true}) s := &storage.MockStorage{} - s.On("FindTickerByID", mock.Anything).Return(storage.Ticker{}, errors.New("storage error")) + s.On("FindTickerByUserAndID", mock.Anything, mock.Anything, mock.Anything).Return(storage.Ticker{}, errors.New("storage error")) mw := PrefetchTicker(s) mw(c) @@ -62,7 +49,7 @@ func TestPrefetchTicker(t *testing.T) { c.Set("me", storage.User{IsSuperAdmin: true}) s := &storage.MockStorage{} ticker := storage.Ticker{ID: 1} - s.On("FindTickerByID", mock.Anything).Return(ticker, nil) + s.On("FindTickerByUserAndID", mock.Anything, mock.Anything, mock.Anything).Return(ticker, nil) mw := PrefetchTicker(s) mw(c) diff --git a/internal/api/tickers.go b/internal/api/tickers.go index 80f14e32..0ca0da99 100644 --- a/internal/api/tickers.go +++ b/internal/api/tickers.go @@ -18,12 +18,7 @@ func (h *handler) GetTickers(c *gin.Context) { return } - var tickers []storage.Ticker - if me.IsSuperAdmin { - tickers, err = h.storage.FindTickers() - } else { - tickers = me.Tickers - } + tickers, err := h.storage.FindTickersByUser(me) if err != nil { c.JSON(http.StatusNotFound, response.ErrorResponse(response.CodeDefault, response.TickerNotFound)) return diff --git a/internal/api/tickers_test.go b/internal/api/tickers_test.go index a0e79b11..b12a34e5 100644 --- a/internal/api/tickers_test.go +++ b/internal/api/tickers_test.go @@ -40,7 +40,7 @@ func TestGetTickersStorageError(t *testing.T) { c, _ := gin.CreateTestContext(w) c.Set("me", storage.User{IsSuperAdmin: true}) s := &storage.MockStorage{} - s.On("FindTickers", mock.Anything).Return([]storage.Ticker{}, errors.New("storage error")) + s.On("FindTickersByUser", mock.Anything, mock.Anything).Return([]storage.Ticker{}, errors.New("storage error")) h := handler{ storage: s, config: config.NewConfig(), @@ -56,7 +56,7 @@ func TestGetTickers(t *testing.T) { c, _ := gin.CreateTestContext(w) c.Set("me", storage.User{IsSuperAdmin: false, Tickers: []storage.Ticker{{ID: 2}}}) s := &storage.MockStorage{} - s.On("FindTickersByIDs", mock.Anything, mock.Anything).Return([]storage.Ticker{}, nil) + s.On("FindTickersByUser", mock.Anything, mock.Anything).Return([]storage.Ticker{}, nil) h := handler{ storage: s, config: config.NewConfig(), diff --git a/internal/storage/mock_Storage.go b/internal/storage/mock_Storage.go index c3913946..7982d032 100644 --- a/internal/storage/mock_Storage.go +++ b/internal/storage/mock_Storage.go @@ -317,6 +317,37 @@ func (_m *MockStorage) FindTickerByID(id int, opts ...func(*gorm.DB) *gorm.DB) ( return r0, r1 } +// FindTickerByUserAndID provides a mock function with given fields: user, id, opts +func (_m *MockStorage) FindTickerByUserAndID(user User, id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, user, id) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 Ticker + var r1 error + if rf, ok := ret.Get(0).(func(User, int, ...func(*gorm.DB) *gorm.DB) (Ticker, error)); ok { + return rf(user, id, opts...) + } + if rf, ok := ret.Get(0).(func(User, int, ...func(*gorm.DB) *gorm.DB) Ticker); ok { + r0 = rf(user, id, opts...) + } else { + r0 = ret.Get(0).(Ticker) + } + + if rf, ok := ret.Get(1).(func(User, int, ...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(user, id, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindTickers provides a mock function with given fields: opts func (_m *MockStorage) FindTickers(opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) { _va := make([]interface{}, len(opts)) @@ -382,6 +413,39 @@ func (_m *MockStorage) FindTickersByIDs(ids []int, opts ...func(*gorm.DB) *gorm. return r0, r1 } +// FindTickersByUser provides a mock function with given fields: user, opts +func (_m *MockStorage) FindTickersByUser(user User, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) { + _va := make([]interface{}, len(opts)) + for _i := range opts { + _va[_i] = opts[_i] + } + var _ca []interface{} + _ca = append(_ca, user) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 []Ticker + var r1 error + if rf, ok := ret.Get(0).(func(User, ...func(*gorm.DB) *gorm.DB) ([]Ticker, error)); ok { + return rf(user, opts...) + } + if rf, ok := ret.Get(0).(func(User, ...func(*gorm.DB) *gorm.DB) []Ticker); ok { + r0 = rf(user, opts...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]Ticker) + } + } + + if rf, ok := ret.Get(1).(func(User, ...func(*gorm.DB) *gorm.DB) error); ok { + r1 = rf(user, opts...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // FindUploadByUUID provides a mock function with given fields: uuid func (_m *MockStorage) FindUploadByUUID(uuid string) (Upload, error) { ret := _m.Called(uuid) diff --git a/internal/storage/sql_storage.go b/internal/storage/sql_storage.go index bf693784..451168dc 100644 --- a/internal/storage/sql_storage.go +++ b/internal/storage/sql_storage.go @@ -105,6 +105,34 @@ func (s *SqlStorage) FindTickers(opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, err return tickers, err } +func (s *SqlStorage) FindTickersByUser(user User, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) { + tickers := make([]Ticker, 0) + db := s.prepareDb(opts...) + + var err error + if user.IsSuperAdmin { + err = db.Find(&tickers).Error + } else { + err = db.Model(&user).Association("Tickers").Find(&tickers) + } + + return tickers, err +} + +func (s *SqlStorage) FindTickerByUserAndID(user User, id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) { + db := s.prepareDb(opts...) + + var ticker Ticker + var err error + if user.IsSuperAdmin { + err = db.First(&ticker, id).Error + } else { + err = db.Model(&user).Association("Tickers").Find(&ticker, id) + } + + return ticker, err +} + func (s *SqlStorage) FindTickersByIDs(ids []int, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) { tickers := make([]Ticker, 0) db := s.prepareDb(opts...) diff --git a/internal/storage/sql_storage_test.go b/internal/storage/sql_storage_test.go index d8c01f7b..331d9fb5 100644 --- a/internal/storage/sql_storage_test.go +++ b/internal/storage/sql_storage_test.go @@ -361,6 +361,82 @@ var _ = Describe("SqlStorage", func() { }) }) + Describe("FindTickersByUser", func() { + var user = User{ + Email: "user@systemli.org", + IsSuperAdmin: false, + } + var admin User = User{ + Email: "admin@systemli.org", + IsSuperAdmin: true, + } + var ticker Ticker = Ticker{ + Users: []User{user}, + } + + BeforeEach(func() { + Expect(db.Create(&user).Error).ToNot(HaveOccurred()) + Expect(db.Create(&admin).Error).ToNot(HaveOccurred()) + Expect(db.Create(&ticker).Error).ToNot(HaveOccurred()) + }) + + It("returns all tickers for admins", func() { + tickers, err := store.FindTickersByUser(admin) + Expect(err).ToNot(HaveOccurred()) + Expect(tickers).To(HaveLen(1)) + }) + + It("returns all tickers for users", func() { + tickers, err := store.FindTickersByUser(user) + Expect(err).ToNot(HaveOccurred()) + Expect(tickers).To(HaveLen(1)) + }) + + It("returns no tickers for users", func() { + tickers, err := store.FindTickersByUser(User{ID: 2}) + Expect(err).ToNot(HaveOccurred()) + Expect(tickers).To(BeEmpty()) + }) + }) + + Describe("FindTickerByUserAndID", func() { + var user = User{ + Email: "user@systemli.org", + IsSuperAdmin: false, + } + var admin = User{ + Email: "admin@systemli.org", + IsSuperAdmin: true, + } + var ticker = Ticker{ + Users: []User{user}, + } + + BeforeEach(func() { + Expect(db.Create(&user).Error).ToNot(HaveOccurred()) + Expect(db.Create(&admin).Error).ToNot(HaveOccurred()) + Expect(db.Create(&ticker).Error).ToNot(HaveOccurred()) + }) + + It("returns the ticker for admins", func() { + ticker, err := store.FindTickerByUserAndID(admin, ticker.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(ticker).ToNot(BeZero()) + }) + + It("returns the ticker for users", func() { + ticker, err := store.FindTickerByUserAndID(user, ticker.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(ticker).ToNot(BeZero()) + }) + + It("returns no ticker for users", func() { + ticker, err := store.FindTickerByUserAndID(User{ID: 2}, ticker.ID) + Expect(err).ToNot(HaveOccurred()) + Expect(ticker).To(BeZero()) + }) + }) + Describe("FindTickersByIDs", func() { It("returns the tickers with the given ids", func() { tickers, err := store.FindTickersByIDs([]int{1, 2}) diff --git a/internal/storage/storage.go b/internal/storage/storage.go index 26af4537..4dae872e 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -20,6 +20,8 @@ type Storage interface { DeleteTickerUser(ticker *Ticker, user *User) error AddTickerUser(ticker *Ticker, user *User) error FindTickers(opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) + FindTickersByUser(user User, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) + FindTickerByUserAndID(user User, id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) FindTickersByIDs(ids []int, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) FindTickerByDomain(domain string, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) FindTickerByID(id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error)