Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MSC3860 download redirection behaviour #543

Merged
merged 1 commit into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Exporting MMR's data to Synapse is now possible with `import_to_synapse`. To use it, first run `gdpr_export` or similar.
* Errors encountered during a background task, such as an API-induced export, are exposed as `error_message` in the admin API.
* MMR will follow redirects on federated downloads up to 5 hops.
* S3-backed datastores can have download requests redirected to a public-facing CDN rather than being proxied through MMR. See `publicBaseUrl` under the S3 datastore config.

### Changed

Expand Down
9 changes: 9 additions & 0 deletions api/_responses/redirect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package _responses

type RedirectResponse struct {
ToUrl string
}

func Redirect(url string) *RedirectResponse {
return &RedirectResponse{ToUrl: url}
}
8 changes: 8 additions & 0 deletions api/_routers/98-use-rcontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ func (c *RContextRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {

headers := w.Header()

// Check for redirection early
if redirect, isRedirect := res.(*_responses.RedirectResponse); isRedirect {
log.Infof("Replying with result: %T <%s>", res, redirect.ToUrl)
headers.Set("Location", redirect.ToUrl)
r = writeStatusCode(w, r, http.StatusTemporaryRedirect)
return // we're done here
}

// Check for HTML response and reply accordingly
if htmlRes, isHtml := res.(*_responses.HtmlResponse); isHtml {
log.Infof("Replying with result: %T <%d chars of html>", res, len(htmlRes.HTML))
Expand Down
24 changes: 20 additions & 4 deletions api/r0/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_download"
"github.com/turt2live/matrix-media-repo/util"

Expand All @@ -22,6 +23,7 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
mediaId := _routers.GetParam("mediaId", r)
filename := _routers.GetParam("filename", r)
allowRemote := r.URL.Query().Get("allow_remote")
allowRedirect := r.URL.Query().Get("allow_redirect")
timeoutMs := r.URL.Query().Get("timeout_ms")

if !_routers.ServerNameRegex.MatchString(server) {
Expand All @@ -37,16 +39,26 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
downloadRemote = parsedFlag
}

canRedirect := false
if allowRedirect != "" {
parsedFlag, err := strconv.ParseBool(allowRedirect)
if err != nil {
return _responses.BadRequest("allow_redirect flag does not appear to be a boolean")
}
canRedirect = parsedFlag
}

blockFor, err := util.CalcBlockForDuration(timeoutMs)
if err != nil {
return _responses.BadRequest("timeout_ms does not appear to be an integer")
}

rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
"filename": filename,
"allowRemote": downloadRemote,
"mediaId": mediaId,
"server": server,
"filename": filename,
"allowRemote": downloadRemote,
"allowRedirect": canRedirect,
})

if !util.IsGlobalAdmin(user.UserId) && util.IsHostIgnored(server) {
Expand All @@ -57,8 +69,10 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
media, stream, err := pipeline_download.Execute(rctx, server, mediaId, pipeline_download.DownloadOpts{
FetchRemoteIfNeeded: downloadRemote,
BlockForReadUntil: blockFor,
CanRedirect: canRedirect,
})
if err != nil {
var redirect datastores.RedirectError
if errors.Is(err, common.ErrMediaNotFound) {
return _responses.NotFoundError()
} else if errors.Is(err, common.ErrMediaTooLarge) {
Expand All @@ -72,6 +86,8 @@ func DownloadMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.
}
} else if errors.Is(err, common.ErrMediaNotYetUploaded) {
return _responses.NotYetUploaded()
} else if errors.As(err, &redirect) {
return _responses.Redirect(redirect.RedirectUrl)
}
rctx.Log.Error("Unexpected error locating media: ", err)
sentry.CaptureException(err)
Expand Down
22 changes: 19 additions & 3 deletions api/r0/thumbnail.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_download"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_thumbnail"
"github.com/turt2live/matrix-media-repo/util"
Expand All @@ -23,6 +24,7 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
server := _routers.GetParam("server", r)
mediaId := _routers.GetParam("mediaId", r)
allowRemote := r.URL.Query().Get("allow_remote")
allowRedirect := r.URL.Query().Get("allow_redirect")
timeoutMs := r.URL.Query().Get("timeout_ms")

if !_routers.ServerNameRegex.MatchString(server) {
Expand All @@ -38,15 +40,25 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
downloadRemote = parsedFlag
}

canRedirect := false
if allowRedirect != "" {
parsedFlag, err := strconv.ParseBool(allowRedirect)
if err != nil {
return _responses.BadRequest("allow_redirect flag does not appear to be a boolean")
}
canRedirect = parsedFlag
}

blockFor, err := util.CalcBlockForDuration(timeoutMs)
if err != nil {
return _responses.BadRequest("timeout_ms does not appear to be an integer")
}

rctx = rctx.LogWithFields(logrus.Fields{
"mediaId": mediaId,
"server": server,
"allowRemote": downloadRemote,
"mediaId": mediaId,
"server": server,
"allowRemote": downloadRemote,
"allowRedirect": canRedirect,
})

if !util.IsGlobalAdmin(user.UserId) && util.IsHostIgnored(server) {
Expand Down Expand Up @@ -111,13 +123,15 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
FetchRemoteIfNeeded: downloadRemote,
BlockForReadUntil: blockFor,
RecordOnly: false, // overridden
CanRedirect: canRedirect,
},
Width: width,
Height: height,
Method: method,
Animated: animated,
})
if err != nil {
var redirect datastores.RedirectError
if errors.Is(err, common.ErrMediaNotFound) {
return _responses.NotFoundError()
} else if errors.Is(err, common.ErrMediaTooLarge) {
Expand Down Expand Up @@ -152,6 +166,8 @@ func ThumbnailMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta
TargetDisposition: "infer",
}
}
} else if errors.As(err, &redirect) {
return _responses.Redirect(redirect.RedirectUrl)
}
rctx.Log.Error("Unexpected error locating media: ", err)
sentry.CaptureException(err)
Expand Down
9 changes: 9 additions & 0 deletions config.sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ datastores:
# An optional storage class for tuning how the media is stored at s3.
# See https://aws.amazon.com/s3/storage-classes/ for details; uncomment to use.
#storageClass: STANDARD
# When set, if the requesting user/server supports being redirected, and MMR is capable
# of performing that redirection, they will be redirected to the given object location.
# The object ID used in S3 is assumed to be the file name, and will simply be appended.
# It is therefore important to include any trailing slashes or path information. For
# example, an object with ID "hello/world" will get converted to "https://mycdn.example.org/hello/world".
# Note that MMR may not redirect in all cases, even if the client/server requests the
# capability. MMR may still be responsible for bandwidth charges incurred from going to
# the bucket directly.
#publicBaseUrl: "https://mycdn.example.org/"

