From 6a318b9f76263db720c5f9e6e5ffc2369c9608fa Mon Sep 17 00:00:00 2001 From: devin ivy Date: Tue, 23 Jan 2024 21:17:32 -0500 Subject: [PATCH] Appview v1 maintaining device tokens and pushing notifications w/ courier (#2073) * add courier proto to bsky, build * update registerPush on appview to support registering device tokens with courier * setup bsky notifications to send to either gorush or courier * wire courier push into indexer, test * courier push retries * tidy and build --- .../workflows/build-and-push-bsky-aws.yaml | 2 +- packages/bsky/package.json | 5 +- packages/bsky/proto/courier.proto | 56 +++ .../api/app/bsky/notification/registerPush.ts | 50 +- packages/bsky/src/config.ts | 42 ++ packages/bsky/src/context.ts | 12 +- packages/bsky/src/courier.ts | 41 ++ packages/bsky/src/index.ts | 20 +- packages/bsky/src/indexer/config.ts | 36 ++ packages/bsky/src/indexer/context.ts | 6 + packages/bsky/src/indexer/index.ts | 30 +- packages/bsky/src/notifications.ts | 314 ++++++------ packages/bsky/src/proto/courier_connect.ts | 50 ++ packages/bsky/src/proto/courier_pb.ts | 473 ++++++++++++++++++ packages/bsky/src/services/actor/index.ts | 19 +- .../bsky/src/services/indexing/processor.ts | 2 +- packages/bsky/src/util/retry.ts | 12 + .../bsky/tests/notification-server.test.ts | 78 ++- pnpm-lock.yaml | 7 + 19 files changed, 1059 insertions(+), 196 deletions(-) create mode 100644 packages/bsky/proto/courier.proto create mode 100644 packages/bsky/src/courier.ts create mode 100644 packages/bsky/src/proto/courier_connect.ts create mode 100644 packages/bsky/src/proto/courier_pb.ts diff --git a/.github/workflows/build-and-push-bsky-aws.yaml b/.github/workflows/build-and-push-bsky-aws.yaml index d56d56d0030..f534e015ea5 100644 --- a/.github/workflows/build-and-push-bsky-aws.yaml +++ b/.github/workflows/build-and-push-bsky-aws.yaml @@ -3,7 +3,7 @@ on: push: branches: - main - - appview-v1-sync-mutes + - appview-v1-courier env: REGISTRY: ${{ secrets.AWS_ECR_REGISTRY_USEAST2_PACKAGES_REGISTRY }} USERNAME: ${{ secrets.AWS_ECR_REGISTRY_USEAST2_PACKAGES_USERNAME }} diff --git a/packages/bsky/package.json b/packages/bsky/package.json index 754d3c614fe..15e03c3dd86 100644 --- a/packages/bsky/package.json +++ b/packages/bsky/package.json @@ -29,16 +29,16 @@ "test:log": "tail -50 test.log | pino-pretty", "test:updateSnapshot": "jest --updateSnapshot", "migration:create": "ts-node ./bin/migration-create.ts", - "buf:gen": "buf generate ../bsync/proto" + "buf:gen": "buf generate ../bsync/proto && buf generate ./proto" }, "dependencies": { "@atproto/api": "workspace:^", "@atproto/common": "workspace:^", "@atproto/crypto": "workspace:^", - "@atproto/syntax": "workspace:^", "@atproto/identity": "workspace:^", "@atproto/lexicon": "workspace:^", "@atproto/repo": "workspace:^", + "@atproto/syntax": "workspace:^", "@atproto/xrpc-server": "workspace:^", "@bufbuild/protobuf": "^1.5.0", "@connectrpc/connect": "^1.1.4", @@ -55,6 +55,7 @@ "ioredis": "^5.3.2", "kysely": "^0.22.0", "multiformats": "^9.9.0", + "murmurhash": "^2.0.1", "p-queue": "^6.6.2", "pg": "^8.10.0", "pino": "^8.15.0", diff --git a/packages/bsky/proto/courier.proto b/packages/bsky/proto/courier.proto new file mode 100644 index 00000000000..8a63bb86e47 --- /dev/null +++ b/packages/bsky/proto/courier.proto @@ -0,0 +1,56 @@ +syntax = "proto3"; + +package courier; +option go_package = "./;courier"; + +import "google/protobuf/struct.proto"; +import "google/protobuf/timestamp.proto"; + +// +// Messages +// + +// Ping +message PingRequest {} +message PingResponse {} + +// Notifications + +enum AppPlatform { + APP_PLATFORM_UNSPECIFIED = 0; + APP_PLATFORM_IOS = 1; + APP_PLATFORM_ANDROID = 2; + APP_PLATFORM_WEB = 3; +} + +message Notification { + string id = 1; + string recipient_did = 2; + string title = 3; + string message = 4; + string collapse_key = 5; + bool always_deliver = 6; + google.protobuf.Timestamp timestamp = 7; + google.protobuf.Struct additional = 8; +} + +message PushNotificationsRequest { + repeated Notification notifications = 1; +} + +message PushNotificationsResponse {} + +message RegisterDeviceTokenRequest { + string did = 1; + string token = 2; + string app_id = 3; + AppPlatform platform = 4; +} + +message RegisterDeviceTokenResponse {} + +service Service { + rpc Ping(PingRequest) returns (PingResponse); + rpc PushNotifications(PushNotificationsRequest) returns (PushNotificationsResponse); + rpc RegisterDeviceToken(RegisterDeviceTokenRequest) returns (RegisterDeviceTokenResponse); +} diff --git a/packages/bsky/src/api/app/bsky/notification/registerPush.ts b/packages/bsky/src/api/app/bsky/notification/registerPush.ts index 9645cd76c83..abce1cd096c 100644 --- a/packages/bsky/src/api/app/bsky/notification/registerPush.ts +++ b/packages/bsky/src/api/app/bsky/notification/registerPush.ts @@ -1,29 +1,63 @@ +import assert from 'node:assert' import { InvalidRequestError } from '@atproto/xrpc-server' import { Server } from '../../../../lexicon' import AppContext from '../../../../context' import { Platform } from '../../../../notifications' +import { CourierClient } from '../../../../courier' +import { AppPlatform } from '../../../../proto/courier_pb' export default function (server: Server, ctx: AppContext) { server.app.bsky.notification.registerPush({ auth: ctx.authVerifier.standard, - handler: async ({ auth, input }) => { + handler: async ({ req, auth, input }) => { const { token, platform, serviceDid, appId } = input.body const did = auth.credentials.iss if (serviceDid !== auth.credentials.aud) { throw new InvalidRequestError('Invalid serviceDid.') } - const { notifServer } = ctx if (platform !== 'ios' && platform !== 'android' && platform !== 'web') { throw new InvalidRequestError( 'Unsupported platform: must be "ios", "android", or "web".', ) } - await notifServer.registerDeviceForPushNotifications( - did, - token, - platform as Platform, - appId, - ) + + const db = ctx.db.getPrimary() + + const registerDeviceWithAppview = async () => { + await ctx.services + .actor(db) + .registerPushDeviceToken(did, token, platform as Platform, appId) + } + + const registerDeviceWithCourier = async ( + courierClient: CourierClient, + ) => { + await courierClient.registerDeviceToken({ + did, + token, + platform: + platform === 'ios' + ? AppPlatform.IOS + : platform === 'android' + ? AppPlatform.ANDROID + : AppPlatform.WEB, + appId, + }) + } + + if (ctx.cfg.courierOnlyRegistration) { + assert(ctx.courierClient) + await registerDeviceWithCourier(ctx.courierClient) + } else { + await registerDeviceWithAppview() + if (ctx.courierClient) { + try { + await registerDeviceWithCourier(ctx.courierClient) + } catch (err) { + req.log.warn(err, 'failed to register device token with courier') + } + } + } }, }) } diff --git a/packages/bsky/src/config.ts b/packages/bsky/src/config.ts index da695cef4fb..1d88abb588a 100644 --- a/packages/bsky/src/config.ts +++ b/packages/bsky/src/config.ts @@ -36,6 +36,11 @@ export interface ServerConfigValues { bsyncHttpVersion?: '1.1' | '2' bsyncIgnoreBadTls?: boolean bsyncOnlyMutes?: boolean + courierUrl?: string + courierApiKey?: string + courierHttpVersion?: '1.1' | '2' + courierIgnoreBadTls?: boolean + courierOnlyRegistration?: boolean adminPassword: string moderatorPassword: string triagePassword: string @@ -100,6 +105,18 @@ export class ServerConfig { const bsyncOnlyMutes = process.env.BSKY_BSYNC_ONLY_MUTES === 'true' assert(!bsyncOnlyMutes || bsyncUrl, 'bsync-only mutes requires a bsync url') assert(bsyncHttpVersion === '1.1' || bsyncHttpVersion === '2') + const courierUrl = process.env.BSKY_COURIER_URL || undefined + const courierApiKey = process.env.BSKY_COURIER_API_KEY || undefined + const courierHttpVersion = process.env.BSKY_COURIER_HTTP_VERSION || '2' + const courierIgnoreBadTls = + process.env.BSKY_COURIER_IGNORE_BAD_TLS === 'true' + const courierOnlyRegistration = + process.env.BSKY_COURIER_ONLY_REGISTRATION === 'true' + assert( + !courierOnlyRegistration || courierUrl, + 'courier-only registration requires a courier url', + ) + assert(courierHttpVersion === '1.1' || courierHttpVersion === '2') const dbPrimaryPostgresUrl = overrides?.dbPrimaryPostgresUrl || process.env.DB_PRIMARY_POSTGRES_URL let dbReplicaPostgresUrls = overrides?.dbReplicaPostgresUrls @@ -169,6 +186,11 @@ export class ServerConfig { bsyncHttpVersion, bsyncIgnoreBadTls, bsyncOnlyMutes, + courierUrl, + courierApiKey, + courierHttpVersion, + courierIgnoreBadTls, + courierOnlyRegistration, adminPassword, moderatorPassword, triagePassword, @@ -305,6 +327,26 @@ export class ServerConfig { return this.cfg.bsyncIgnoreBadTls } + get courierUrl() { + return this.cfg.courierUrl + } + + get courierApiKey() { + return this.cfg.courierApiKey + } + + get courierHttpVersion() { + return this.cfg.courierHttpVersion + } + + get courierIgnoreBadTls() { + return this.cfg.courierIgnoreBadTls + } + + get courierOnlyRegistration() { + return this.cfg.courierOnlyRegistration + } + get adminPassword() { return this.cfg.adminPassword } diff --git a/packages/bsky/src/context.ts b/packages/bsky/src/context.ts index fb5993a5aa4..2eeb83fea4b 100644 --- a/packages/bsky/src/context.ts +++ b/packages/bsky/src/context.ts @@ -10,10 +10,10 @@ import { Services } from './services' import DidRedisCache from './did-cache' import { BackgroundQueue } from './background' import { MountedAlgos } from './feed-gen/types' -import { NotificationServer } from './notifications' import { Redis } from './redis' import { AuthVerifier } from './auth-verifier' import { BsyncClient } from './bsync' +import { CourierClient } from './courier' export class AppContext { constructor( @@ -29,8 +29,8 @@ export class AppContext { backgroundQueue: BackgroundQueue searchAgent?: AtpAgent bsyncClient?: BsyncClient + courierClient?: CourierClient algos: MountedAlgos - notifServer: NotificationServer authVerifier: AuthVerifier }, ) {} @@ -71,10 +71,6 @@ export class AppContext { return this.opts.redis } - get notifServer(): NotificationServer { - return this.opts.notifServer - } - get searchAgent(): AtpAgent | undefined { return this.opts.searchAgent } @@ -83,6 +79,10 @@ export class AppContext { return this.opts.bsyncClient } + get courierClient(): CourierClient | undefined { + return this.opts.courierClient + } + get authVerifier(): AuthVerifier { return this.opts.authVerifier } diff --git a/packages/bsky/src/courier.ts b/packages/bsky/src/courier.ts new file mode 100644 index 00000000000..aeb095898f6 --- /dev/null +++ b/packages/bsky/src/courier.ts @@ -0,0 +1,41 @@ +import { Service } from './proto/courier_connect' +import { + Code, + ConnectError, + PromiseClient, + createPromiseClient, + Interceptor, +} from '@connectrpc/connect' +import { + createConnectTransport, + ConnectTransportOptions, +} from '@connectrpc/connect-node' + +export type CourierClient = PromiseClient + +export const createCourierClient = ( + opts: ConnectTransportOptions, +): CourierClient => { + const transport = createConnectTransport(opts) + return createPromiseClient(Service, transport) +} + +export { Code } + +export const isCourierError = ( + err: unknown, + code?: Code, +): err is ConnectError => { + if (err instanceof ConnectError) { + return !code || err.code === code + } + return false +} + +export const authWithApiKey = + (apiKey: string): Interceptor => + (next) => + (req) => { + req.header.set('authorization', `Bearer ${apiKey}`) + return next(req) + } diff --git a/packages/bsky/src/index.ts b/packages/bsky/src/index.ts index 48dcc82fc39..5de34201e87 100644 --- a/packages/bsky/src/index.ts +++ b/packages/bsky/src/index.ts @@ -29,12 +29,12 @@ import { } from './image/invalidator' import { BackgroundQueue } from './background' import { MountedAlgos } from './feed-gen/types' -import { NotificationServer } from './notifications' import { AtpAgent } from '@atproto/api' import { Keypair } from '@atproto/crypto' import { Redis } from './redis' import { AuthVerifier } from './auth-verifier' -import { authWithApiKey, createBsyncClient } from './bsync' +import { authWithApiKey as bsyncAuth, createBsyncClient } from './bsync' +import { authWithApiKey as courierAuth, createCourierClient } from './courier' export type { ServerConfigValues } from './config' export type { MountedAlgos } from './feed-gen/types' @@ -113,7 +113,6 @@ export class BskyAppView { const backgroundQueue = new BackgroundQueue(db.getPrimary()) - const notifServer = new NotificationServer(db.getPrimary()) const searchAgent = config.searchEndpoint ? new AtpAgent({ service: config.searchEndpoint }) : undefined @@ -142,7 +141,18 @@ export class BskyAppView { httpVersion: config.bsyncHttpVersion ?? '2', nodeOptions: { rejectUnauthorized: !config.bsyncIgnoreBadTls }, interceptors: config.bsyncApiKey - ? [authWithApiKey(config.bsyncApiKey)] + ? [bsyncAuth(config.bsyncApiKey)] + : [], + }) + : undefined + + const courierClient = config.courierUrl + ? createCourierClient({ + baseUrl: config.courierUrl, + httpVersion: config.courierHttpVersion ?? '2', + nodeOptions: { rejectUnauthorized: !config.courierIgnoreBadTls }, + interceptors: config.courierApiKey + ? [courierAuth(config.courierApiKey)] : [], }) : undefined @@ -159,8 +169,8 @@ export class BskyAppView { backgroundQueue, searchAgent, bsyncClient, + courierClient, algos, - notifServer, authVerifier, }) diff --git a/packages/bsky/src/indexer/config.ts b/packages/bsky/src/indexer/config.ts index 6acf86f9543..be7129580f2 100644 --- a/packages/bsky/src/indexer/config.ts +++ b/packages/bsky/src/indexer/config.ts @@ -22,6 +22,10 @@ export interface IndexerConfigValues { fuzzyFalsePositiveB64?: string labelerKeywords: Record moderationPushUrl: string + courierUrl?: string + courierApiKey?: string + courierHttpVersion?: '1.1' | '2' + courierIgnoreBadTls?: boolean indexerConcurrency?: number indexerPartitionIds: number[] indexerPartitionBatchSize?: number @@ -72,6 +76,18 @@ export class IndexerConfig { process.env.MODERATION_PUSH_URL || undefined assert(moderationPushUrl) + const courierUrl = + overrides?.courierUrl || process.env.BSKY_COURIER_URL || undefined + const courierApiKey = + overrides?.courierApiKey || process.env.BSKY_COURIER_API_KEY || undefined + const courierHttpVersion = + overrides?.courierHttpVersion || + process.env.BSKY_COURIER_HTTP_VERSION || + '2' + const courierIgnoreBadTls = + overrides?.courierIgnoreBadTls || + process.env.BSKY_COURIER_IGNORE_BAD_TLS === 'true' + assert(courierHttpVersion === '1.1' || courierHttpVersion === '2') const hiveApiKey = process.env.HIVE_API_KEY || undefined const abyssEndpoint = process.env.ABYSS_ENDPOINT const abyssPassword = process.env.ABYSS_PASSWORD @@ -114,6 +130,10 @@ export class IndexerConfig { didCacheMaxTTL, handleResolveNameservers, moderationPushUrl, + courierUrl, + courierApiKey, + courierHttpVersion, + courierIgnoreBadTls, hiveApiKey, abyssEndpoint, abyssPassword, @@ -185,6 +205,22 @@ export class IndexerConfig { return this.cfg.moderationPushUrl } + get courierUrl() { + return this.cfg.courierUrl + } + + get courierApiKey() { + return this.cfg.courierApiKey + } + + get courierHttpVersion() { + return this.cfg.courierHttpVersion + } + + get courierIgnoreBadTls() { + return this.cfg.courierIgnoreBadTls + } + get hiveApiKey() { return this.cfg.hiveApiKey } diff --git a/packages/bsky/src/indexer/context.ts b/packages/bsky/src/indexer/context.ts index 1ce2fbf1ea2..a4c1f1f2ea0 100644 --- a/packages/bsky/src/indexer/context.ts +++ b/packages/bsky/src/indexer/context.ts @@ -6,6 +6,7 @@ import { BackgroundQueue } from '../background' import DidSqlCache from '../did-cache' import { Redis } from '../redis' import { AutoModerator } from '../auto-moderator' +import { NotificationServer } from '../notifications' export class IndexerContext { constructor( @@ -19,6 +20,7 @@ export class IndexerContext { didCache: DidSqlCache backgroundQueue: BackgroundQueue autoMod: AutoModerator + notifServer?: NotificationServer }, ) {} @@ -57,6 +59,10 @@ export class IndexerContext { get autoMod(): AutoModerator { return this.opts.autoMod } + + get notifServer(): NotificationServer | undefined { + return this.opts.notifServer + } } export default IndexerContext diff --git a/packages/bsky/src/indexer/index.ts b/packages/bsky/src/indexer/index.ts index fec81faa374..7d012304573 100644 --- a/packages/bsky/src/indexer/index.ts +++ b/packages/bsky/src/indexer/index.ts @@ -11,8 +11,13 @@ import { createServices } from './services' import { IndexerSubscription } from './subscription' import { AutoModerator } from '../auto-moderator' import { Redis } from '../redis' -import { NotificationServer } from '../notifications' +import { + CourierNotificationServer, + GorushNotificationServer, + NotificationServer, +} from '../notifications' import { CloseFn, createServer, startServer } from './server' +import { authWithApiKey as courierAuth, createCourierClient } from '../courier' export { IndexerConfig } from './config' export type { IndexerConfigValues } from './config' @@ -60,9 +65,27 @@ export class BskyIndexer { backgroundQueue, }) - const notifServer = cfg.pushNotificationEndpoint - ? new NotificationServer(db, cfg.pushNotificationEndpoint) + const courierClient = cfg.courierUrl + ? createCourierClient({ + baseUrl: cfg.courierUrl, + httpVersion: cfg.courierHttpVersion ?? '2', + nodeOptions: { rejectUnauthorized: !cfg.courierIgnoreBadTls }, + interceptors: cfg.courierApiKey + ? [courierAuth(cfg.courierApiKey)] + : [], + }) : undefined + + let notifServer: NotificationServer | undefined + if (courierClient) { + notifServer = new CourierNotificationServer(db, courierClient) + } else if (cfg.pushNotificationEndpoint) { + notifServer = new GorushNotificationServer( + db, + cfg.pushNotificationEndpoint, + ) + } + const services = createServices({ idResolver, autoMod, @@ -79,6 +102,7 @@ export class BskyIndexer { didCache, backgroundQueue, autoMod, + notifServer, }) const sub = new IndexerSubscription(ctx, { partitionIds: cfg.indexerPartitionIds, diff --git a/packages/bsky/src/notifications.ts b/packages/bsky/src/notifications.ts index fdf24919d19..d29913ec7d4 100644 --- a/packages/bsky/src/notifications.ts +++ b/packages/bsky/src/notifications.ts @@ -1,6 +1,8 @@ import axios from 'axios' import { Insertable, sql } from 'kysely' import TTLCache from '@isaacs/ttlcache' +import { Struct, Timestamp } from '@bufbuild/protobuf' +import murmur from 'murmurhash' import { AtUri } from '@atproto/api' import { MINUTE, chunkArray } from '@atproto/common' import Database from './db/primary' @@ -9,11 +11,13 @@ import { NotificationPushToken as PushToken } from './db/tables/notification-pus import logger from './indexer/logger' import { notSoftDeletedClause, valuesList } from './db/util' import { ids } from './lexicon/lexicons' -import { retryHttp } from './util/retry' +import { retryConnect, retryHttp } from './util/retry' +import { Notification as CourierNotification } from './proto/courier_pb' +import { CourierClient } from './courier' export type Platform = 'ios' | 'android' | 'web' -type PushNotification = { +type GorushNotification = { tokens: string[] platform: 1 | 2 // 1 = ios, 2 = android title: string @@ -26,161 +30,24 @@ type PushNotification = { collapse_key?: string } -type InsertableNotif = Insertable +type NotifRow = Insertable -type NotifDisplay = { +type NotifView = { key: string rateLimit: boolean title: string body: string - notif: InsertableNotif + notif: NotifRow } -export class NotificationServer { - private rateLimiter = new RateLimiter(1, 30 * MINUTE) - - constructor(public db: Database, public pushEndpoint?: string) {} - - async getTokensByDid(dids: string[]) { - if (!dids.length) return {} - const tokens = await this.db.db - .selectFrom('notification_push_token') - .where('did', 'in', dids) - .selectAll() - .execute() - return tokens.reduce((acc, token) => { - acc[token.did] ??= [] - acc[token.did].push(token) - return acc - }, {} as Record) - } - - async prepareNotifsToSend(notifications: InsertableNotif[]) { - const now = Date.now() - const notifsToSend: PushNotification[] = [] - const tokensByDid = await this.getTokensByDid( - unique(notifications.map((n) => n.did)), - ) - // views for all notifications that have tokens - const notificationViews = await this.getNotificationDisplayAttributes( - notifications.filter((n) => tokensByDid[n.did]), - ) - - for (const notifView of notificationViews) { - if (!isRecent(notifView.notif.sortAt, 10 * MINUTE)) { - continue // if the notif is from > 10 minutes ago, don't send push notif - } - const { did: userDid } = notifView.notif - const userTokens = tokensByDid[userDid] ?? [] - for (const t of userTokens) { - const { appId, platform, token } = t - if (notifView.rateLimit && !this.rateLimiter.check(token, now)) { - continue - } - if (platform === 'ios' || platform === 'android') { - notifsToSend.push({ - tokens: [token], - platform: platform === 'ios' ? 1 : 2, - title: notifView.title, - message: notifView.body, - topic: appId, - data: { - reason: notifView.notif.reason, - recordUri: notifView.notif.recordUri, - recordCid: notifView.notif.recordCid, - }, - collapse_id: notifView.key, - collapse_key: notifView.key, - }) - } else { - // @TODO: Handle web notifs - logger.warn({ did: userDid }, 'cannot send web notification to user') - } - } - } +export abstract class NotificationServer { + constructor(public db: Database) {} - return notifsToSend - } + abstract prepareNotifications(notifs: NotifRow[]): Promise - /** - * The function `addNotificationsToQueue` adds push notifications to a queue, taking into account rate - * limiting and batching the notifications for efficient processing. - * @param {PushNotification[]} notifs - An array of PushNotification objects. Each PushNotification - * object has a "tokens" property which is an array of tokens. - * @returns void - */ - async processNotifications(notifs: PushNotification[]) { - for (const batch of chunkArray(notifs, 20)) { - try { - await this.sendPushNotifications(batch) - } catch (err) { - logger.error({ err, batch }, 'notification push batch failed') - } - } - } + abstract processNotifications(prepared: N[]): Promise - /** 1. Get the user's token (APNS or FCM for iOS and Android respectively) from the database - User token will be in the format: - did || token || platform (1 = iOS, 2 = Android, 3 = Web) - 2. Send notification to `gorush` server with token - Notification will be in the format: - "notifications": [ - { - "tokens": string[], - "platform": 1 | 2, - "message": string, - "title": string, - "priority": "normal" | "high", - "image": string, (Android only) - "expiration": number, (iOS only) - "badge": number, (iOS only) - } - ] - 3. `gorush` will send notification to APNS or FCM - 4. store response from `gorush` which contains the ID of the notification - 5. If notification needs to be updated or deleted, find the ID of the notification from the database and send a new notification to `gorush` with the ID (repeat step 2) - */ - private async sendPushNotifications(notifications: PushNotification[]) { - // if pushEndpoint is not defined, we are not running in the indexer service, so we can't send push notifications - if (!this.pushEndpoint) { - throw new Error('Push endpoint not defined') - } - // if no notifications, skip and return early - if (notifications.length === 0) { - return - } - const pushEndpoint = this.pushEndpoint - await retryHttp(() => - axios.post( - pushEndpoint, - { notifications }, - { - headers: { - 'Content-Type': 'application/json', - accept: 'application/json', - }, - }, - ), - ) - } - - async registerDeviceForPushNotifications( - did: string, - token: string, - platform: Platform, - appId: string, - ) { - // if token doesn't exist, insert it, on conflict do nothing - await this.db.db - .insertInto('notification_push_token') - .values({ did, token, platform, appId }) - .onConflict((oc) => oc.doNothing()) - .execute() - } - - async getNotificationDisplayAttributes( - notifs: InsertableNotif[], - ): Promise { + async getNotificationViews(notifs: NotifRow[]): Promise { const { ref } = this.db.db.dynamic const authorDids = notifs.map((n) => n.author) const subjectUris = notifs.flatMap((n) => n.reasonSubject ?? []) @@ -219,7 +86,7 @@ export class NotificationServer { return acc }, {} as Record) - const results: NotifDisplay[] = [] + const results: NotifView[] = [] for (const notif of notifs) { const { @@ -310,7 +177,7 @@ export class NotificationServer { return results } - async findBlocksAndMutes(notifs: InsertableNotif[]) { + private async findBlocksAndMutes(notifs: NotifRow[]) { const pairs = notifs.map((n) => ({ author: n.author, receiver: n.did })) const { ref } = this.db.db.dynamic const blockQb = this.db.db @@ -353,6 +220,155 @@ export class NotificationServer { } } +export class GorushNotificationServer extends NotificationServer { + private rateLimiter = new RateLimiter(1, 30 * MINUTE) + + constructor(public db: Database, public pushEndpoint: string) { + super(db) + } + + async prepareNotifications( + notifs: NotifRow[], + ): Promise { + const now = Date.now() + const notifsToSend: GorushNotification[] = [] + const tokensByDid = await this.getTokensByDid( + unique(notifs.map((n) => n.did)), + ) + // views for all notifications that have tokens + const notificationViews = await this.getNotificationViews( + notifs.filter((n) => tokensByDid[n.did]), + ) + + for (const notifView of notificationViews) { + if (!isRecent(notifView.notif.sortAt, 10 * MINUTE)) { + continue // if the notif is from > 10 minutes ago, don't send push notif + } + const { did: userDid } = notifView.notif + const userTokens = tokensByDid[userDid] ?? [] + for (const t of userTokens) { + const { appId, platform, token } = t + if (notifView.rateLimit && !this.rateLimiter.check(token, now)) { + continue + } + if (platform === 'ios' || platform === 'android') { + notifsToSend.push({ + tokens: [token], + platform: platform === 'ios' ? 1 : 2, + title: notifView.title, + message: notifView.body, + topic: appId, + data: { + reason: notifView.notif.reason, + recordUri: notifView.notif.recordUri, + recordCid: notifView.notif.recordCid, + }, + collapse_id: notifView.key, + collapse_key: notifView.key, + }) + } else { + // @TODO: Handle web notifs + logger.warn({ did: userDid }, 'cannot send web notification to user') + } + } + } + return notifsToSend + } + + async getTokensByDid(dids: string[]) { + if (!dids.length) return {} + const tokens = await this.db.db + .selectFrom('notification_push_token') + .where('did', 'in', dids) + .selectAll() + .execute() + return tokens.reduce((acc, token) => { + acc[token.did] ??= [] + acc[token.did].push(token) + return acc + }, {} as Record) + } + + async processNotifications(prepared: GorushNotification[]): Promise { + for (const batch of chunkArray(prepared, 20)) { + try { + await this.sendToGorush(batch) + } catch (err) { + logger.error({ err, batch }, 'notification push batch failed') + } + } + } + + private async sendToGorush(prepared: GorushNotification[]) { + // if no notifications, skip and return early + if (prepared.length === 0) { + return + } + const pushEndpoint = this.pushEndpoint + await retryHttp(() => + axios.post( + pushEndpoint, + { notifications: prepared }, + { + headers: { + 'content-type': 'application/json', + accept: 'application/json', + }, + }, + ), + ) + } +} + +export class CourierNotificationServer extends NotificationServer { + constructor(public db: Database, public courierClient: CourierClient) { + super(db) + } + + async prepareNotifications( + notifs: NotifRow[], + ): Promise { + const notificationViews = await this.getNotificationViews(notifs) + const notifsToSend = notificationViews.map((n) => { + return new CourierNotification({ + id: getCourierId(n), + recipientDid: n.notif.did, + title: n.title, + message: n.body, + collapseKey: n.key, + alwaysDeliver: !n.rateLimit, + timestamp: Timestamp.fromDate(new Date(n.notif.sortAt)), + additional: Struct.fromJson({ + uri: n.notif.recordUri, + reason: n.notif.reason, + subject: n.notif.reasonSubject || '', + }), + }) + }) + return notifsToSend + } + + async processNotifications(prepared: CourierNotification[]): Promise { + try { + await retryConnect(() => + this.courierClient.pushNotifications({ notifications: prepared }), + ) + } catch (err) { + logger.error({ err }, 'notification push to courier failed') + } + } +} + +const getCourierId = (notif: NotifView) => { + const key = [ + notif.notif.recordUri, + notif.notif.did, + notif.notif.reason, + notif.notif.reasonSubject || '', + ].join('::') + return murmur.v3(key).toString(16) +} + const isRecent = (isoTime: string, timeDiff: number): boolean => { const diff = Date.now() - new Date(isoTime).getTime() return diff < timeDiff diff --git a/packages/bsky/src/proto/courier_connect.ts b/packages/bsky/src/proto/courier_connect.ts new file mode 100644 index 00000000000..04d482e0788 --- /dev/null +++ b/packages/bsky/src/proto/courier_connect.ts @@ -0,0 +1,50 @@ +// @generated by protoc-gen-connect-es v1.3.0 with parameter "target=ts,import_extension=.ts" +// @generated from file courier.proto (package courier, syntax proto3) +/* eslint-disable */ +// @ts-nocheck + +import { + PingRequest, + PingResponse, + PushNotificationsRequest, + PushNotificationsResponse, + RegisterDeviceTokenRequest, + RegisterDeviceTokenResponse, +} from './courier_pb.ts' +import { MethodKind } from '@bufbuild/protobuf' + +/** + * @generated from service courier.Service + */ +export const Service = { + typeName: 'courier.Service', + methods: { + /** + * @generated from rpc courier.Service.Ping + */ + ping: { + name: 'Ping', + I: PingRequest, + O: PingResponse, + kind: MethodKind.Unary, + }, + /** + * @generated from rpc courier.Service.PushNotifications + */ + pushNotifications: { + name: 'PushNotifications', + I: PushNotificationsRequest, + O: PushNotificationsResponse, + kind: MethodKind.Unary, + }, + /** + * @generated from rpc courier.Service.RegisterDeviceToken + */ + registerDeviceToken: { + name: 'RegisterDeviceToken', + I: RegisterDeviceTokenRequest, + O: RegisterDeviceTokenResponse, + kind: MethodKind.Unary, + }, + }, +} as const diff --git a/packages/bsky/src/proto/courier_pb.ts b/packages/bsky/src/proto/courier_pb.ts new file mode 100644 index 00000000000..447b47211d9 --- /dev/null +++ b/packages/bsky/src/proto/courier_pb.ts @@ -0,0 +1,473 @@ +// @generated by protoc-gen-es v1.6.0 with parameter "target=ts,import_extension=.ts" +// @generated from file courier.proto (package courier, syntax proto3) +/* eslint-disable */ +// @ts-nocheck + +import type { + BinaryReadOptions, + FieldList, + JsonReadOptions, + JsonValue, + PartialMessage, + PlainMessage, +} from '@bufbuild/protobuf' +import { Message, proto3, Struct, Timestamp } from '@bufbuild/protobuf' + +/** + * @generated from enum courier.AppPlatform + */ +export enum AppPlatform { + /** + * @generated from enum value: APP_PLATFORM_UNSPECIFIED = 0; + */ + UNSPECIFIED = 0, + + /** + * @generated from enum value: APP_PLATFORM_IOS = 1; + */ + IOS = 1, + + /** + * @generated from enum value: APP_PLATFORM_ANDROID = 2; + */ + ANDROID = 2, + + /** + * @generated from enum value: APP_PLATFORM_WEB = 3; + */ + WEB = 3, +} +// Retrieve enum metadata with: proto3.getEnumType(AppPlatform) +proto3.util.setEnumType(AppPlatform, 'courier.AppPlatform', [ + { no: 0, name: 'APP_PLATFORM_UNSPECIFIED' }, + { no: 1, name: 'APP_PLATFORM_IOS' }, + { no: 2, name: 'APP_PLATFORM_ANDROID' }, + { no: 3, name: 'APP_PLATFORM_WEB' }, +]) + +/** + * Ping + * + * @generated from message courier.PingRequest + */ +export class PingRequest extends Message { + constructor(data?: PartialMessage) { + super() + proto3.util.initPartial(data, this) + } + + static readonly runtime: typeof proto3 = proto3 + static readonly typeName = 'courier.PingRequest' + static readonly fields: FieldList = proto3.util.newFieldList(() => []) + + static fromBinary( + bytes: Uint8Array, + options?: Partial, + ): PingRequest { + return new PingRequest().fromBinary(bytes, options) + } + + static fromJson( + jsonValue: JsonValue, + options?: Partial, + ): PingRequest { + return new PingRequest().fromJson(jsonValue, options) + } + + static fromJsonString( + jsonString: string, + options?: Partial, + ): PingRequest { + return new PingRequest().fromJsonString(jsonString, options) + } + + static equals( + a: PingRequest | PlainMessage | undefined, + b: PingRequest | PlainMessage | undefined, + ): boolean { + return proto3.util.equals(PingRequest, a, b) + } +} + +/** + * @generated from message courier.PingResponse + */ +export class PingResponse extends Message { + constructor(data?: PartialMessage) { + super() + proto3.util.initPartial(data, this) + } + + static readonly runtime: typeof proto3 = proto3 + static readonly typeName = 'courier.PingResponse' + static readonly fields: FieldList = proto3.util.newFieldList(() => []) + + static fromBinary( + bytes: Uint8Array, + options?: Partial, + ): PingResponse { + return new PingResponse().fromBinary(bytes, options) + } + + static fromJson( + jsonValue: JsonValue, + options?: Partial, + ): PingResponse { + return new PingResponse().fromJson(jsonValue, options) + } + + static fromJsonString( + jsonString: string, + options?: Partial, + ): PingResponse { + return new PingResponse().fromJsonString(jsonString, options) + } + + static equals( + a: PingResponse | PlainMessage | undefined, + b: PingResponse | PlainMessage | undefined, + ): boolean { + return proto3.util.equals(PingResponse, a, b) + } +} + +/** + * @generated from message courier.Notification + */ +export class Notification extends Message { + /** + * @generated from field: string id = 1; + */ + id = '' + + /** + * @generated from field: string recipient_did = 2; + */ + recipientDid = '' + + /** + * @generated from field: string title = 3; + */ + title = '' + + /** + * @generated from field: string message = 4; + */ + message = '' + + /** + * @generated from field: string collapse_key = 5; + */ + collapseKey = '' + + /** + * @generated from field: bool always_deliver = 6; + */ + alwaysDeliver = false + + /** + * @generated from field: google.protobuf.Timestamp timestamp = 7; + */ + timestamp?: Timestamp + + /** + * @generated from field: google.protobuf.Struct additional = 8; + */ + additional?: Struct + + constructor(data?: PartialMessage) { + super() + proto3.util.initPartial(data, this) + } + + static readonly runtime: typeof proto3 = proto3 + static readonly typeName = 'courier.Notification' + static readonly fields: FieldList = proto3.util.newFieldList(() => [ + { no: 1, name: 'id', kind: 'scalar', T: 9 /* ScalarType.STRING */ }, + { + no: 2, + name: 'recipient_did', + kind: 'scalar', + T: 9 /* ScalarType.STRING */, + }, + { no: 3, name: 'title', kind: 'scalar', T: 9 /* ScalarType.STRING */ }, + { no: 4, name: 'message', kind: 'scalar', T: 9 /* ScalarType.STRING */ }, + { + no: 5, + name: 'collapse_key', + kind: 'scalar', + T: 9 /* ScalarType.STRING */, + }, + { + no: 6, + name: 'always_deliver', + kind: 'scalar', + T: 8 /* ScalarType.BOOL */, + }, + { no: 7, name: 'timestamp', kind: 'message', T: Timestamp }, + { no: 8, name: 'additional', kind: 'message', T: Struct }, + ]) + + static fromBinary( + bytes: Uint8Array, + options?: Partial, + ): Notification { + return new Notification().fromBinary(bytes, options) + } + + static fromJson( + jsonValue: JsonValue, + options?: Partial, + ): Notification { + return new Notification().fromJson(jsonValue, options) + } + + static fromJsonString( + jsonString: string, + options?: Partial, + ): Notification { + return new Notification().fromJsonString(jsonString, options) + } + + static equals( + a: Notification | PlainMessage | undefined, + b: Notification | PlainMessage | undefined, + ): boolean { + return proto3.util.equals(Notification, a, b) + } +} + +/** + * @generated from message courier.PushNotificationsRequest + */ +export class PushNotificationsRequest extends Message { + /** + * @generated from field: repeated courier.Notification notifications = 1; + */ + notifications: Notification[] = [] + + constructor(data?: PartialMessage) { + super() + proto3.util.initPartial(data, this) + } + + static readonly runtime: typeof proto3 = proto3 + static readonly typeName = 'courier.PushNotificationsRequest' + static readonly fields: FieldList = proto3.util.newFieldList(() => [ + { + no: 1, + name: 'notifications', + kind: 'message', + T: Notification, + repeated: true, + }, + ]) + + static fromBinary( + bytes: Uint8Array, + options?: Partial, + ): PushNotificationsRequest { + return new PushNotificationsRequest().fromBinary(bytes, options) + } + + static fromJson( + jsonValue: JsonValue, + options?: Partial, + ): PushNotificationsRequest { + return new PushNotificationsRequest().fromJson(jsonValue, options) + } + + static fromJsonString( + jsonString: string, + options?: Partial, + ): PushNotificationsRequest { + return new PushNotificationsRequest().fromJsonString(jsonString, options) + } + + static equals( + a: + | PushNotificationsRequest + | PlainMessage + | undefined, + b: + | PushNotificationsRequest + | PlainMessage + | undefined, + ): boolean { + return proto3.util.equals(PushNotificationsRequest, a, b) + } +} + +/** + * @generated from message courier.PushNotificationsResponse + */ +export class PushNotificationsResponse extends Message { + constructor(data?: PartialMessage) { + super() + proto3.util.initPartial(data, this) + } + + static readonly runtime: typeof proto3 = proto3 + static readonly typeName = 'courier.PushNotificationsResponse' + static readonly fields: FieldList = proto3.util.newFieldList(() => []) + + static fromBinary( + bytes: Uint8Array, + options?: Partial, + ): PushNotificationsResponse { + return new PushNotificationsResponse().fromBinary(bytes, options) + } + + static fromJson( + jsonValue: JsonValue, + options?: Partial, + ): PushNotificationsResponse { + return new PushNotificationsResponse().fromJson(jsonValue, options) + } + + static fromJsonString( + jsonString: string, + options?: Partial, + ): PushNotificationsResponse { + return new PushNotificationsResponse().fromJsonString(jsonString, options) + } + + static equals( + a: + | PushNotificationsResponse + | PlainMessage + | undefined, + b: + | PushNotificationsResponse + | PlainMessage + | undefined, + ): boolean { + return proto3.util.equals(PushNotificationsResponse, a, b) + } +} + +/** + * @generated from message courier.RegisterDeviceTokenRequest + */ +export class RegisterDeviceTokenRequest extends Message { + /** + * @generated from field: string did = 1; + */ + did = '' + + /** + * @generated from field: string token = 2; + */ + token = '' + + /** + * @generated from field: string app_id = 3; + */ + appId = '' + + /** + * @generated from field: courier.AppPlatform platform = 4; + */ + platform = AppPlatform.UNSPECIFIED + + constructor(data?: PartialMessage) { + super() + proto3.util.initPartial(data, this) + } + + static readonly runtime: typeof proto3 = proto3 + static readonly typeName = 'courier.RegisterDeviceTokenRequest' + static readonly fields: FieldList = proto3.util.newFieldList(() => [ + { no: 1, name: 'did', kind: 'scalar', T: 9 /* ScalarType.STRING */ }, + { no: 2, name: 'token', kind: 'scalar', T: 9 /* ScalarType.STRING */ }, + { no: 3, name: 'app_id', kind: 'scalar', T: 9 /* ScalarType.STRING */ }, + { + no: 4, + name: 'platform', + kind: 'enum', + T: proto3.getEnumType(AppPlatform), + }, + ]) + + static fromBinary( + bytes: Uint8Array, + options?: Partial, + ): RegisterDeviceTokenRequest { + return new RegisterDeviceTokenRequest().fromBinary(bytes, options) + } + + static fromJson( + jsonValue: JsonValue, + options?: Partial, + ): RegisterDeviceTokenRequest { + return new RegisterDeviceTokenRequest().fromJson(jsonValue, options) + } + + static fromJsonString( + jsonString: string, + options?: Partial, + ): RegisterDeviceTokenRequest { + return new RegisterDeviceTokenRequest().fromJsonString(jsonString, options) + } + + static equals( + a: + | RegisterDeviceTokenRequest + | PlainMessage + | undefined, + b: + | RegisterDeviceTokenRequest + | PlainMessage + | undefined, + ): boolean { + return proto3.util.equals(RegisterDeviceTokenRequest, a, b) + } +} + +/** + * @generated from message courier.RegisterDeviceTokenResponse + */ +export class RegisterDeviceTokenResponse extends Message { + constructor(data?: PartialMessage) { + super() + proto3.util.initPartial(data, this) + } + + static readonly runtime: typeof proto3 = proto3 + static readonly typeName = 'courier.RegisterDeviceTokenResponse' + static readonly fields: FieldList = proto3.util.newFieldList(() => []) + + static fromBinary( + bytes: Uint8Array, + options?: Partial, + ): RegisterDeviceTokenResponse { + return new RegisterDeviceTokenResponse().fromBinary(bytes, options) + } + + static fromJson( + jsonValue: JsonValue, + options?: Partial, + ): RegisterDeviceTokenResponse { + return new RegisterDeviceTokenResponse().fromJson(jsonValue, options) + } + + static fromJsonString( + jsonString: string, + options?: Partial, + ): RegisterDeviceTokenResponse { + return new RegisterDeviceTokenResponse().fromJsonString(jsonString, options) + } + + static equals( + a: + | RegisterDeviceTokenResponse + | PlainMessage + | undefined, + b: + | RegisterDeviceTokenResponse + | PlainMessage + | undefined, + ): boolean { + return proto3.util.equals(RegisterDeviceTokenResponse, a, b) + } +} diff --git a/packages/bsky/src/services/actor/index.ts b/packages/bsky/src/services/actor/index.ts index b8898570688..096bf18be9b 100644 --- a/packages/bsky/src/services/actor/index.ts +++ b/packages/bsky/src/services/actor/index.ts @@ -12,6 +12,7 @@ import { GraphService } from '../graph' import { LabelService } from '../label' import { AtUri } from '@atproto/syntax' import { ids } from '../../lexicon/lexicons' +import { Platform } from '../../notifications' export * from './types' @@ -21,8 +22,8 @@ export class ActorService { constructor( public db: Database, public imgUriBuilder: ImageUriBuilder, - private graph: FromDb, - private label: FromDb, + graph: FromDb, + label: FromDb, ) { this.views = new ActorViews(this.db, this.imgUriBuilder, graph, label) } @@ -214,6 +215,20 @@ export class ActorService { } } } + + async registerPushDeviceToken( + did: string, + token: string, + platform: Platform, + appId: string, + ) { + await this.db + .asPrimary() + .db.insertInto('notification_push_token') + .values({ did, token, platform, appId }) + .onConflict((oc) => oc.doNothing()) + .execute() + } } type ActorResult = Actor diff --git a/packages/bsky/src/services/indexing/processor.ts b/packages/bsky/src/services/indexing/processor.ts index 2a02c61125e..0dad405b9ef 100644 --- a/packages/bsky/src/services/indexing/processor.ts +++ b/packages/bsky/src/services/indexing/processor.ts @@ -257,7 +257,7 @@ export class RecordProcessor { const notifServer = this.notifServer sendOnCommit.push(async () => { try { - const preparedNotifs = await notifServer.prepareNotifsToSend(chunk) + const preparedNotifs = await notifServer.prepareNotifications(chunk) await notifServer.processNotifications(preparedNotifs) } catch (error) { dbLogger.error({ error }, 'error sending push notifications') diff --git a/packages/bsky/src/util/retry.ts b/packages/bsky/src/util/retry.ts index ab96998642a..62b1747815e 100644 --- a/packages/bsky/src/util/retry.ts +++ b/packages/bsky/src/util/retry.ts @@ -1,6 +1,7 @@ import { AxiosError } from 'axios' import { XRPCError, ResponseType } from '@atproto/xrpc' import { RetryOptions, retry } from '@atproto/common' +import { Code, ConnectError } from '@connectrpc/connect' export async function retryHttp( fn: () => Promise, @@ -24,3 +25,14 @@ export function retryableHttp(err: unknown) { const retryableHttpStatusCodes = new Set([ 408, 425, 429, 500, 502, 503, 504, 522, 524, ]) + +export async function retryConnect( + fn: () => Promise, + opts: RetryOptions = {}, +): Promise { + return retry(fn, { retryable: retryableConnect, ...opts }) +} + +export function retryableConnect(err: unknown) { + return err instanceof ConnectError && err.code === Code.Unavailable +} diff --git a/packages/bsky/tests/notification-server.test.ts b/packages/bsky/tests/notification-server.test.ts index 11b9f2395e8..0efd1e448b4 100644 --- a/packages/bsky/tests/notification-server.test.ts +++ b/packages/bsky/tests/notification-server.test.ts @@ -1,14 +1,18 @@ import AtpAgent, { AtUri } from '@atproto/api' import { TestNetwork, SeedClient, basicSeed } from '@atproto/dev-env' -import { NotificationServer } from '../src/notifications' +import { + CourierNotificationServer, + GorushNotificationServer, +} from '../src/notifications' import { Database } from '../src' +import { createCourierClient } from '../src/courier' describe('notification server', () => { let network: TestNetwork let agent: AtpAgent let pdsAgent: AtpAgent let sc: SeedClient - let notifServer: NotificationServer + let notifServer: GorushNotificationServer // account dids, for convenience let alice: string @@ -24,14 +28,17 @@ describe('notification server', () => { await network.processAll() await network.bsky.processAll() alice = sc.dids.alice - notifServer = network.bsky.ctx.notifServer + notifServer = new GorushNotificationServer( + network.bsky.ctx.db.getPrimary(), + 'http://mock', + ) }) afterAll(async () => { await network.close() }) - describe('registerPushNotification', () => { + describe('registerPush', () => { it('registers push notification token and device.', async () => { const res = await agent.api.app.bsky.notification.registerPush( { @@ -95,19 +102,14 @@ describe('notification server', () => { }) describe('NotificationServer', () => { - it('gets user tokens from db', async () => { - const tokens = await notifServer.getTokensByDid([alice]) - expect(tokens[alice][0].token).toEqual('123') - }) - it('gets notification display attributes: title and body', async () => { const db = network.bsky.ctx.db.getPrimary() const notif = await getLikeNotification(db, alice) if (!notif) throw new Error('no notification found') - const attrs = await notifServer.getNotificationDisplayAttributes([notif]) - if (!attrs.length) + const views = await notifServer.getNotificationViews([notif]) + if (!views.length) throw new Error('no notification display attributes found') - expect(attrs[0].title).toEqual('bobby liked your post') + expect(views[0].title).toEqual('bobby liked your post') }) it('filters notifications that violate blocks', async () => { @@ -126,11 +128,11 @@ describe('notification server', () => { did: notif.author, author: notif.did, } - const attrs = await notifServer.getNotificationDisplayAttributes([ + const views = await notifServer.getNotificationViews([ notif, flippedNotif, ]) - expect(attrs.length).toBe(0) + expect(views.length).toBe(0) const uri = new AtUri(blockRef.uri) await pdsAgent.api.app.bsky.graph.block.delete( { repo: alice, rkey: uri.rkey }, @@ -147,8 +149,8 @@ describe('notification server', () => { { actor: notif.author }, { headers: sc.getHeaders(alice), encoding: 'application/json' }, ) - const attrs = await notifServer.getNotificationDisplayAttributes([notif]) - expect(attrs.length).toBe(0) + const views = await notifServer.getNotificationViews([notif]) + expect(views.length).toBe(0) await pdsAgent.api.app.bsky.graph.unmuteActor( { actor: notif.author }, { headers: sc.getHeaders(alice), encoding: 'application/json' }, @@ -182,13 +184,20 @@ describe('notification server', () => { { list: listRef.uri }, { headers: sc.getHeaders(alice), encoding: 'application/json' }, ) - const attrs = await notifServer.getNotificationDisplayAttributes([notif]) - expect(attrs.length).toBe(0) + const views = await notifServer.getNotificationViews([notif]) + expect(views.length).toBe(0) await pdsAgent.api.app.bsky.graph.unmuteActorList( { list: listRef.uri }, { headers: sc.getHeaders(alice), encoding: 'application/json' }, ) }) + }) + + describe('GorushNotificationServer', () => { + it('gets user tokens from db', async () => { + const tokens = await notifServer.getTokensByDid([alice]) + expect(tokens[alice][0].token).toEqual('123') + }) it('prepares notification to be sent', async () => { const db = network.bsky.ctx.db.getPrimary() @@ -198,7 +207,7 @@ describe('notification server', () => { notif, notif /* second one will get dropped by rate limit */, ] - const prepared = await notifServer.prepareNotifsToSend(notifAsArray) + const prepared = await notifServer.prepareNotifications(notifAsArray) expect(prepared).toEqual([ { collapse_id: 'like', @@ -218,6 +227,37 @@ describe('notification server', () => { }) }) + describe('CourierNotificationServer', () => { + it('prepares notification to be sent', async () => { + const db = network.bsky.ctx.db.getPrimary() + const notif = await getLikeNotification(db, alice) + if (!notif) throw new Error('no notification found') + const courierNotifServer = new CourierNotificationServer( + db, + createCourierClient({ baseUrl: 'http://mock', httpVersion: '2' }), + ) + const prepared = await courierNotifServer.prepareNotifications([notif]) + expect(prepared[0]?.id).toBeTruthy() + expect(prepared.map((p) => p.toJson())).toEqual([ + { + id: prepared[0].id, // already ensured it exists + recipientDid: notif.did, + title: 'bobby liked your post', + message: 'again', + collapseKey: 'like', + timestamp: notif.sortAt, + // this is missing, appears to be a quirk of toJson() + // alwaysDeliver: false, + additional: { + reason: notif.reason, + uri: notif.recordUri, + subject: notif.reasonSubject, + }, + }, + ]) + }) + }) + async function getLikeNotification(db: Database, did: string) { return await db.db .selectFrom('notification') diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d7dcdc7db48..d96fa71b5ef 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -234,6 +234,9 @@ importers: multiformats: specifier: ^9.9.0 version: 9.9.0 + murmurhash: + specifier: ^2.0.1 + version: 2.0.1 p-queue: specifier: ^6.6.2 version: 6.6.2 @@ -9790,6 +9793,10 @@ packages: /multiformats@9.9.0: resolution: {integrity: sha512-HoMUjhH9T8DDBNT+6xzkrd9ga/XiBI4xLr58LJACwK6G3HTOPeMz4nB4KJs33L2BelrIJa7P0VuNaVF3hMYfjg==} + /murmurhash@2.0.1: + resolution: {integrity: sha512-5vQEh3y+DG/lMPM0mCGPDnyV8chYg/g7rl6v3Gd8WMF9S429ox3Xk8qrk174kWhG767KQMqqxLD1WnGd77hiew==} + dev: false + /napi-build-utils@1.0.2: resolution: {integrity: sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==}