From 26c84cab7117e6e03d7edc2b2727483ef3824466 Mon Sep 17 00:00:00 2001 From: Wessie Date: Tue, 30 Apr 2024 19:46:31 +0100 Subject: [PATCH] website/public: implement CSRF token on search page website/api/v1: refactor search api to reuse code a bit more so that implementing the CSRF token and requestability is possible templates: added SongPair function, which takes a radio.Song and an any to pass to another template --- templates/functions.go | 13 +++++++ website/api/v1/request.go | 71 +++++++++++++++++++++++++-------------- website/api/v1/search.go | 49 +++++++++++++-------------- website/public/search.go | 54 ++++++++++++++++++++++++----- 4 files changed, 127 insertions(+), 60 deletions(-) diff --git a/templates/functions.go b/templates/functions.go index 6920f4b6..56b940c7 100644 --- a/templates/functions.go +++ b/templates/functions.go @@ -32,6 +32,19 @@ var fnMap = map[string]any{ "CalculateSubmissionCooldown": radio.CalculateSubmissionCooldown, "AllUserPermissions": radio.AllUserPermissions, "HasField": HasField, + "SongPair": SongPair, +} + +type SongPairing struct { + *radio.Song + Data any +} + +func SongPair(song radio.Song, data any) SongPairing { + return SongPairing{ + Song: &song, + Data: data, + } } func HasField(v any, name string) bool { diff --git a/website/api/v1/request.go b/website/api/v1/request.go index 36a883a6..f9d39f8b 100644 --- a/website/api/v1/request.go +++ b/website/api/v1/request.go @@ -7,11 +7,48 @@ import ( radio "github.com/R-a-dio/valkyrie" "github.com/R-a-dio/valkyrie/errors" "github.com/R-a-dio/valkyrie/util" + "github.com/R-a-dio/valkyrie/website/public" "github.com/rs/zerolog/hlog" ) func (a *API) PostRequest(w http.ResponseWriter, r *http.Request) { - res := a.postRequest(r) + var message string + + err := a.postRequest(r) + if err != nil { + switch { + case errors.Is(errors.SongCooldown, err): + message = "Song is on cooldown" + case errors.Is(errors.UserCooldown, err): + message = "You can't request yet" + case errors.Is(errors.StreamerNoRequests, err): + message = "Requests are currently disabled" + case errors.Is(errors.InvalidForm, err): + message = "Invalid form in request" + case errors.Is(errors.SongUnknown, err): + message = "Unknown song, how did you get here?" + default: + message = "something broke, report to IRC." + hlog.FromRequest(r).Error().Err(err).Msg("request failed") + } + } + + input, err := public.NewSearchInput( + a.Search, + a.storage.Request(r.Context()), + r, + time.Duration(a.Config.Conf().UserRequestDelay), + ) + if err != nil { + hlog.FromRequest(r).Error().Err(err).Msg("") + return + } + if message == "" { + input.Message = "Thank you for requesting" + input.IsError = true + } else { + input.Message = message + } if !util.IsHTMX(r) { // for non-htmx users we redirect them back to where they came from @@ -21,7 +58,7 @@ func (a *API) PostRequest(w http.ResponseWriter, r *http.Request) { return } - err := a.Templates.Execute(w, r, &res) + err = a.Templates.Execute(w, r, input) if err != nil { hlog.FromRequest(r).Error().Err(err).Msg("template failure") return @@ -43,43 +80,25 @@ func (RequestInput) TemplateName() string { return "request-response" } -func (a *API) postRequest(r *http.Request) RequestInput { - var res RequestInput +func (a *API) postRequest(r *http.Request) error { + const op errors.Op = "website/api/v1/API.postRequest" ctx := r.Context() tid, err := radio.ParseTrackID(r.FormValue("trackid")) if err != nil { - hlog.FromRequest(r).Error().Err(err).Msg("invalid request form") - res.Error = "Invalid Request" - return res + return errors.E(op, err, errors.InvalidForm) } song, err := a.storage.Track(ctx).Get(tid) if err != nil { - hlog.FromRequest(r).Error().Err(err).Msg("invalid request form") - res.Error = "Unknown Song" - return res + return errors.E(op, err, errors.SongUnknown) } - res.Song = *song err = a.streamer.RequestSong(ctx, *song, r.RemoteAddr) if err != nil { - switch { - case errors.Is(errors.SongCooldown, err): - res.Error = "Song is on cooldown" - case errors.Is(errors.UserCooldown, err): - res.Error = "You can't request yet" - case errors.Is(errors.StreamerNoRequests, err): - res.Error = "Requests are disabled" - default: - res.Error = "something broke, report to IRC." - hlog.FromRequest(r).Error().Err(err).Msg("request failed") - } - return res + return err } - res.Song.LastRequested = time.Now() - res.Message = "Thanks for requesting" - return res + return nil } diff --git a/website/api/v1/search.go b/website/api/v1/search.go index 5a9bde0f..e25b59c1 100644 --- a/website/api/v1/search.go +++ b/website/api/v1/search.go @@ -2,51 +2,50 @@ package v1 import ( "net/http" + "time" - radio "github.com/R-a-dio/valkyrie" "github.com/R-a-dio/valkyrie/errors" + "github.com/R-a-dio/valkyrie/website/public" "github.com/rs/zerolog/hlog" ) -type SearchInput struct { - Result *radio.SearchResult -} - -func (SearchInput) TemplateBundle() string { - return "search" -} - -func (SearchInput) TemplateName() string { - return "search-api" -} +const searchPageSize = 50 func (a *API) SearchHTML(w http.ResponseWriter, r *http.Request) { const op errors.Op = "website/api/v1.API.SearchHTML" - err := r.ParseForm() + input, err := public.NewSearchSharedInput( + a.Search, + a.storage.Request(r.Context()), + r, + time.Duration(a.Config.Conf().UserRequestDelay), + searchPageSize, + ) if err != nil { hlog.FromRequest(r).Error().Err(err) return } - res, err := a.Search.Search(r.Context(), r.Form.Get("q"), 50, 0) - if err != nil { - err = errors.E(op, err, errors.InternalServer) - hlog.FromRequest(r).Error().Err(err).Msg("database error") + if len(input.Songs) == 0 { return } - input := SearchInput{ - Result: res, - } - if input.Result.TotalHits == 0 { - return - } - - err = a.Templates.Execute(w, r, input) + err = a.Templates.Execute(w, r, SearchInput{*input}) if err != nil { err = errors.E(op, err, errors.InternalServer) hlog.FromRequest(r).Error().Err(err).Msg("template error") return } } + +type SearchInput struct { + public.SearchSharedInput +} + +func (SearchInput) TemplateName() string { + return "search-api" +} + +func (SearchInput) TemplateBundle() string { + return "search" +} diff --git a/website/public/search.go b/website/public/search.go index 13e87a4d..32ecf7cc 100644 --- a/website/public/search.go +++ b/website/public/search.go @@ -1,6 +1,8 @@ package public import ( + "fmt" + "html/template" "net/http" "time" @@ -8,13 +10,43 @@ import ( "github.com/R-a-dio/valkyrie/errors" "github.com/R-a-dio/valkyrie/website/middleware" "github.com/R-a-dio/valkyrie/website/shared" + "github.com/gorilla/csrf" ) const searchPageSize = 20 type SearchInput struct { middleware.Input + SearchSharedInput + CSRFLegacyFix template.HTML + // IsError indicates if the message given is an error + IsError bool + // Message to show at the top of the page + Message string +} + +func (SearchInput) TemplateBundle() string { + return "search" +} + +func NewSearchInput(s radio.SearchService, rs radio.RequestStorage, r *http.Request, requestDelay time.Duration) (*SearchInput, error) { + const op errors.Op = "website/public.NewSearchInput" + + sharedInput, err := NewSearchSharedInput(s, rs, r, requestDelay, searchPageSize) + if err != nil { + return nil, errors.E(op, err) + } + + return &SearchInput{ + Input: middleware.InputFromRequest(r), + SearchSharedInput: *sharedInput, + CSRFLegacyFix: csrfLegacyFix(r), + }, nil +} + +type SearchSharedInput struct { + CSRFTokenInput template.HTML Query string Songs []radio.Song CanRequest bool @@ -22,11 +54,11 @@ type SearchInput struct { Page *shared.Pagination } -func NewSearchInput(s radio.SearchService, rs radio.RequestStorage, r *http.Request, requestDelay time.Duration) (*SearchInput, error) { - const op errors.Op = "website/public.NewSearchInput" +func NewSearchSharedInput(s radio.SearchService, rs radio.RequestStorage, r *http.Request, requestDelay time.Duration, pageSize int64) (*SearchSharedInput, error) { + const op errors.Op = "website/public.NewSearchSharedInput" ctx := r.Context() - page, offset, err := getPageOffset(r, searchPageSize) + page, offset, err := getPageOffset(r, pageSize) if err != nil { return nil, errors.E(op, err) } @@ -46,8 +78,8 @@ func NewSearchInput(s radio.SearchService, rs radio.RequestStorage, r *http.Requ cd, ok := radio.CalculateCooldown(requestDelay, lastRequest) - return &SearchInput{ - Input: middleware.InputFromRequest(r), + return &SearchSharedInput{ + CSRFTokenInput: csrf.TemplateField(r), Query: query, Songs: searchResult.Songs, CanRequest: ok, @@ -59,10 +91,6 @@ func NewSearchInput(s radio.SearchService, rs radio.RequestStorage, r *http.Requ }, nil } -func (SearchInput) TemplateBundle() string { - return "search" -} - func (s State) GetSearch(w http.ResponseWriter, r *http.Request) { input, err := NewSearchInput( s.Search, @@ -81,3 +109,11 @@ func (s State) GetSearch(w http.ResponseWriter, r *http.Request) { return } } + +func csrfLegacyFix(r *http.Request) template.HTML { + return template.HTML(fmt.Sprintf(` + + `, csrf.TemplateField(r))) +}