diff --git a/internal/api/tickers.go b/internal/api/tickers.go index 527b404f..a7e2bc1e 100644 --- a/internal/api/tickers.go +++ b/internal/api/tickers.go @@ -252,6 +252,10 @@ func (h *handler) DeleteTicker(c *gin.Context) { if err != nil { log.WithError(err).Error("failed to delete message for ticker") } + err = h.storage.DeleteUploadsByTicker(ticker) + if err != nil { + log.WithError(err).Error("failed to delete uploads for ticker") + } err = h.storage.DeleteTicker(ticker) if err != nil { c.JSON(http.StatusNotFound, response.ErrorResponse(response.CodeNotFound, response.StorageError)) diff --git a/internal/api/tickers_test.go b/internal/api/tickers_test.go index 91f97824..61d7beed 100644 --- a/internal/api/tickers_test.go +++ b/internal/api/tickers_test.go @@ -622,6 +622,7 @@ func TestDeleteTickerStorageError(t *testing.T) { c.Set("ticker", storage.Ticker{}) s := &storage.MockStorage{} s.On("DeleteMessages", mock.Anything).Return(errors.New("storage error")) + s.On("DeleteUploadsByTicker", mock.Anything).Return(errors.New("storage error")) s.On("DeleteTicker", mock.Anything).Return(errors.New("storage error")) h := handler{ storage: s, @@ -639,6 +640,7 @@ func TestDeleteTicker(t *testing.T) { c.Set("ticker", storage.Ticker{}) s := &storage.MockStorage{} s.On("DeleteMessages", mock.Anything).Return(nil) + s.On("DeleteUploadsByTicker", mock.Anything).Return(nil) s.On("DeleteTicker", mock.Anything).Return(nil) h := handler{ storage: s, diff --git a/internal/storage/mock_Storage.go b/internal/storage/mock_Storage.go index 34eccccb..810055c5 100644 --- a/internal/storage/mock_Storage.go +++ b/internal/storage/mock_Storage.go @@ -51,6 +51,20 @@ func (_m *MockStorage) CountUser() (int, error) { return r0, r1 } +// DeleteAttachmentsByMessage provides a mock function with given fields: message +func (_m *MockStorage) DeleteAttachmentsByMessage(message Message) error { + ret := _m.Called(message) + + var r0 error + if rf, ok := ret.Get(0).(func(Message) error); ok { + r0 = rf(message) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeleteMessage provides a mock function with given fields: message func (_m *MockStorage) DeleteMessage(message Message) error { ret := _m.Called(message) diff --git a/internal/storage/sql_storage.go b/internal/storage/sql_storage.go index 973a3224..351faef6 100644 --- a/internal/storage/sql_storage.go +++ b/internal/storage/sql_storage.go @@ -233,13 +233,36 @@ func (s *SqlStorage) SaveMessage(message *Message) error { } func (s *SqlStorage) DeleteMessage(message Message) error { - return s.db.Delete(&message).Error + var err error + err = s.db.Delete(&message).Error + if err != nil { + return err + } + + if len(message.Attachments) > 0 { + err = s.DeleteAttachmentsByMessage(message) + } + + return err } func (s *SqlStorage) DeleteMessages(ticker Ticker) error { - err := s.db.Where("ticker_id = ?", ticker.ID).Delete(&Message{}).Error + var msgIds []int + err := s.db.Model(&Message{}).Where("ticker_id = ?", ticker.ID).Pluck("id", &msgIds).Error + if err != nil { + return err + } - return err + err = s.db.Where("message_id IN ?", msgIds).Delete(&Attachment{}).Error + if err != nil { + return err + } + + return s.db.Where("ticker_id = ?", ticker.ID).Delete(&Message{}).Error +} + +func (s *SqlStorage) DeleteAttachmentsByMessage(message Message) error { + return s.db.Where("message_id = ?", message.ID).Delete(&Attachment{}).Error } func (s *SqlStorage) GetInactiveSettings() InactiveSettings { diff --git a/internal/storage/sql_storage_test.go b/internal/storage/sql_storage_test.go index 1c33a292..c069cfee 100644 --- a/internal/storage/sql_storage_test.go +++ b/internal/storage/sql_storage_test.go @@ -708,6 +708,32 @@ var _ = Describe("SqlStorage", func() { Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(0))) }) + + It("deletes the message with attachments", func() { + message := NewMessage() + message.Attachments = []Attachment{ + {ID: 1, MessageID: 1, UUID: "uuid", ContentType: "image/jpg", Extension: "jpg"}, + } + + err = store.SaveMessage(&message) + Expect(err).ToNot(HaveOccurred()) + + var count int64 + err = db.Model(&Message{}).Count(&count).Error + Expect(err).ToNot(HaveOccurred()) + Expect(count).To(Equal(int64(1))) + + err = store.DeleteMessage(message) + Expect(err).ToNot(HaveOccurred()) + + err = db.Model(&Message{}).Count(&count).Error + Expect(err).ToNot(HaveOccurred()) + Expect(count).To(Equal(int64(0))) + + err = db.Model(&Attachment{}).Count(&count).Error + Expect(err).ToNot(HaveOccurred()) + Expect(count).To(Equal(int64(0))) + }) }) Describe("DeleteMessages", func() { @@ -718,6 +744,9 @@ var _ = Describe("SqlStorage", func() { message := NewMessage() message.TickerID = ticker.ID + message.Attachments = []Attachment{ + {ID: 1, MessageID: 1, UUID: "uuid", ContentType: "image/jpg", Extension: "jpg"}, + } err = store.SaveMessage(&message) Expect(err).ToNot(HaveOccurred()) @@ -732,6 +761,10 @@ var _ = Describe("SqlStorage", func() { err = db.Model(&Message{}).Count(&count).Error Expect(err).ToNot(HaveOccurred()) Expect(count).To(Equal(int64(0))) + + err = db.Model(&Attachment{}).Count(&count).Error + Expect(err).ToNot(HaveOccurred()) + Expect(count).To(Equal(int64(0))) }) }) diff --git a/internal/storage/storage.go b/internal/storage/storage.go index bbcece04..5670562e 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -27,6 +27,8 @@ type Storage interface { SaveTicker(ticker *Ticker) error DeleteTicker(ticker Ticker) error SaveUpload(upload *Upload) error + FindUploadByUUID(uuid string) (Upload, error) + FindUploadsByIDs(ids []int) ([]Upload, error) DeleteUpload(upload Upload) error DeleteUploads(uploads []Upload) DeleteUploadsByTicker(ticker Ticker) error @@ -36,11 +38,10 @@ type Storage interface { SaveMessage(message *Message) error DeleteMessage(message Message) error DeleteMessages(ticker Ticker) error + DeleteAttachmentsByMessage(message Message) error GetInactiveSettings() InactiveSettings GetRefreshIntervalSettings() RefreshIntervalSettings SaveInactiveSettings(inactiveSettings InactiveSettings) error SaveRefreshIntervalSettings(refreshInterval RefreshIntervalSettings) error - FindUploadByUUID(uuid string) (Upload, error) - FindUploadsByIDs(ids []int) ([]Upload, error) UploadPath() string }