Skip to content

Commit

Permalink
website: change ServeFileFS to ServeContent
Browse files Browse the repository at this point in the history
Due to how io/fs.FS expects to have a "root" when creating the FS,
mapping it from an afero.OsFs does not work. So instead we open
the file ourselves and pass it to ServeContent instead.
  • Loading branch information
Wessie committed Apr 3, 2024
1 parent 605c245 commit ea2a1d0
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 9 deletions.
2 changes: 1 addition & 1 deletion util/secret/secret.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
const keySize = 256

func NewSecretWithKey(length int, key []byte) Secret {
length = min(length, MinLength)
length = max(length, MinLength)
return &secret{length, key, time.Now}
}

Expand Down
54 changes: 54 additions & 0 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,30 @@ package util

import (
"context"
"fmt"
"mime"
"net/http"
"net/url"
"path/filepath"
"strings"
"sync/atomic"
"time"

"github.com/R-a-dio/valkyrie/util/eventstream"
"github.com/rs/zerolog"
)

func init() {
must := func(err error) {
if err != nil {
panic(err)
}
}
must(mime.AddExtensionType(".opus", "audio/ogg"))
must(mime.AddExtensionType(".mp3", "audio/mpeg"))
must(mime.AddExtensionType(".flac", "audio/flac"))
}

// IsHTMX checks if a request was made by HTMX through the Hx-Request header
func IsHTMX(r *http.Request) bool {
return r.Header.Get("Hx-Request") == "true"
Expand All @@ -37,6 +51,46 @@ func AbsolutePath(dir string, path string) string {
return filepath.Join(dir, path)
}

const headerContentDisposition = "Content-Disposition"

func AddContentDispositionSong(w http.ResponseWriter, metadata, filename string) {
filename = metadata + filepath.Ext(filename)
AddContentDisposition(w, filename)
}

var headerReplacer = strings.NewReplacer(
"\r", "", "\n", "", // newlines
"+", "%20", // spaces from the query escape
)

var rfc2616 = strings.NewReplacer(
`\`, `\\`, // escape character
`"`, `\"`, // quotes
)

func AddContentDisposition(w http.ResponseWriter, filename string) {
disposition := "attachment; " + makeHeader(filename)
w.Header().Set(headerContentDisposition, disposition)
// also add a content-type header if we can get a mimetype
ct := mime.TypeByExtension(filepath.Ext(filename))
if ct != "" {
w.Header().Set("Content-Type", ct)
}
}

func makeHeader(filename string) string {
// For some reason Go doesn't provide access to the internal percent
// encoding routines, meaning we have to do this to get a fully
// percent-encoded string including spaces as %20.
encoded := url.QueryEscape(filename)
encoded = headerReplacer.Replace(encoded)
// RFC2616 quoted string encoded
escaped := rfc2616.Replace(filename)
// RFC5987 regular and extended header value encoding
disposition := fmt.Sprintf(`filename="%s"; filename*=UTF-8''%s`, escaped, encoded)
return disposition
}

type StreamFn[T any] func(context.Context) (eventstream.Stream[T], error)

type StreamCallbackFn[T any] func(context.Context, T)
Expand Down
15 changes: 11 additions & 4 deletions website/admin/pending.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/go-chi/chi/v5"
"github.com/rs/xid"
"github.com/rs/zerolog/hlog"
"github.com/spf13/afero"
)