# Options for controlling archives. Archives are exports of a particular user's content for
# the purpose of GDPR or moving media to a different server.
Expand Down
19 changes: 19 additions & 0 deletions datastores/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package datastores

import (
"errors"
"fmt"
"io"
"os"
"path"
Expand Down Expand Up @@ -35,3 +36,21 @@ func Download(ctx rcontext.RequestContext, ds config.DatastoreConfig, dsFileName

return rsc, err
}

func DownloadOrRedirect(ctx rcontext.RequestContext, ds config.DatastoreConfig, dsFileName string) (io.ReadSeekCloser, error) {
if ds.Type != "s3" {
return Download(ctx, ds, dsFileName)
}

s3c, err := getS3(ds)
if err != nil {
return nil, err
}

if s3c.publicBaseUrl != "" {
metrics.S3Operations.With(prometheus.Labels{"operation": "RedirectGetObject"}).Inc()
return nil, redirect(fmt.Sprintf("%s%s", s3c.publicBaseUrl, dsFileName))
}

return Download(ctx, ds, dsFileName)
}
15 changes: 15 additions & 0 deletions datastores/redirect.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package datastores

import "errors"

type RedirectError struct {
error
RedirectUrl string
}

func redirect(url string) RedirectError {
return RedirectError{
error: errors.New("redirection"),
RedirectUrl: url,
}
}
15 changes: 9 additions & 6 deletions datastores/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ import (
var s3clients = &sync.Map{}

type s3 struct {
client *minio.Client
storageClass string
bucket string
client *minio.Client
storageClass string
bucket string
publicBaseUrl string
}

func ResetS3Clients() {
Expand All @@ -37,6 +38,7 @@ func getS3(ds config.DatastoreConfig) (*s3, error) {
region := ds.Options["region"]
storageClass, hasStorageClass := ds.Options["storageClass"]
useSslStr, hasSsl := ds.Options["ssl"]
publicBaseUrl := ds.Options["publicBaseUrl"]

if !hasStorageClass {
storageClass = "STANDARD"
Expand All @@ -59,9 +61,10 @@ func getS3(ds config.DatastoreConfig) (*s3, error) {
}

s3c := &s3{
client: client,
storageClass: storageClass,
bucket: bucket,
client: client,
storageClass: storageClass,
bucket: bucket,
publicBaseUrl: publicBaseUrl,
}
s3clients.Store(ds.Id, s3c)
return s3c, nil
Expand Down
33 changes: 30 additions & 3 deletions pipelines/_steps/download/open_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"io"

"github.com/turt2live/matrix-media-repo/common/config"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/datastores"
Expand All @@ -12,16 +13,42 @@ import (
)

func OpenStream(ctx rcontext.RequestContext, media *database.Locatable) (io.ReadSeekCloser, error) {
reader, ds, err := doOpenStream(ctx, media)
if err != nil {
return nil, err
}
if reader != nil {
ctx.Log.Debugf("Got %s from cache", media.Sha256Hash)
return readers.NopSeekCloser(reader), nil
}

return datastores.Download(ctx, ds, media.Location)
}

func OpenOrRedirect(ctx rcontext.RequestContext, media *database.Locatable) (io.ReadSeekCloser, error) {
reader, ds, err := doOpenStream(ctx, media)
if err != nil {
return nil, err
}
if reader != nil {
ctx.Log.Debugf("Got %s from cache", media.Sha256Hash)
return readers.NopSeekCloser(reader), nil
}

return datastores.DownloadOrRedirect(ctx, ds, media.Location)
}

func doOpenStream(ctx rcontext.RequestContext, media *database.Locatable) (io.ReadSeekCloser, config.DatastoreConfig, error) {
reader, err := redislib.TryGetMedia(ctx, media.Sha256Hash)
if err != nil || reader != nil {
ctx.Log.Debugf("Got %s from cache", media.Sha256Hash)
return readers.NopSeekCloser(reader), err
return readers.NopSeekCloser(reader), config.DatastoreConfig{}, err
}

ds, ok := datastores.Get(ctx, media.DatastoreId)
if !ok {
return nil, errors.New("unable to locate datastore for media")
return nil, ds, errors.New("unable to locate datastore for media")
}

return datastores.Download(ctx, ds, media.Location)
return nil, ds, nil
}
9 changes: 7 additions & 2 deletions pipelines/pipeline_download/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ type DownloadOpts struct {
FetchRemoteIfNeeded bool
BlockForReadUntil time.Duration
RecordOnly bool
CanRedirect bool
}

func (o DownloadOpts) String() string {
return fmt.Sprintf("f=%t,b=%s,r=%t", o.FetchRemoteIfNeeded, o.BlockForReadUntil.String(), o.RecordOnly)
return fmt.Sprintf("f=%t,b=%s,r=%t,d=%t", o.FetchRemoteIfNeeded, o.BlockForReadUntil.String(), o.RecordOnly, o.CanRedirect)
}

func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts DownloadOpts) (*database.DbMedia, io.ReadCloser, error) {
Expand Down Expand Up @@ -71,7 +72,11 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Do
if opts.RecordOnly {
return nil, nil
}
return download.OpenStream(ctx, record.Locatable)
if opts.CanRedirect {
return download.OpenOrRedirect(ctx, record.Locatable)
} else {
return download.OpenStream(ctx, record.Locatable)
}
}

// Step 4: Media record unknown - download it (if possible)
Expand Down
13 changes: 11 additions & 2 deletions pipelines/pipeline_thumbnail/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,23 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, opts Th
if opts.RecordOnly {
return nil, nil
}
return download.OpenStream(ctx, record.Locatable)
if opts.CanRedirect {
return download.OpenOrRedirect(ctx, record.Locatable)
} else {
return download.OpenStream(ctx, record.Locatable)
}
}

// Step 6: Generate the thumbnail and return that
record, r, err := thumbnails.Generate(ctx, mediaRecord, opts.Width, opts.Height, opts.Method, opts.Animated)
if err != nil {
if !opts.RecordOnly && errors.Is(err, common.ErrMediaDimensionsTooSmall) {
d, err := download.OpenStream(ctx, mediaRecord.Locatable)
var d io.ReadSeekCloser
if opts.CanRedirect {
d, err = download.OpenOrRedirect(ctx, mediaRecord.Locatable)
} else {
d, err = download.OpenStream(ctx, mediaRecord.Locatable)
}
if err != nil {
return nil, err
} else {
Expand Down
Loading