Skip to content

Commit

Permalink
website/public: implement CSRF token on search page
Browse files Browse the repository at this point in the history
website/api/v1: refactor search api to reuse code a bit more so that
implementing the CSRF token and requestability is possible

templates: added SongPair function, which takes a radio.Song and an any
to pass to another template
  • Loading branch information
Wessie committed Apr 30, 2024
1 parent a80a5e2 commit 26c84ca
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 60 deletions.
13 changes: 13 additions & 0 deletions templates/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,19 @@ var fnMap = map[string]any{
"CalculateSubmissionCooldown": radio.CalculateSubmissionCooldown,
"AllUserPermissions": radio.AllUserPermissions,
"HasField": HasField,
"SongPair": SongPair,
}

type SongPairing struct {
*radio.Song
Data any
}

func SongPair(song radio.Song, data any) SongPairing {
return SongPairing{
Song: &song,
Data: data,
}
}

func HasField(v any, name string) bool {
Expand Down
71 changes: 45 additions & 26 deletions website/api/v1/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,48 @@ import (
radio "github.com/R-a-dio/valkyrie"
"github.com/R-a-dio/valkyrie/errors"
"github.com/R-a-dio/valkyrie/util"
"github.com/R-a-dio/valkyrie/website/public"
"github.com/rs/zerolog/hlog"
)

func (a *API) PostRequest(w http.ResponseWriter, r *http.Request) {
res := a.postRequest(r)
var message string

err := a.postRequest(r)
if err != nil {
switch {
case errors.Is(errors.SongCooldown, err):
message = "Song is on cooldown"
case errors.Is(errors.UserCooldown, err):
message = "You can't request yet"
case errors.Is(errors.StreamerNoRequests, err):
message = "Requests are currently disabled"
case errors.Is(errors.InvalidForm, err):
message = "Invalid form in request"
case errors.Is(errors.SongUnknown, err):
message = "Unknown song, how did you get here?"
default:
message = "something broke, report to IRC."
hlog.FromRequest(r).Error().Err(err).Msg("request failed")
}
}

input, err := public.NewSearchInput(
a.Search,
a.storage.Request(r.Context()),
r,
time.Duration(a.Config.Conf().UserRequestDelay),
)
if err != nil {
hlog.FromRequest(r).Error().Err(err).Msg("")
return
}
if message == "" {
input.Message = "Thank you for requesting"
input.IsError = true
} else {
input.Message = message
}

if !util.IsHTMX(r) {
// for non-htmx users we redirect them back to where they came from
Expand All @@ -21,7 +58,7 @@ func (a *API) PostRequest(w http.ResponseWriter, r *http.Request) {
return
}

err := a.Templates.Execute(w, r, &res)
err = a.Templates.Execute(w, r, input)
if err != nil {
hlog.FromRequest(r).Error().Err(err).Msg("template failure")
return
Expand All @@ -43,43 +80,25 @@ func (RequestInput) TemplateName() string {
return "request-response"
}

func (a *API) postRequest(r *http.Request) RequestInput {
var res RequestInput
func (a *API) postRequest(r *http.Request) error {
const op errors.Op = "website/api/v1/API.postRequest"

ctx := r.Context()

tid, err := radio.ParseTrackID(r.FormValue("trackid"))
if err != nil {
hlog.FromRequest(r).Error().Err(err).Msg("invalid request form")
res.Error = "Invalid Request"
return res
return errors.E(op, err, errors.InvalidForm)
}

song, err := a.storage.Track(ctx).Get(tid)
if err != nil {
hlog.FromRequest(r).Error().Err(err).Msg("invalid request form")
res.Error = "Unknown Song"
return res
return errors.E(op, err, errors.SongUnknown)
}
res.Song = *song

err = a.streamer.RequestSong(ctx, *song, r.RemoteAddr)
if err != nil {
switch {
case errors.Is(errors.SongCooldown, err):
res.Error = "Song is on cooldown"
case errors.Is(errors.UserCooldown, err):
res.Error = "You can't request yet"
case errors.Is(errors.StreamerNoRequests, err):
res.Error = "Requests are disabled"
default:
res.Error = "something broke, report to IRC."
hlog.FromRequest(r).Error().Err(err).Msg("request failed")
}
return res
return err
}

res.Song.LastRequested = time.Now()
res.Message = "Thanks for requesting"
return res
return nil
}
49 changes: 24 additions & 25 deletions website/api/v1/search.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,50 @@ package v1

import (
"net/http"
"time"

radio "github.com/R-a-dio/valkyrie"
"github.com/R-a-dio/valkyrie/errors"
"github.com/R-a-dio/valkyrie/website/public"
"github.com/rs/zerolog/hlog"
)

type SearchInput struct {
Result *radio.SearchResult
}

func (SearchInput) TemplateBundle() string {
return "search"
}

func (SearchInput) TemplateName() string {
return "search-api"
}
const searchPageSize = 50

func (a *API) SearchHTML(w http.ResponseWriter, r *http.Request) {
const op errors.Op = "website/api/v1.API.SearchHTML"

err := r.ParseForm()
input, err := public.NewSearchSharedInput(
a.Search,
a.storage.Request(r.Context()),
r,
time.Duration(a.Config.Conf().UserRequestDelay),
searchPageSize,
)
if err != nil {
hlog.FromRequest(r).Error().Err(err)
return
}

res, err := a.Search.Search(r.Context(), r.Form.Get("q"), 50, 0)
if err != nil {
err = errors.E(op, err, errors.InternalServer)
hlog.FromRequest(r).Error().Err(err).Msg("database error")
if len(input.Songs) == 0 {
return
}
input := SearchInput{
Result: res,
}

if input.Result.TotalHits == 0 {
return
}

err = a.Templates.Execute(w, r, input)
err = a.Templates.Execute(w, r, SearchInput{*input})
if err != nil {
err = errors.E(op, err, errors.InternalServer)
hlog.FromRequest(r).Error().Err(err).Msg("template error")
return
}
}

type SearchInput struct {
public.SearchSharedInput
}

func (SearchInput) TemplateName() string {
return "search-api"
}

func (SearchInput) TemplateBundle() string {
return "search"
}
54 changes: 45 additions & 9 deletions website/public/search.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,64 @@
package public

import (
"fmt"
"html/template"
"net/http"
"time"

radio "github.com/R-a-dio/valkyrie"
"github.com/R-a-dio/valkyrie/errors"
"github.com/R-a-dio/valkyrie/website/middleware"
"github.com/R-a-dio/valkyrie/website/shared"
"github.com/gorilla/csrf"
)

const searchPageSize = 20

type SearchInput struct {
middleware.Input
SearchSharedInput
CSRFLegacyFix template.HTML

// IsError indicates if the message given is an error
IsError bool
// Message to show at the top of the page
Message string
}

func (SearchInput) TemplateBundle() string {
return "search"
}

func NewSearchInput(s radio.SearchService, rs radio.RequestStorage, r *http.Request, requestDelay time.Duration) (*SearchInput, error) {
const op errors.Op = "website/public.NewSearchInput"

sharedInput, err := NewSearchSharedInput(s, rs, r, requestDelay, searchPageSize)
if err != nil {
return nil, errors.E(op, err)
}

return &SearchInput{
Input: middleware.InputFromRequest(r),
SearchSharedInput: *sharedInput,
CSRFLegacyFix: csrfLegacyFix(r),
}, nil
}

type SearchSharedInput struct {
CSRFTokenInput template.HTML
Query string
Songs []radio.Song
CanRequest bool
RequestCooldown time.Duration
Page *shared.Pagination
}

func NewSearchInput(s radio.SearchService, rs radio.RequestStorage, r *http.Request, requestDelay time.Duration) (*SearchInput, error) {
const op errors.Op = "website/public.NewSearchInput"
func NewSearchSharedInput(s radio.SearchService, rs radio.RequestStorage, r *http.Request, requestDelay time.Duration, pageSize int64) (*SearchSharedInput, error) {
const op errors.Op = "website/public.NewSearchSharedInput"
ctx := r.Context()

page, offset, err := getPageOffset(r, searchPageSize)
page, offset, err := getPageOffset(r, pageSize)
if err != nil {
return nil, errors.E(op, err)
}
Expand All @@ -46,8 +78,8 @@ func NewSearchInput(s radio.SearchService, rs radio.RequestStorage, r *http.Requ

cd, ok := radio.CalculateCooldown(requestDelay, lastRequest)

return &SearchInput{
Input: middleware.InputFromRequest(r),
return &SearchSharedInput{
CSRFTokenInput: csrf.TemplateField(r),
Query: query,
Songs: searchResult.Songs,
CanRequest: ok,
Expand All @@ -59,10 +91,6 @@ func NewSearchInput(s radio.SearchService, rs radio.RequestStorage, r *http.Requ
}, nil
}

func (SearchInput) TemplateBundle() string {
return "search"
}

func (s State) GetSearch(w http.ResponseWriter, r *http.Request) {
input, err := NewSearchInput(
s.Search,
Expand All @@ -81,3 +109,11 @@ func (s State) GetSearch(w http.ResponseWriter, r *http.Request) {
return
}
}

func csrfLegacyFix(r *http.Request) template.HTML {
return template.HTML(fmt.Sprintf(`
<!--
<form %s
-->
`, csrf.TemplateField(r)))
}

0 comments on commit 26c84ca

Please sign in to comment.