Skip to content

Commit

Permalink
Fix remote downloads being trimmed when downloads can be bigger than …
Browse files Browse the repository at this point in the history
…uploads (#542)
  • Loading branch information
turt2live authored Jan 14, 2024
1 parent 3915d58 commit 960691a
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Requests requiring authentication, but lack a provided access token, will return HTTP 401 instead of HTTP 500 now.
* Downloads when using a self-hosted MinIO instance are no longer slower than expected.
* The `DELETE /_matrix/media/unstable/admin/export/:exportId` endpoint has been reinstated as described.
* If a server's `downloads.maxSize` is greater than the `uploads.maxSize`, remote media is no longer cut off at `uploads.maxSize`. The media will instead be downloaded at `downloads.maxSize` and error if greater.

## [1.3.3] - October 31, 2023

Expand Down
10 changes: 8 additions & 2 deletions pipelines/_steps/download/try_download.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/turt2live/matrix-media-repo/pipelines/_steps/datastore_op"
"github.com/turt2live/matrix-media-repo/pool"
"github.com/turt2live/matrix-media-repo/util"
"github.com/turt2live/matrix-media-repo/util/readers"
)

type downloadResult struct {
Expand Down Expand Up @@ -79,11 +80,16 @@ func TryDownload(ctx rcontext.RequestContext, origin string, mediaId string) (*d
}
}

if contentLength > 0 && ctx.Config.Downloads.MaxSizeBytes > 0 && contentLength > ctx.Config.Downloads.MaxSizeBytes {
if contentLength != 0 && ctx.Config.Downloads.MaxSizeBytes > 0 && contentLength > ctx.Config.Downloads.MaxSizeBytes {
errFn(common.ErrMediaTooLarge)
return
}

r := resp.Body
if ctx.Config.Downloads.MaxSizeBytes > 0 {
r = readers.LimitReaderWithOverrunError(resp.Body, ctx.Config.Downloads.MaxSizeBytes)
}

contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream" // binary
Expand All @@ -96,7 +102,7 @@ func TryDownload(ctx rcontext.RequestContext, origin string, mediaId string) (*d
}

ch <- downloadResult{
r: resp.Body,
r: r,
filename: fileName,
contentType: contentType,
err: nil,
Expand Down
3 changes: 2 additions & 1 deletion pipelines/_steps/upload/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import (
"io"

"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util/readers"
)

func LimitStream(ctx rcontext.RequestContext, r io.ReadCloser) io.ReadCloser {
if ctx.Config.Uploads.MaxSizeBytes > 0 {
return io.NopCloser(io.LimitReader(r, ctx.Config.Uploads.MaxSizeBytes))
return readers.LimitReaderWithOverrunError(r, ctx.Config.Uploads.MaxSizeBytes)
} else {
return r
}
Expand Down
4 changes: 3 additions & 1 deletion pipelines/pipeline_upload/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ func Execute(ctx rcontext.RequestContext, origin string, mediaId string, r io.Re
}

// Step 1: Limit the stream's length
r = upload.LimitStream(ctx, r)
if kind == datastores.LocalMediaKind {
r = upload.LimitStream(ctx, r)
}

// Step 2: Create a media ID (if needed)
mustUseMediaId := true
Expand Down
42 changes: 42 additions & 0 deletions util/readers/error_limit_reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package readers

import (
"io"

"github.com/turt2live/matrix-media-repo/common"
)

func LimitReaderWithOverrunError(r io.ReadCloser, n int64) io.ReadCloser {
return &limitedReader{r: r, n: n}
}

type limitedReader struct {
r io.ReadCloser
n int64
}

func (r *limitedReader) Read(p []byte) (int, error) {
if r.n <= 0 {
// See if we can read one more byte, indicating the stream is too big
b := make([]byte, 1)
n, err := r.r.Read(b)
p[0] = b[0]
if err != nil {
// ignore - we're at the end anyways
return n, io.EOF
}
if n > 0 {
return n, common.ErrMediaTooLarge
}

return n, io.EOF
}

n, err := r.r.Read(p)
r.n -= int64(n)
return n, err
}

func (r *limitedReader) Close() error {
return r.r.Close()
}

0 comments on commit 960691a

Please sign in to comment.