diff --git a/api/routes.go b/api/routes.go index 3b36c468..49580f18 100644 --- a/api/routes.go +++ b/api/routes.go @@ -56,6 +56,10 @@ func buildRoutes() http.Handler { register([]string{"GET"}, PrefixFederation, "media/download/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationDownloadMedia), "download", counter)) register([]string{"GET"}, PrefixFederation, "media/thumbnail/:server/:mediaId", msc3916, router, makeRoute(_routers.RequireServerAuth(unstable.FederationThumbnailMedia), "thumbnail", counter)) + // MSC3911 - Linking media to events + register([]string{"POST"}, PrefixClient, "media/upload", msc3911, router, makeRoute(_routers.RequireAccessToken(unstable.ClientUploadMediaSync), "upload", counter)) + register([]string{"POST"}, PrefixClient, "media/create", msc3911, router, makeRoute(_routers.RequireAccessToken(unstable.ClientCreateMedia), "create", counter)) + // Custom features register([]string{"GET"}, PrefixMedia, "local_copy/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.LocalCopy), "local_copy", counter)) register([]string{"GET"}, PrefixMedia, "info/:server/:mediaId", mxUnstable, router, makeRoute(_routers.RequireAccessToken(unstable.MediaInfo), "info", counter)) @@ -143,6 +147,7 @@ var ( mxUnstable matrixVersions = []string{"unstable", "unstable/io.t2bot.media"} msc4034 matrixVersions = []string{"unstable/org.matrix.msc4034"} msc3916 matrixVersions = []string{"unstable/org.matrix.msc3916"} + msc3911 matrixVersions = []string{"unstable/org.matrix.msc3911"} mxSpecV3Transition matrixVersions = []string{"r0", "v1", "v3"} mxSpecV3TransitionCS matrixVersions = []string{"r0", "v3"} mxR0 matrixVersions = []string{"r0"} diff --git a/api/unstable/msc3911_create.go b/api/unstable/msc3911_create.go new file mode 100644 index 00000000..bea65622 --- /dev/null +++ b/api/unstable/msc3911_create.go @@ -0,0 +1,50 @@ +package unstable + +import ( + "net/http" + + "github.com/getsentry/sentry-go" + "github.com/turt2live/matrix-media-repo/api/_apimeta" + "github.com/turt2live/matrix-media-repo/api/_responses" + v1 "github.com/turt2live/matrix-media-repo/api/v1" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/pipelines/pipeline_create" + "github.com/turt2live/matrix-media-repo/util" +) + +func ClientCreateMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + id, err := restrictAsyncMediaId(rctx, r.Host, user.UserId) + if err != nil { + rctx.Log.Error("Unexpected error creating media ID:", err) + sentry.CaptureException(err) + return _responses.InternalServerError("unexpected error") + } + + return &v1.MediaCreatedResponse{ + ContentUri: util.MxcUri(id.Origin, id.MediaId), + ExpiresTs: id.ExpiresTs, + } +} + +func restrictAsyncMediaId(ctx rcontext.RequestContext, host string, userId string) (*database.DbExpiringMedia, error) { + id, err := pipeline_create.Execute(ctx, host, userId, pipeline_create.DefaultExpirationTime) + if err != nil { + return nil, err + } + + db := database.GetInstance().RestrictedMedia.Prepare(ctx) + err = db.Insert(id.Origin, id.MediaId, database.RestrictedToUser, id.UserId) + if err != nil { + // Try to clean up the expiring record, but don't fail if it fails + err2 := database.GetInstance().ExpiringMedia.Prepare(ctx).SetExpiry(id.Origin, id.MediaId, util.NowMillis()) + if err2 != nil { + ctx.Log.Warn("Non-fatal error when trying to clean up interstitial expiring media: ", err2) + sentry.CaptureException(err2) + } + + return nil, err + } + + return id, nil +} diff --git a/api/unstable/msc3911_upload_sync.go b/api/unstable/msc3911_upload_sync.go new file mode 100644 index 00000000..b03a43db --- /dev/null +++ b/api/unstable/msc3911_upload_sync.go @@ -0,0 +1,36 @@ +package unstable + +import ( + "net/http" + + "github.com/getsentry/sentry-go" + "github.com/turt2live/matrix-media-repo/api/_apimeta" + "github.com/turt2live/matrix-media-repo/api/_responses" + "github.com/turt2live/matrix-media-repo/api/_routers" + "github.com/turt2live/matrix-media-repo/api/r0" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/util" +) + +func ClientUploadMediaSync(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { + // We're a bit fancy here. Instead of mirroring the "upload sync" endpoint to include restricted media, we + // internally create an async media ID then claim it immediately. + + id, err := restrictAsyncMediaId(rctx, r.Host, user.UserId) + if err != nil { + rctx.Log.Error("Unexpected error creating media ID:", err) + sentry.CaptureException(err) + return _responses.InternalServerError("unexpected error") + } + + r = _routers.ForceSetParam("server", id.Origin, r) + r = _routers.ForceSetParam("mediaId", id.MediaId, r) + + resp := r0.UploadMediaAsync(r, rctx, user) + if _, ok := resp.(*r0.MediaUploadedResponse); ok { + return &r0.MediaUploadedResponse{ + ContentUri: util.MxcUri(id.Origin, id.MediaId), + } + } + return resp +} diff --git a/database/db.go b/database/db.go index 9b027ce7..dabb4941 100644 --- a/database/db.go +++ b/database/db.go @@ -28,6 +28,7 @@ type Database struct { Tasks *tasksTableStatements Exports *exportsTableStatements ExportParts *exportPartsTableStatements + RestrictedMedia *restrictedMediaTableStatements } var instance *Database @@ -124,6 +125,9 @@ func openDatabase(connectionString string, maxConns int, maxIdleConns int) error if d.ExportParts, err = prepareExportPartsTables(d.conn); err != nil { return errors.New("failed to create export parts table accessor: " + err.Error()) } + if d.RestrictedMedia, err = prepareRestrictedMediaTables(d.conn); err != nil { + return errors.New("failed to create restricted media table accessor: " + err.Error()) + } instance = d return nil diff --git a/database/table_expiring_media.go b/database/table_expiring_media.go index 9dfa0854..2bde322b 100644 --- a/database/table_expiring_media.go +++ b/database/table_expiring_media.go @@ -23,6 +23,7 @@ const insertExpiringMedia = "INSERT INTO expiring_media (origin, media_id, user_ const selectExpiringMediaByUserCount = "SELECT COUNT(*) FROM expiring_media WHERE user_id = $1 AND expires_ts >= $2;" const selectExpiringMediaById = "SELECT origin, media_id, user_id, expires_ts FROM expiring_media WHERE origin = $1 AND media_id = $2;" const deleteExpiringMediaById = "DELETE FROM expiring_media WHERE origin = $1 AND media_id = $2;" +const updateExpiringMediaExpiration = "UPDATE expiring_media SET expires_ts = $3 WHERE origin = $1 AND media_id = $2;" // Dev note: there is an UPDATE query in the Upload test suite. @@ -31,6 +32,7 @@ type expiringMediaTableStatements struct { selectExpiringMediaByUserCount *sql.Stmt selectExpiringMediaById *sql.Stmt deleteExpiringMediaById *sql.Stmt + updateExpiringMediaExpiration *sql.Stmt } type expiringMediaTableWithContext struct { @@ -54,6 +56,9 @@ func prepareExpiringMediaTables(db *sql.DB) (*expiringMediaTableStatements, erro if stmts.deleteExpiringMediaById, err = db.Prepare(deleteExpiringMediaById); err != nil { return nil, errors.New("error preparing deleteExpiringMediaById: " + err.Error()) } + if stmts.updateExpiringMediaExpiration, err = db.Prepare(updateExpiringMediaExpiration); err != nil { + return nil, errors.New("error preparing updateExpiringMediaExpiration: " + err.Error()) + } return stmts, nil } @@ -96,3 +101,8 @@ func (s *expiringMediaTableWithContext) Delete(origin string, mediaId string) er _, err := s.statements.deleteExpiringMediaById.ExecContext(s.ctx, origin, mediaId) return err } + +func (s *expiringMediaTableWithContext) SetExpiry(origin string, mediaId string, expiresTs int64) error { + _, err := s.statements.updateExpiringMediaExpiration.ExecContext(s.ctx, origin, mediaId, expiresTs) + return err +} diff --git a/database/table_restricted_media.go b/database/table_restricted_media.go new file mode 100644 index 00000000..5f1fff65 --- /dev/null +++ b/database/table_restricted_media.go @@ -0,0 +1,81 @@ +package database + +import ( + "database/sql" + "errors" + + "github.com/turt2live/matrix-media-repo/common/rcontext" +) + +type RestrictedCondition string + +const RestrictedToEvent RestrictedCondition = "event_id" // MSC3911 +const RestrictedToProfile RestrictedCondition = "profile_user_id" // MSC3911 +const RestrictedToUser RestrictedCondition = "io.t2bot.user_id" // Internal extension + +type DbRestrictedMedia struct { + Origin string + MediaId string + Condition RestrictedCondition + ConditionValue string +} + +const insertRestrictedMedia = "INSERT INTO restricted_media (origin, media_id, condition_type, condition_value) VALUES ($1, $2, $3, $4);" +const updateRestrictedMedia = "UPDATE restricted_media SET condition_type = $3, condition_value = $4 WHERE origin = $1 AND media_id = $2;" +const selectRestrictedMedia = "SELECT origin, media_id, condition_type, condition_value FROM restricted_media WHERE origin = $1 AND media_id = $2 LIMIT 1;" + +type restrictedMediaTableStatements struct { + insertRestrictedMedia *sql.Stmt + updateRestrictedMedia *sql.Stmt + selectRestrictedMedia *sql.Stmt +} + +type restrictedMediaTableWithContext struct { + statements *restrictedMediaTableStatements + ctx rcontext.RequestContext +} + +func prepareRestrictedMediaTables(db *sql.DB) (*restrictedMediaTableStatements, error) { + var err error + var stmts = &restrictedMediaTableStatements{} + + if stmts.insertRestrictedMedia, err = db.Prepare(insertRestrictedMedia); err != nil { + return nil, errors.New("error preparing insertRestrictedMedia: " + err.Error()) + } + if stmts.updateRestrictedMedia, err = db.Prepare(updateRestrictedMedia); err != nil { + return nil, errors.New("error preparing updateRestrictedMedia: " + err.Error()) + } + if stmts.selectRestrictedMedia, err = db.Prepare(selectRestrictedMedia); err != nil { + return nil, errors.New("error preparing selectRestrictedMedia: " + err.Error()) + } + + return stmts, nil +} + +func (s *restrictedMediaTableStatements) Prepare(ctx rcontext.RequestContext) *restrictedMediaTableWithContext { + return &restrictedMediaTableWithContext{ + statements: s, + ctx: ctx, + } +} + +func (s *restrictedMediaTableWithContext) Insert(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error { + _, err := s.statements.insertRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue) + return err +} + +func (s *restrictedMediaTableWithContext) Update(origin string, mediaId string, condition RestrictedCondition, conditionValue string) error { + _, err := s.statements.updateRestrictedMedia.ExecContext(s.ctx, origin, mediaId, condition, conditionValue) + return err +} + +func (s *restrictedMediaTableWithContext) GetById(origin string, mediaId string) (*DbRestrictedMedia, error) { + row := s.statements.selectRestrictedMedia.QueryRowContext(s.ctx, origin, mediaId) + val := &DbRestrictedMedia{} + err := row.Scan(&val.Origin, &val.MediaId, &val.Condition, &val.ConditionValue) + if errors.Is(err, sql.ErrNoRows) { + err = nil + val = nil + } + return val, err +} diff --git a/migrations/27_create_restricted_media_table_down.sql b/migrations/27_create_restricted_media_table_down.sql new file mode 100644 index 00000000..62ae20d6 --- /dev/null +++ b/migrations/27_create_restricted_media_table_down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_restricted_media; +DROP TABLE IF EXISTS restricted_media; diff --git a/migrations/27_create_restricted_media_table_up.sql b/migrations/27_create_restricted_media_table_up.sql new file mode 100644 index 00000000..13b0fa6c --- /dev/null +++ b/migrations/27_create_restricted_media_table_up.sql @@ -0,0 +1,2 @@ +CREATE TABLE IF NOT EXISTS restricted_media (origin TEXT NOT NULL, media_id TEXT NOT NULL, condition_type TEXT NOT NULL, condition_value TEXT NOT NULL); +CREATE UNIQUE INDEX IF NOT EXISTS idx_restricted_media ON restricted_media (origin, media_id);