Skip to content

Commit

Permalink
Flag media as restricted on upload (MSC3911)
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Nov 25, 2023
1 parent a066bda commit f600ea4
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 0 deletions.
5 changes: 5 additions & 0 deletions api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ func buildRoutes() http.Handler {
register([]string{"GET"}, PrefixFederation, "media/download/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationDownloadMedia), "download", counter))
register([]string{"GET"}, PrefixFederation, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationThumbnailMedia), "thumbnail", counter))

// MSC3911 - Linking media to events
register([]string{"POST"}, PrefixClient, "media/upload", msc3911, router, makeRoute(_routers.RequireAccessToken(unstable.ClientUploadMediaSync), "upload", counter))
register([]string{"POST"}, PrefixClient, "media/create", msc3911, router, makeRoute(_routers.RequireAccessToken(unstable.ClientCreateMedia), "create", counter))

// Custom features
register([]string{"GET"}, PrefixMedia, "local_copy/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.LocalCopy), "local_copy", counter))
register([]string{"GET"}, PrefixMedia, "info/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.MediaInfo), "info", counter))
Expand Down Expand Up @@ -143,6 +147,7 @@ var (
mxUnstable matrixVersions = []string{"unstable", "unstable/io.t2bot.media"}
msc4034 matrixVersions = []string{"unstable/org.matrix.msc4034"}
msc3916 matrixVersions = []string{"unstable/org.matrix.msc3916"}
msc3911 matrixVersions = []string{"unstable/org.matrix.msc3911"}
mxSpecV3Transition matrixVersions = []string{"r0", "v1", "v3"}
mxSpecV3TransitionCS matrixVersions = []string{"r0", "v3"}
mxR0 matrixVersions = []string{"r0"}
Expand Down
50 changes: 50 additions & 0 deletions api/unstable/msc3911_create.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package unstable

import (
"net/http"

"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
v1 "github.com/turt2live/matrix-media-repo/api/v1"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/database"
"github.com/turt2live/matrix-media-repo/pipelines/pipeline_create"
"github.com/turt2live/matrix-media-repo/util"
)

func ClientCreateMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
id, err := restrictAsyncMediaId(rctx, r.Host, user.UserId)
if err != nil {
rctx.Log.Error("Unexpected error creating media ID:", err)
sentry.CaptureException(err)
return _responses.InternalServerError("unexpected error")
}

return &v1.MediaCreatedResponse{
ContentUri: util.MxcUri(id.Origin, id.MediaId),
ExpiresTs: id.ExpiresTs,
}
}

func restrictAsyncMediaId(ctx rcontext.RequestContext, host string, userId string) (*database.DbExpiringMedia, error) {
id, err := pipeline_create.Execute(ctx, host, userId, pipeline_create.DefaultExpirationTime)
if err != nil {
return nil, err
}

db := database.GetInstance().RestrictedMedia.Prepare(ctx)
err = db.Insert(id.Origin, id.MediaId, database.RestrictedToUser, id.UserId)
if err != nil {
// Try to clean up the expiring record, but don't fail if it fails
err2 := database.GetInstance().ExpiringMedia.Prepare(ctx).SetExpiry(id.Origin, id.MediaId, util.NowMillis())
if err2 != nil {
ctx.Log.Warn("Non-fatal error when trying to clean up interstitial expiring media: ", err2)
sentry.CaptureException(err2)
}

return nil, err
}

return id, nil
}
36 changes: 36 additions & 0 deletions api/unstable/msc3911_upload_sync.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package unstable

import (
"net/http"

"github.com/getsentry/sentry-go"
"github.com/turt2live/matrix-media-repo/api/_apimeta"
"github.com/turt2live/matrix-media-repo/api/_responses"
"github.com/turt2live/matrix-media-repo/api/_routers"
"github.com/turt2live/matrix-media-repo/api/r0"
"github.com/turt2live/matrix-media-repo/common/rcontext"
"github.com/turt2live/matrix-media-repo/util"
)

func ClientUploadMediaSync(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} {
// We're a bit fancy here. Instead of mirroring the "upload sync" endpoint to include restricted media, we
// internally create an async media ID then claim it immediately.

id, err := restrictAsyncMediaId(rctx, r.Host, user.UserId)
if err != nil {
rctx.Log.Error("Unexpected error creating media ID:", err)
sentry.CaptureException(err)
return _responses.InternalServerError("unexpected error")
}

r = _routers.ForceSetParam("server", id.Origin, r)
r = _routers.ForceSetParam("mediaId", id.MediaId, r)

resp := r0.UploadMediaAsync(r, rctx, user)
if _, ok := resp.(*r0.MediaUploadedResponse); ok {
return &r0.MediaUploadedResponse{
ContentUri: util.MxcUri(id.Origin, id.MediaId),
}
}
return resp
}
4 changes: 4 additions & 0 deletions database/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type Database struct {
Tasks *tasksTableStatements
Exports *exportsTableStatements
ExportParts *exportPartsTableStatements
RestrictedMedia *restrictedMediaTableStatements
}

var instance *Database
Expand Down Expand Up @@ -124,6 +125,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error
if d.ExportParts, err = prepareExportPartsTables(d.conn); err != nil {
return errors.New("failed to create export parts table accessor: " + err.Error())
}
if d.RestrictedMedia, err = prepareRestrictedMediaTables(d.conn); err != nil {
return errors.New("failed to create restricted media table accessor: " + err.Error())
}

instance = d
return nil
Expand Down
10 changes: 10 additions & 0 deletions database/table_expiring_media.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_
const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;"
const selectExpiringMediaById = "SELECT origin, media_id, user_id, expires_ts FROM expiring_media WHERE origin = $1 AND media_id = $2;"
const deleteExpiringMediaById = "DELETE FROM expiring_media WHERE origin = $1 AND media_id = $2;"
const updateExpiringMediaExpiration = "UPDATE expiring_media SET expires_ts = $3 WHERE origin = $1 AND media_id = $2;"

// Dev note: there is an UPDATE query in the Upload test suite.

Expand All @@ -31,6 +32,7 @@ type expiringMediaTableStatements struct {
selectExpiringMediaByUserCount *sql.Stmt
selectExpiringMediaById *sql.Stmt
deleteExpiringMediaById *sql.Stmt
updateExpiringMediaExpiration *sql.Stmt
}

type expiringMediaTableWithContext struct {
Expand All @@ -54,6 +56,9 @@ func prepareExpiringMediaTables(db *sql.DB) (*expiringMediaTableStatements, erro
if stmts.deleteExpiringMediaById, err = db.Prepare(deleteExpiringMediaById); err != nil {
return nil, errors.New("error preparing deleteExpiringMediaById: " + err.Error())
}
if stmts.updateExpiringMediaExpiration, err = db.Prepare(updateExpiringMediaExpiration); err != nil {
return nil, errors.New("error preparing updateExpiringMediaExpiration: " + err.Error())
}

return stmts, nil
}
Expand Down Expand Up @@ -96,3 +101,8 @@ func (s *expiringMediaTableWithContext) Delete(origin string, mediaId string) er
_, err := s.statements.deleteExpiringMediaById.ExecContext(s.ctx, origin, mediaId)
return err
}

func (s *expiringMediaTableWithContext) SetExpiry(origin string, mediaId string, expiresTs int64) error {
_, err := s.statements.updateExpiringMediaExpiration.ExecContext(s.ctx, origin, mediaId, expiresTs)
return err
}
81 changes: 81 additions & 0 deletions database/table_restricted_media.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package database

import (
"database/sql"
"errors"

"github.com/turt2live/matrix-media-repo/common/rcontext"
)

type RestrictedCondition string

const RestrictedToEvent RestrictedCondition = "event_id" // MSC3911
const RestrictedToProfile RestrictedCondition = "profile_user_id" // MSC3911
const RestrictedToUser RestrictedCondition = "io.t2bot.user_id" // Internal extension

type DbRestrictedMedia struct {
Origin string
MediaId string
Condition RestrictedCondition
ConditionValue string
}

const insertRestrictedMedia = "INSERT INTO restricted_media (origin, media_id, condition_type, condition_value) VALUES ($1, $2, $3, $4);"
const updateRestrictedMedia = "UPDATE restricted_media SET condition_type = $3, condition_value = $4 WHERE origin = $1 AND media_id = $2;"
const selectRestrictedMedia = "SELECT origin, media_id, condition_type, condition_value FROM restricted_media WHERE origin = $1 AND media_id = $2 LIMIT 1;"

type restrictedMediaTableStatements struct {
insertRestrictedMedia *sql.Stmt
updateRestrictedMedia *sql.Stmt
selectRestrictedMedia *sql.Stmt
}

type restrictedMediaTableWithContext struct {
statements *restrictedMediaTableStatements
ctx rcontext.RequestContext
}

func prepareRestrictedMediaTables(db *sql.DB) (*restrictedMediaTableStatements, error) {
var err error
var stmts = &restrictedMediaTableStatements{}

if stmts.insertRestrictedMedia, err = db.Prepare(insertRestrictedMedia); err != nil {
return nil, errors.New("error preparing insertRestrictedMedia: " + err.Error())
}
if stmts.updateRestrictedMedia, err = db.Prepare(updateRestrictedMedia); err != nil {
return nil, errors.New("error preparing updateRestrictedMedia: " + err.Error())
}
if stmts.selectRestrictedMedia, err = db.Prepare(selectRestrictedMedia); err != nil {
return nil, errors.New("error preparing selectRestrictedMedia: " + err.Error())
}

return stmts, nil
}

func (s *restrictedMediaTableStatements) Prepare(ctx rcontext.RequestContext) *restrictedMediaTableWithContext {
return &restrictedMediaTableWithContext{
statements: s,
ctx: ctx,
}
}

func (s *restrictedMediaTableWithContext) Insert(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error {
_, err := s.statements.insertRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue)
return err
}

func (s *restrictedMediaTableWithContext) Update(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error {
_, err := s.statements.updateRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue)
return err
}

func (s *restrictedMediaTableWithContext) GetById(origin string, mediaId string) (*DbRestrictedMedia, error) {
row := s.statements.selectRestrictedMedia.QueryRowContext(s.ctx, origin, mediaId)
val := &DbRestrictedMedia{}
err := row.Scan(&val.Origin, &val.MediaId, &val.Condition, &val.ConditionValue)
if errors.Is(err, sql.ErrNoRows) {
err = nil
val = nil
}
return val, err
}
2 changes: 2 additions & 0 deletions migrations/27_create_restricted_media_table_down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
DROP INDEX IF EXISTS idx_restricted_media;
DROP TABLE IF EXISTS restricted_media;
2 changes: 2 additions & 0 deletions migrations/27_create_restricted_media_table_up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
CREATE TABLE IF NOT EXISTS restricted_media (origin TEXT NOT NULL, media_id TEXT NOT NULL, condition_type TEXT NOT NULL, condition_value TEXT NOT NULL);
CREATE UNIQUE INDEX IF NOT EXISTS idx_restricted_media ON restricted_media (origin, media_id);

0 comments on commit f600ea4

Please sign in to comment.