Skip to content

Commit

Permalink
🐛 Fix fetching Tickers for non-admin Users
Browse files Browse the repository at this point in the history
  • Loading branch information
0x46616c6b committed Oct 24, 2023
1 parent c759d2f commit 2479c07
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 37 deletions.
15 changes: 1 addition & 14 deletions internal/api/middleware/ticker/ticker.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/systemli/ticker/internal/api/helper"
"github.com/systemli/ticker/internal/api/response"
"github.com/systemli/ticker/internal/storage"
"github.com/systemli/ticker/internal/util"
"gorm.io/gorm"
)

Expand All @@ -21,19 +20,7 @@ func PrefetchTicker(s storage.Storage, opts ...func(*gorm.DB) *gorm.DB) gin.Hand
return
}

if !user.IsSuperAdmin {
var tickerIDs []int
for _, t := range user.Tickers {
tickerIDs = append(tickerIDs, t.ID)
}
if !util.Contains(tickerIDs, tickerID) {
c.JSON(http.StatusForbidden, response.ErrorResponse(response.CodeInsufficientPermissions, response.InsufficientPermissions))
return
}
}

ticker, err := s.FindTickerByID(tickerID, opts...)

ticker, err := s.FindTickerByUserAndID(user, tickerID, opts...)
if err != nil {
c.JSON(http.StatusNotFound, response.ErrorResponse(response.CodeNotFound, response.TickerNotFound))
return
Expand Down
17 changes: 2 additions & 15 deletions internal/api/middleware/ticker/ticker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,13 @@ func TestPrefetchTickerParamMissing(t *testing.T) {
assert.Equal(t, http.StatusBadRequest, w.Code)
}

func TestPrefetchTickerNoPermission(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.AddParam("tickerID", "1")
c.Set("me", storage.User{IsSuperAdmin: false, Tickers: []storage.Ticker{{ID: 2}}})
s := &storage.MockStorage{}
mw := PrefetchTicker(s)

mw(c)

assert.Equal(t, http.StatusForbidden, w.Code)
}

func TestPrefetchTickerStorageError(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.AddParam("tickerID", "1")
c.Set("me", storage.User{IsSuperAdmin: true})
s := &storage.MockStorage{}
s.On("FindTickerByID", mock.Anything).Return(storage.Ticker{}, errors.New("storage error"))
s.On("FindTickerByUserAndID", mock.Anything, mock.Anything, mock.Anything).Return(storage.Ticker{}, errors.New("storage error"))
mw := PrefetchTicker(s)

mw(c)
Expand All @@ -62,7 +49,7 @@ func TestPrefetchTicker(t *testing.T) {
c.Set("me", storage.User{IsSuperAdmin: true})
s := &storage.MockStorage{}
ticker := storage.Ticker{ID: 1}
s.On("FindTickerByID", mock.Anything).Return(ticker, nil)
s.On("FindTickerByUserAndID", mock.Anything, mock.Anything, mock.Anything).Return(ticker, nil)
mw := PrefetchTicker(s)

mw(c)
Expand Down
7 changes: 1 addition & 6 deletions internal/api/tickers.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@ func (h *handler) GetTickers(c *gin.Context) {
return
}

var tickers []storage.Ticker
if me.IsSuperAdmin {
tickers, err = h.storage.FindTickers()
} else {
tickers = me.Tickers
}
tickers, err := h.storage.FindTickersByUser(me)
if err != nil {
c.JSON(http.StatusNotFound, response.ErrorResponse(response.CodeDefault, response.TickerNotFound))
return
Expand Down
4 changes: 2 additions & 2 deletions internal/api/tickers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestGetTickersStorageError(t *testing.T) {
c, _ := gin.CreateTestContext(w)
c.Set("me", storage.User{IsSuperAdmin: true})
s := &storage.MockStorage{}
s.On("FindTickers", mock.Anything).Return([]storage.Ticker{}, errors.New("storage error"))
s.On("FindTickersByUser", mock.Anything, mock.Anything).Return([]storage.Ticker{}, errors.New("storage error"))
h := handler{
storage: s,
config: config.NewConfig(),
Expand All @@ -56,7 +56,7 @@ func TestGetTickers(t *testing.T) {
c, _ := gin.CreateTestContext(w)
c.Set("me", storage.User{IsSuperAdmin: false, Tickers: []storage.Ticker{{ID: 2}}})
s := &storage.MockStorage{}
s.On("FindTickersByIDs", mock.Anything, mock.Anything).Return([]storage.Ticker{}, nil)
s.On("FindTickersByUser", mock.Anything, mock.Anything).Return([]storage.Ticker{}, nil)
h := handler{
storage: s,
config: config.NewConfig(),
Expand Down
64 changes: 64 additions & 0 deletions internal/storage/mock_Storage.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions internal/storage/sql_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,34 @@ func (s *SqlStorage) FindTickers(opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, err
return tickers, err
}

func (s *SqlStorage) FindTickersByUser(user User, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) {
tickers := make([]Ticker, 0)
db := s.prepareDb(opts...)

var err error
if user.IsSuperAdmin {
err = db.Find(&tickers).Error
} else {
err = db.Model(&user).Association("Tickers").Find(&tickers)
}

return tickers, err
}

func (s *SqlStorage) FindTickerByUserAndID(user User, id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error) {
db := s.prepareDb(opts...)

var ticker Ticker
var err error
if user.IsSuperAdmin {
err = db.First(&ticker, id).Error
} else {
err = db.Model(&user).Association("Tickers").Find(&ticker, id)
}

return ticker, err
}

func (s *SqlStorage) FindTickersByIDs(ids []int, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error) {
tickers := make([]Ticker, 0)
db := s.prepareDb(opts...)
Expand Down
76 changes: 76 additions & 0 deletions internal/storage/sql_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,82 @@ var _ = Describe("SqlStorage", func() {
})
})

Describe("FindTickersByUser", func() {
var user = User{
Email: "[email protected]",
IsSuperAdmin: false,
}
var admin User = User{
Email: "[email protected]",
IsSuperAdmin: true,
}
var ticker Ticker = Ticker{
Users: []User{user},
}

BeforeEach(func() {
Expect(db.Create(&user).Error).ToNot(HaveOccurred())
Expect(db.Create(&admin).Error).ToNot(HaveOccurred())
Expect(db.Create(&ticker).Error).ToNot(HaveOccurred())
})

It("returns all tickers for admins", func() {
tickers, err := store.FindTickersByUser(admin)
Expect(err).ToNot(HaveOccurred())
Expect(tickers).To(HaveLen(1))
})

It("returns all tickers for users", func() {
tickers, err := store.FindTickersByUser(user)
Expect(err).ToNot(HaveOccurred())
Expect(tickers).To(HaveLen(1))
})

It("returns no tickers for users", func() {
tickers, err := store.FindTickersByUser(User{ID: 2})
Expect(err).ToNot(HaveOccurred())
Expect(tickers).To(BeEmpty())
})
})

Describe("FindTickerByUserAndID", func() {
var user = User{
Email: "[email protected]",
IsSuperAdmin: false,
}
var admin = User{
Email: "[email protected]",
IsSuperAdmin: true,
}
var ticker = Ticker{
Users: []User{user},
}

BeforeEach(func() {
Expect(db.Create(&user).Error).ToNot(HaveOccurred())
Expect(db.Create(&admin).Error).ToNot(HaveOccurred())
Expect(db.Create(&ticker).Error).ToNot(HaveOccurred())
})

It("returns the ticker for admins", func() {
ticker, err := store.FindTickerByUserAndID(admin, ticker.ID)
Expect(err).ToNot(HaveOccurred())
Expect(ticker).ToNot(BeZero())
})

It("returns the ticker for users", func() {
ticker, err := store.FindTickerByUserAndID(user, ticker.ID)
Expect(err).ToNot(HaveOccurred())
Expect(ticker).ToNot(BeZero())
})

It("returns no ticker for users", func() {
ticker, err := store.FindTickerByUserAndID(User{ID: 2}, ticker.ID)
Expect(err).ToNot(HaveOccurred())
Expect(ticker).To(BeZero())
})
})

Describe("FindTickersByIDs", func() {
It("returns the tickers with the given ids", func() {
tickers, err := store.FindTickersByIDs([]int{1, 2})
Expand Down
2 changes: 2 additions & 0 deletions internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type Storage interface {
DeleteTickerUser(ticker *Ticker, user *User) error
AddTickerUser(ticker *Ticker, user *User) error
FindTickers(opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error)
FindTickersByUser(user User, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error)
FindTickerByUserAndID(user User, id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error)
FindTickersByIDs(ids []int, opts ...func(*gorm.DB) *gorm.DB) ([]Ticker, error)
FindTickerByDomain(domain string, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error)
FindTickerByID(id int, opts ...func(*gorm.DB) *gorm.DB) (Ticker, error)
Expand Down

0 comments on commit 2479c07

Please sign in to comment.