diff --git a/generate.go b/generate.go index ff77a943..1ae8ca29 100644 --- a/generate.go +++ b/generate.go @@ -1,6 +1,6 @@ package radio //go:generate go generate ./rpc/generate.go -//go:generate moq -out mocks/radio.gen.go -pkg mocks . StorageService StorageTx TrackStorage SubmissionStorage UserStorage ManagerService +//go:generate moq -out mocks/radio.gen.go -pkg mocks . StorageService StorageTx TrackStorage SubmissionStorage UserStorage ManagerService SearchService RequestStorage //go:generate moq -out mocks/templates.gen.go -pkg mocks ./templates/ Executor TemplateSelectable //go:generate moq -out mocks/util.gen.go -pkg mocks ./mocks/ FS File FileInfo diff --git a/mocks/radio.gen.go b/mocks/radio.gen.go index ebcb30b6..aebc2269 100644 --- a/mocks/radio.gen.go +++ b/mocks/radio.gen.go @@ -3004,3 +3004,297 @@ func (mock *ManagerServiceMock) UpdateUserCalls() []struct { mock.lockUpdateUser.RUnlock() return calls } + +// Ensure, that SearchServiceMock does implement radio.SearchService. +// If this is not the case, regenerate this file with moq. +var _ radio.SearchService = &SearchServiceMock{} + +// SearchServiceMock is a mock implementation of radio.SearchService. +// +// func TestSomethingThatUsesSearchService(t *testing.T) { +// +// // make and configure a mocked radio.SearchService +// mockedSearchService := &SearchServiceMock{ +// DeleteFunc: func(contextMoqParam context.Context, trackIDs ...radio.TrackID) error { +// panic("mock out the Delete method") +// }, +// SearchFunc: func(ctx context.Context, query string, limit int64, offset int64) (*radio.SearchResult, error) { +// panic("mock out the Search method") +// }, +// UpdateFunc: func(contextMoqParam context.Context, songs ...radio.Song) error { +// panic("mock out the Update method") +// }, +// } +// +// // use mockedSearchService in code that requires radio.SearchService +// // and then make assertions. +// +// } +type SearchServiceMock struct { + // DeleteFunc mocks the Delete method. + DeleteFunc func(contextMoqParam context.Context, trackIDs ...radio.TrackID) error + + // SearchFunc mocks the Search method. + SearchFunc func(ctx context.Context, query string, limit int64, offset int64) (*radio.SearchResult, error) + + // UpdateFunc mocks the Update method. + UpdateFunc func(contextMoqParam context.Context, songs ...radio.Song) error + + // calls tracks calls to the methods. + calls struct { + // Delete holds details about calls to the Delete method. + Delete []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // TrackIDs is the trackIDs argument value. + TrackIDs []radio.TrackID + } + // Search holds details about calls to the Search method. + Search []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Query is the query argument value. + Query string + // Limit is the limit argument value. + Limit int64 + // Offset is the offset argument value. + Offset int64 + } + // Update holds details about calls to the Update method. + Update []struct { + // ContextMoqParam is the contextMoqParam argument value. + ContextMoqParam context.Context + // Songs is the songs argument value. + Songs []radio.Song + } + } + lockDelete sync.RWMutex + lockSearch sync.RWMutex + lockUpdate sync.RWMutex +} + +// Delete calls DeleteFunc. +func (mock *SearchServiceMock) Delete(contextMoqParam context.Context, trackIDs ...radio.TrackID) error { + if mock.DeleteFunc == nil { + panic("SearchServiceMock.DeleteFunc: method is nil but SearchService.Delete was just called") + } + callInfo := struct { + ContextMoqParam context.Context + TrackIDs []radio.TrackID + }{ + ContextMoqParam: contextMoqParam, + TrackIDs: trackIDs, + } + mock.lockDelete.Lock() + mock.calls.Delete = append(mock.calls.Delete, callInfo) + mock.lockDelete.Unlock() + return mock.DeleteFunc(contextMoqParam, trackIDs...) +} + +// DeleteCalls gets all the calls that were made to Delete. +// Check the length with: +// +// len(mockedSearchService.DeleteCalls()) +func (mock *SearchServiceMock) DeleteCalls() []struct { + ContextMoqParam context.Context + TrackIDs []radio.TrackID +} { + var calls []struct { + ContextMoqParam context.Context + TrackIDs []radio.TrackID + } + mock.lockDelete.RLock() + calls = mock.calls.Delete + mock.lockDelete.RUnlock() + return calls +} + +// Search calls SearchFunc. +func (mock *SearchServiceMock) Search(ctx context.Context, query string, limit int64, offset int64) (*radio.SearchResult, error) { + if mock.SearchFunc == nil { + panic("SearchServiceMock.SearchFunc: method is nil but SearchService.Search was just called") + } + callInfo := struct { + Ctx context.Context + Query string + Limit int64 + Offset int64 + }{ + Ctx: ctx, + Query: query, + Limit: limit, + Offset: offset, + } + mock.lockSearch.Lock() + mock.calls.Search = append(mock.calls.Search, callInfo) + mock.lockSearch.Unlock() + return mock.SearchFunc(ctx, query, limit, offset) +} + +// SearchCalls gets all the calls that were made to Search. +// Check the length with: +// +// len(mockedSearchService.SearchCalls()) +func (mock *SearchServiceMock) SearchCalls() []struct { + Ctx context.Context + Query string + Limit int64 + Offset int64 +} { + var calls []struct { + Ctx context.Context + Query string + Limit int64 + Offset int64 + } + mock.lockSearch.RLock() + calls = mock.calls.Search + mock.lockSearch.RUnlock() + return calls +} + +// Update calls UpdateFunc. +func (mock *SearchServiceMock) Update(contextMoqParam context.Context, songs ...radio.Song) error { + if mock.UpdateFunc == nil { + panic("SearchServiceMock.UpdateFunc: method is nil but SearchService.Update was just called") + } + callInfo := struct { + ContextMoqParam context.Context + Songs []radio.Song + }{ + ContextMoqParam: contextMoqParam, + Songs: songs, + } + mock.lockUpdate.Lock() + mock.calls.Update = append(mock.calls.Update, callInfo) + mock.lockUpdate.Unlock() + return mock.UpdateFunc(contextMoqParam, songs...) +} + +// UpdateCalls gets all the calls that were made to Update. +// Check the length with: +// +// len(mockedSearchService.UpdateCalls()) +func (mock *SearchServiceMock) UpdateCalls() []struct { + ContextMoqParam context.Context + Songs []radio.Song +} { + var calls []struct { + ContextMoqParam context.Context + Songs []radio.Song + } + mock.lockUpdate.RLock() + calls = mock.calls.Update + mock.lockUpdate.RUnlock() + return calls +} + +// Ensure, that RequestStorageMock does implement radio.RequestStorage. +// If this is not the case, regenerate this file with moq. +var _ radio.RequestStorage = &RequestStorageMock{} + +// RequestStorageMock is a mock implementation of radio.RequestStorage. +// +// func TestSomethingThatUsesRequestStorage(t *testing.T) { +// +// // make and configure a mocked radio.RequestStorage +// mockedRequestStorage := &RequestStorageMock{ +// LastRequestFunc: func(identifier string) (time.Time, error) { +// panic("mock out the LastRequest method") +// }, +// UpdateLastRequestFunc: func(identifier string) error { +// panic("mock out the UpdateLastRequest method") +// }, +// } +// +// // use mockedRequestStorage in code that requires radio.RequestStorage +// // and then make assertions. +// +// } +type RequestStorageMock struct { + // LastRequestFunc mocks the LastRequest method. + LastRequestFunc func(identifier string) (time.Time, error) + + // UpdateLastRequestFunc mocks the UpdateLastRequest method. + UpdateLastRequestFunc func(identifier string) error + + // calls tracks calls to the methods. + calls struct { + // LastRequest holds details about calls to the LastRequest method. + LastRequest []struct { + // Identifier is the identifier argument value. + Identifier string + } + // UpdateLastRequest holds details about calls to the UpdateLastRequest method. + UpdateLastRequest []struct { + // Identifier is the identifier argument value. + Identifier string + } + } + lockLastRequest sync.RWMutex + lockUpdateLastRequest sync.RWMutex +} + +// LastRequest calls LastRequestFunc. +func (mock *RequestStorageMock) LastRequest(identifier string) (time.Time, error) { + if mock.LastRequestFunc == nil { + panic("RequestStorageMock.LastRequestFunc: method is nil but RequestStorage.LastRequest was just called") + } + callInfo := struct { + Identifier string + }{ + Identifier: identifier, + } + mock.lockLastRequest.Lock() + mock.calls.LastRequest = append(mock.calls.LastRequest, callInfo) + mock.lockLastRequest.Unlock() + return mock.LastRequestFunc(identifier) +} + +// LastRequestCalls gets all the calls that were made to LastRequest. +// Check the length with: +// +// len(mockedRequestStorage.LastRequestCalls()) +func (mock *RequestStorageMock) LastRequestCalls() []struct { + Identifier string +} { + var calls []struct { + Identifier string + } + mock.lockLastRequest.RLock() + calls = mock.calls.LastRequest + mock.lockLastRequest.RUnlock() + return calls +} + +// UpdateLastRequest calls UpdateLastRequestFunc. +func (mock *RequestStorageMock) UpdateLastRequest(identifier string) error { + if mock.UpdateLastRequestFunc == nil { + panic("RequestStorageMock.UpdateLastRequestFunc: method is nil but RequestStorage.UpdateLastRequest was just called") + } + callInfo := struct { + Identifier string + }{ + Identifier: identifier, + } + mock.lockUpdateLastRequest.Lock() + mock.calls.UpdateLastRequest = append(mock.calls.UpdateLastRequest, callInfo) + mock.lockUpdateLastRequest.Unlock() + return mock.UpdateLastRequestFunc(identifier) +} + +// UpdateLastRequestCalls gets all the calls that were made to UpdateLastRequest. +// Check the length with: +// +// len(mockedRequestStorage.UpdateLastRequestCalls()) +func (mock *RequestStorageMock) UpdateLastRequestCalls() []struct { + Identifier string +} { + var calls []struct { + Identifier string + } + mock.lockUpdateLastRequest.RLock() + calls = mock.calls.UpdateLastRequest + mock.lockUpdateLastRequest.RUnlock() + return calls +} diff --git a/website/public/search.go b/website/public/search.go index 32ecf7cc..47d414e7 100644 --- a/website/public/search.go +++ b/website/public/search.go @@ -78,6 +78,18 @@ func NewSearchSharedInput(s radio.SearchService, rs radio.RequestStorage, r *htt cd, ok := radio.CalculateCooldown(requestDelay, lastRequest) + // we also use this input if we're making a request, in which case our url + // will be something other than /search that we can't use for the pagination + // logic. We can detect this by looking for a trackid argument and changing + // the url to the expected /search path + uri := r.URL + if r.URL.Query().Has("trackid") { + query := uri.Query() + query.Del("trackid") + uri.RawQuery = query.Encode() + uri.Path = "/search" + } + return &SearchSharedInput{ CSRFTokenInput: csrf.TemplateField(r), Query: query, @@ -86,7 +98,7 @@ func NewSearchSharedInput(s radio.SearchService, rs radio.RequestStorage, r *htt RequestCooldown: cd, Page: shared.NewPagination( page, shared.PageCount(int64(searchResult.TotalHits), searchPageSize), - r.URL, + uri, ), }, nil } diff --git a/website/public/search_test.go b/website/public/search_test.go new file mode 100644 index 00000000..860932cc --- /dev/null +++ b/website/public/search_test.go @@ -0,0 +1,92 @@ +package public + +import ( + "context" + "net/http" + "net/http/httptest" + "regexp" + "strings" + "testing" + "time" + + radio "github.com/R-a-dio/valkyrie" + "github.com/R-a-dio/valkyrie/mocks" + "github.com/R-a-dio/valkyrie/util/secret" + "github.com/go-chi/chi/v5" + "github.com/gorilla/csrf" + "github.com/jxskiss/base62" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewSearchSharedInputURLFix(t *testing.T) { + ss := &mocks.SearchServiceMock{ + SearchFunc: func(ctx context.Context, query string, limit, offset int64) (*radio.SearchResult, error) { + return &radio.SearchResult{}, nil + }, + } + rs := &mocks.RequestStorageMock{ + LastRequestFunc: func(identifier string) (time.Time, error) { + return time.Now().Add(-time.Hour * 12), nil + }, + } + + r := httptest.NewRequest(http.MethodPost, "/v1/request?page=5&q=test&trackid=100", nil) + + input, err := NewSearchSharedInput(ss, rs, r, time.Hour, searchPageSize) + require.NoError(t, err) + require.NotNil(t, input) + + assert.Contains(t, "/search", input.Page.BaseURL()) + assert.NotContains(t, "/v1/request", input.Page.BaseURL()) +} + +func TestCSRFLegacyFix(t *testing.T) { + // this is the regular expression the old (and new) android apps use + // to retrieve the csrf token from the search page + re := regexp.MustCompile("value=\"(\\w+)\"") + + // setup a chi router with the csrf middleware + key, err := secret.NewKey(32) + require.NoError(t, err) + + router := chi.NewRouter() + router.Use(csrf.Protect(key, + csrf.Secure(false), + // regexp above doesn't allow non-alphanumeric which means base64 + // is out of the running, use base62 instead which is alphanumeric + csrf.Encoding(base62.StdEncoding), + )) + router.Get("/search", func(w http.ResponseWriter, r *http.Request) { + // generate the "legacy fix" which is just a HTML comment with + // contents that the app expects, which is a line starting with + // '