type PendingInput struct {
Expand Down Expand Up @@ -85,8 +84,15 @@ func (s *State) GetPendingSong(w http.ResponseWriter, r *http.Request) {

// if we want the audio file, send that back
if r.FormValue("spectrum") == "" {
w.Header().Set("Content-Disposition", "attachment")
http.ServeFileFS(w, r, afero.NewIOFS(s.FS), path)
f, err := s.FS.Open(path)
if err != nil {
hlog.FromRequest(r).Error().Err(err).Msg("fs failure")
return
}
defer f.Close()

util.AddContentDispositionSong(w, song.Metadata(), song.FilePath)
http.ServeContent(w, r, "", time.Now(), f)
return
}

Expand All @@ -98,7 +104,8 @@ func (s *State) GetPendingSong(w http.ResponseWriter, r *http.Request) {
}
defer os.Remove(specPath)

http.ServeFileFS(w, r, afero.NewIOFS(s.FS), specPath)
// TODO: use in-memory file
http.ServeFile(w, r, specPath)
}

func (s *State) GetPending(w http.ResponseWriter, r *http.Request) {
Expand Down
22 changes: 19 additions & 3 deletions website/api/v1/song.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package v1

import (
"net/http"
"os"
"time"

radio "github.com/R-a-dio/valkyrie"
"github.com/R-a-dio/valkyrie/errors"
"github.com/R-a-dio/valkyrie/util"
"github.com/rs/zerolog/hlog"
"github.com/spf13/afero"
)

func (a *API) GetSong(w http.ResponseWriter, r *http.Request) {
Expand All @@ -24,7 +26,7 @@ func (a *API) GetSong(w http.ResponseWriter, r *http.Request) {
song, err := a.storage.Track(r.Context()).Get(tid)
if err != nil {
hlog.FromRequest(r).Error().Err(err)
http.Error(w, "invalid id", http.StatusNotFound)
http.Error(w, "unknown id", http.StatusNotFound)
return
}

Expand All @@ -35,5 +37,19 @@ func (a *API) GetSong(w http.ResponseWriter, r *http.Request) {
}

path := util.AbsolutePath(a.Config.Conf().MusicPath, song.FilePath)
http.ServeFileFS(w, r, afero.NewIOFS(a.fs), path)

f, err := a.fs.Open(path)
if err != nil {
status := http.StatusInternalServerError
if errors.IsE(err, os.ErrNotExist) {
status = http.StatusNotFound
}
hlog.FromRequest(r).Error().Err(err)
http.Error(w, http.StatusText(status), status)
return
}
defer f.Close()

util.AddContentDispositionSong(w, song.Metadata, song.FilePath)
http.ServeContent(w, r, "", time.Now(), f)
}
136 changes: 136 additions & 0 deletions website/api/v1/song_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package v1

import (
"context"
"encoding/base64"
"io"
"math/rand"
"net/http"
"net/url"
"os"
"path/filepath"
"testing"

radio "github.com/R-a-dio/valkyrie"
"github.com/R-a-dio/valkyrie/config"
"github.com/R-a-dio/valkyrie/errors"
"github.com/R-a-dio/valkyrie/mocks"
"github.com/R-a-dio/valkyrie/util/secret"
"github.com/rs/zerolog"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type testAPI struct {
ctx context.Context
storageMock *mocks.StorageServiceMock
trackMock *mocks.TrackStorageMock

GetArg radio.TrackID
GetRet *radio.Song
GetErr error
}

func newTestAPI(t *testing.T) (*testAPI, *API) {
var api testAPI
var err error

ctx := context.Background()
api.ctx = zerolog.New(os.Stdout).WithContext(ctx)

cfg, err := config.LoadFile()
require.NoError(t, err)

songSecret, err := secret.NewSecret(secret.SongLength)
require.NoError(t, err)

api.trackMock = &mocks.TrackStorageMock{
GetFunc: func(trackID radio.TrackID) (*radio.Song, error) {
if trackID == api.GetArg {
return api.GetRet, api.GetErr
}
return nil, errors.E(errors.SongUnknown)
},
}
api.storageMock = &mocks.StorageServiceMock{
TrackFunc: func(contextMoqParam context.Context) radio.TrackStorage {
return api.trackMock
},
}

fs := afero.NewMemMapFs()

return &api, &API{
storage: api.storageMock,
Config: cfg,
songSecret: songSecret,
fs: fs,
}
}

func createSong(t *testing.T, api *API, id radio.TrackID, filename, metadata string) (*radio.Song, string) {
song := &radio.Song{
Metadata: metadata,
DatabaseTrack: &radio.DatabaseTrack{
TrackID: id,
FilePath: filename,
},
}
song.Hydrate()

data := make([]byte, 128)
_, err := io.ReadFull(rand.New(rand.NewSource(42)), data)
require.NoError(t, err)
sdata := base64.URLEncoding.EncodeToString(data)

fullPath := filepath.Join(api.Config.Conf().MusicPath, filename)
require.NoError(t, afero.WriteFile(api.fs, fullPath, []byte(sdata), 0775))
return song, sdata
}

func TestGetSong(t *testing.T) {
var data string
tapi, api := newTestAPI(t)

tapi.GetArg = 50
tapi.GetRet, data = createSong(t, api, tapi.GetArg, "random.mp3", "testing - hello world")

createValues := func(api *API, song *radio.Song) url.Values {
values := url.Values{}
values.Set("key", api.songSecret.Get(song.Hash[:]))
values.Set("id", song.TrackID.String())
return values
}
t.Run("success", func(t *testing.T) {
assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song",
createValues(api, tapi.GetRet),
http.StatusOK)
assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song",
createValues(api, tapi.GetRet), data)
})
t.Run("missing id", func(t *testing.T) {
values := createValues(api, tapi.GetRet)
values.Del("id")
assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", values, http.StatusBadRequest)
assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", values, "missing")
})
t.Run("invalid id", func(t *testing.T) {
values := createValues(api, tapi.GetRet)
values.Set("id", "this is not a number")
assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", values, http.StatusBadRequest)
assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", values, "invalid")
})
t.Run("unknown id", func(t *testing.T) {
values := createValues(api, tapi.GetRet)
values.Set("id", "100")
assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", values, http.StatusNotFound)
assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", values, "unknown")
})
t.Run("invalid key", func(t *testing.T) {
values := createValues(api, tapi.GetRet)
values.Set("key", "randomdata")
assert.HTTPStatusCode(t, api.GetSong, http.MethodGet, "/song", values, http.StatusUnauthorized)
assert.HTTPBodyContains(t, api.GetSong, http.MethodGet, "/song", values, "invalid key")
})
}
3 changes: 2 additions & 1 deletion website/public/faves.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"

radio "github.com/R-a-dio/valkyrie"
"github.com/R-a-dio/valkyrie/util"
"github.com/R-a-dio/valkyrie/website/middleware"
"github.com/R-a-dio/valkyrie/website/shared"
"github.com/go-chi/chi/v5"
Expand Down Expand Up @@ -71,7 +72,7 @@ func (s State) GetFaves(w http.ResponseWriter, r *http.Request) {
// so we need to support that for old users
if r.FormValue("dl") != "" {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s_faves.json", input.Nickname))
util.AddContentDisposition(w, fmt.Sprintf("%s_faves.json", input.Nickname))
err := json.NewEncoder(w).Encode(NewFaveDownload(input.Faves))
if err != nil {
s.errorHandler(w, r, err)
Expand Down

0 comments on commit ea2a1d0

Please sign in to comment.