From ea2a1d03a0df3734fea07e80def5f8be4c35add7 Mon Sep 17 00:00:00 2001 From: Wessie Date: Wed, 3 Apr 2024 23:06:43 +0100 Subject: [PATCH] website: change ServeFileFS to ServeContent Due to how io/fs.FS expects to have a "root" when creating the FS, mapping it from an afero.OsFs does not work. So instead we open the file ourselves and pass it to ServeContent instead. --- util/secret/secret.go | 2 +- util/util.go | 54 ++++++++++++++ website/admin/pending.go | 15 ++-- website/api/v1/song.go | 22 +++++- website/api/v1/song_test.go | 136 ++++++++++++++++++++++++++++++++++++ website/public/faves.go | 3 +- 6 files changed, 223 insertions(+), 9 deletions(-) create mode 100644 website/api/v1/song_test.go diff --git a/util/secret/secret.go b/util/secret/secret.go index f4719230..1dba4b9f 100644 --- a/util/secret/secret.go +++ b/util/secret/secret.go @@ -10,7 +10,7 @@ import ( const keySize = 256 func NewSecretWithKey(length int, key []byte) Secret { - length = min(length, MinLength) + length = max(length, MinLength) return &secret{length, key, time.Now} } diff --git a/util/util.go b/util/util.go index 632a322e..44425707 100644 --- a/util/util.go +++ b/util/util.go @@ -2,9 +2,12 @@ package util import ( "context" + "fmt" + "mime" "net/http" "net/url" "path/filepath" + "strings" "sync/atomic" "time" @@ -12,6 +15,17 @@ import ( "github.com/rs/zerolog" ) +func init() { + must := func(err error) { + if err != nil { + panic(err) + } + } + must(mime.AddExtensionType(".opus", "audio/ogg")) + must(mime.AddExtensionType(".mp3", "audio/mpeg")) + must(mime.AddExtensionType(".flac", "audio/flac")) +} + // IsHTMX checks if a request was made by HTMX through the Hx-Request header func IsHTMX(r *http.Request) bool { return r.Header.Get("Hx-Request") == "true" @@ -37,6 +51,46 @@ func AbsolutePath(dir string, path string) string { return filepath.Join(dir, path) } +const headerContentDisposition = "Content-Disposition" + +func AddContentDispositionSong(w http.ResponseWriter, metadata, filename string) { + filename = metadata + filepath.Ext(filename) + AddContentDisposition(w, filename) +} + +var headerReplacer = strings.NewReplacer( + "\r", "", "\n", "", // newlines + "+", "%20", // spaces from the query escape +) + +var rfc2616 = strings.NewReplacer( + `\`, `\\`, // escape character + `"`, `\"`, // quotes +) + +func AddContentDisposition(w http.ResponseWriter, filename string) { + disposition := "attachment; " + makeHeader(filename) + w.Header().Set(headerContentDisposition, disposition) + // also add a content-type header if we can get a mimetype + ct := mime.TypeByExtension(filepath.Ext(filename)) + if ct != "" { + w.Header().Set("Content-Type", ct) + } +} + +func makeHeader(filename string) string { + // For some reason Go doesn't provide access to the internal percent + // encoding routines, meaning we have to do this to get a fully + // percent-encoded string including spaces as %20. + encoded := url.QueryEscape(filename) + encoded = headerReplacer.Replace(encoded) + // RFC2616 quoted string encoded + escaped := rfc2616.Replace(filename) + // RFC5987 regular and extended header value encoding + disposition := fmt.Sprintf(`filename="%s"; filename*=UTF-8''%s`, escaped, encoded) + return disposition +} + type StreamFn[T any] func(context.Context) (eventstream.Stream[T], error) type StreamCallbackFn[T any] func(context.Context, T) diff --git a/website/admin/pending.go b/website/admin/pending.go index 63883809..92bfb41b 100644 --- a/website/admin/pending.go +++ b/website/admin/pending.go @@ -18,7 +18,6 @@ import ( "github.com/go-chi/chi/v5" "github.com/rs/xid" "github.com/rs/zerolog/hlog" - "github.com/spf13/afero" ) type PendingInput struct { @@ -85,8 +84,15 @@ func (s *State) GetPendingSong(w http.ResponseWriter, r *http.Request) { // if we want the audio file, send that back if r.FormValue("spectrum") == "" { - w.Header().Set("Content-Disposition", "attachment") - http.ServeFileFS(w, r, afero.NewIOFS(s.FS), path) + f, err := s.FS.Open(path) + if err != nil { + hlog.FromRequest(r).Error().Err(err).Msg("fs failure") + return + } + defer f.Close() + + util.AddContentDispositionSong(w, song.Metadata(), song.FilePath) + http.ServeContent(w, r, "", time.Now(), f) return } @@ -98,7 +104,8 @@ func (s *State) GetPendingSong(w http.ResponseWriter, r *http.Request) { } defer os.Remove(specPath) - http.ServeFileFS(w, r, afero.NewIOFS(s.FS), specPath) + // TODO: use in-memory file + http.ServeFile(w, r, specPath) } func (s *State) GetPending(w http.ResponseWriter, r *http.Request) { diff --git a/website/api/v1/song.go b/website/api/v1/song.go index 4e8cb560..0c9188af 100644 --- a/website/api/v1/song.go +++ b/website/api/v1/song.go @@ -2,11 +2,13 @@ package v1 import ( "net/http" + "os" + "time" radio "github.com/R-a-dio/valkyrie" + "github.com/R-a-dio/valkyrie/errors" "github.com/R-a-dio/valkyrie/util" "github.com/rs/zerolog/hlog" - "github.com/spf13/afero" ) func (a *API) GetSong(w http.ResponseWriter, r *http.Request) { @@ -24,7 +26,7 @@ func (a *API) GetSong(w http.ResponseWriter, r *http.Request) { song, err := a.storage.Track(r.Context()).Get(tid) if err != nil { hlog.FromRequest(r).Error().Err(err) - http.Error(w, "invalid id", http.StatusNotFound) + http.Error(w, "unknown id", http.StatusNotFound) return } @@ -35,5 +37,19 @@ func (a *API) GetSong(w http.ResponseWriter, r *http.Request) { } path := util.AbsolutePath(a.Config.Conf().MusicPath, song.FilePath) - http.ServeFileFS(w, r, afero.NewIOFS(a.fs), path) + + f, err := a.fs.Open(path) + if err != nil { + status := http.StatusInternalServerError + if errors.IsE(err, os.ErrNotExist) { + status = http.StatusNotFound + } + hlog.FromRequest(r).Error().Err(err) + http.Error(w, http.StatusText(status), status) + return + } + defer f.Close() + + util.AddContentDispositionSong(w, song.Metadata, song.FilePath) + http.ServeContent(w, r, "", time.Now(), f) } diff --git a/website/api/v1/song_test.go b/website/api/v1/song_test.go new file mode 100644 index 00000000..a00cc9e5 --- /dev/null +++ b/website/api/v1/song_test.go @@ -0,0 +1,136 @@ +package v1 + +import ( + "context" + "encoding/base64" + "io" + "math/rand" + "net/http" + "net/url" + "os" + "path/filepath" + "testing" + + radio "github.com/R-a-dio/valkyrie" + "github.com/R-a-dio/valkyrie/config" + "github.com/R-a-dio/valkyrie/errors" + "github.com/R-a-dio/valkyrie/mocks" + "github.com/R-a-dio/valkyrie/util/secret" + "github.com/rs/zerolog" + "github.com/spf13/afero" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testAPI struct { + ctx context.Context + storageMock *mocks.StorageServiceMock + trackMock *mocks.TrackStorageMock + + GetArg radio.TrackID + GetRet *radio.Song + GetErr error +} + +func newTestAPI(t *testing.T) (*testAPI, *API) { + var api testAPI + var err error + + ctx := context.Background() + api.ctx = zerolog.New(os.Stdout).WithContext(ctx) + + cfg, err := config.LoadFile() + require.NoError(t, err) + + songSecret, err := secret.NewSecret(secret.SongLength) + require.NoError(t, err) + + api.trackMock = &mocks.TrackStorageMock{ + GetFunc: func(trackID radio.TrackID) (*radio.Song, error) { + if trackID == api.GetArg { + return api.GetRet, api.GetErr + } + return nil, errors.E(errors.SongUnknown) + }, + } + api.storageMock = &mocks.StorageServiceMock{ + TrackFunc: func(contextMoqParam context.Context) radio.TrackStorage { + return api.trackMock + }, + } + + fs := afero.NewMemMapFs() + + return &api, &API{ + storage: api.storageMock, + Config: cfg, + songSecret: songSecret, + fs: fs, + } +} + +func createSong(t *testing.T, api *API, id radio.TrackID, filename, metadata string) (*radio.Song, string) { + song := &radio.Song{ + Metadata: metadata, + DatabaseTrack: &radio.DatabaseTrack{ + TrackID: id, + FilePath: filename, + }, + } + song.Hydrate() + + data := make([]byte, 128) + _, err := io.ReadFull(rand.New(rand.NewSource(42)), data) + require.NoError(t, err) + sdata := base64.URLEncoding.EncodeToString(data) + + fullPath := filepath.Join(api.Config.Conf().MusicPath, filename) + require.NoError(t, afero.WriteFile(api.fs, fullPath, []byte(sdata), 0775)) + return song, sdata +} + +func TestGetSong(t *testing.T) { + var data string + tapi, api := newTestAPI(t) + + tapi.GetArg = 50 + tapi.GetRet, data = createSong(t, api, tapi.GetArg, "random.mp3", "testing - hello world") + + createValues := func(api *API, song *radio.Song) url.Values { + values := url.Values{} + values.Set("key", api.songSecret.Get(song.Hash[:])) + values.Set("id", song.TrackID.String()) + return values + } + t.Run("success", func(t *testing.T) { + assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", + createValues(api, tapi.GetRet), + http.StatusOK) + assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", + createValues(api, tapi.GetRet), data) + }) + t.Run("missing id", func(t *testing.T) { + values := createValues(api, tapi.GetRet) + values.Del("id") + assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", values, http.StatusBadRequest) + assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", values, "missing") + }) + t.Run("invalid id", func(t *testing.T) { + values := createValues(api, tapi.GetRet) + values.Set("id", "this is not a number") + assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", values, http.StatusBadRequest) + assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", values, "invalid") + }) + t.Run("unknown id", func(t *testing.T) { + values := createValues(api, tapi.GetRet) + values.Set("id", "100") + assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", values, http.StatusNotFound) + assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", values, "unknown") + }) + t.Run("invalid key", func(t *testing.T) { + values := createValues(api, tapi.GetRet) + values.Set("key", "randomdata") + assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", values, http.StatusUnauthorized) + assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", values, "invalid key") + }) +} diff --git a/website/public/faves.go b/website/public/faves.go index e7c21a5f..916535f8 100644 --- a/website/public/faves.go +++ b/website/public/faves.go @@ -6,6 +6,7 @@ import ( "net/http" radio "github.com/R-a-dio/valkyrie" + "github.com/R-a-dio/valkyrie/util" "github.com/R-a-dio/valkyrie/website/middleware" "github.com/R-a-dio/valkyrie/website/shared" "github.com/go-chi/chi/v5" @@ -71,7 +72,7 @@ func (s State) GetFaves(w http.ResponseWriter, r *http.Request) { // so we need to support that for old users if r.FormValue("dl") != "" { w.Header().Set("Content-Type", "application/json") - w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s_faves.json", input.Nickname)) + util.AddContentDisposition(w, fmt.Sprintf("%s_faves.json", input.Nickname)) err := json.NewEncoder(w).Encode(NewFaveDownload(input.Faves)) if err != nil { s.errorHandler(w, r, err)