Appview v1 maintaining device tokens and pushing notifications w/ courier ()

* add courier proto to bsky, build

* update registerPush on appview to support registering device tokens with courier

* setup bsky notifications to send to either gorush or courier

* wire courier push into indexer, test

* courier push retries

* tidy and build
This commit is contained in:
devin ivy 2024-01-23 21:17:32 -05:00 committed by GitHub
parent d108310575
commit 6a318b9f76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1059 additions and 196 deletions

@ -3,7 +3,7 @@ on:
push:
branches:
- main
- appview-v1-sync-mutes
- appview-v1-courier
env:
REGISTRY: ${{ secrets.AWS_ECR_REGISTRY_USEAST2_PACKAGES_REGISTRY }}
USERNAME: ${{ secrets.AWS_ECR_REGISTRY_USEAST2_PACKAGES_USERNAME }}

@ -29,16 +29,16 @@
"test:log": "tail -50 test.log | pino-pretty",
"test:updateSnapshot": "jest --updateSnapshot",
"migration:create": "ts-node ./bin/migration-create.ts",
"buf:gen": "buf generate ../bsync/proto"
"buf:gen": "buf generate ../bsync/proto && buf generate ./proto"
},
"dependencies": {
"@atproto/api": "workspace:^",
"@atproto/common": "workspace:^",
"@atproto/crypto": "workspace:^",
"@atproto/syntax": "workspace:^",
"@atproto/identity": "workspace:^",
"@atproto/lexicon": "workspace:^",
"@atproto/repo": "workspace:^",
"@atproto/syntax": "workspace:^",
"@atproto/xrpc-server": "workspace:^",
"@bufbuild/protobuf": "^1.5.0",
"@connectrpc/connect": "^1.1.4",
@ -55,6 +55,7 @@
"ioredis": "^5.3.2",
"kysely": "^0.22.0",
"multiformats": "^9.9.0",
"murmurhash": "^2.0.1",
"p-queue": "^6.6.2",
"pg": "^8.10.0",
"pino": "^8.15.0",

@ -0,0 +1,56 @@
syntax = "proto3";
package courier;
option go_package = "./;courier";
import "google/protobuf/struct.proto";
import "google/protobuf/timestamp.proto";
//
// Messages
//
// Ping
message PingRequest {}
message PingResponse {}
// Notifications
enum AppPlatform {
APP_PLATFORM_UNSPECIFIED = 0;
APP_PLATFORM_IOS = 1;
APP_PLATFORM_ANDROID = 2;
APP_PLATFORM_WEB = 3;
}
message Notification {
string id = 1;
string recipient_did = 2;
string title = 3;
string message = 4;
string collapse_key = 5;
bool always_deliver = 6;
google.protobuf.Timestamp timestamp = 7;
google.protobuf.Struct additional = 8;
}
message PushNotificationsRequest {
repeated Notification notifications = 1;
}
message PushNotificationsResponse {}
message RegisterDeviceTokenRequest {
string did = 1;
string token = 2;
string app_id = 3;
AppPlatform platform = 4;
}
message RegisterDeviceTokenResponse {}
service Service {
rpc Ping(PingRequest) returns (PingResponse);
rpc PushNotifications(PushNotificationsRequest) returns (PushNotificationsResponse);
rpc RegisterDeviceToken(RegisterDeviceTokenRequest) returns (RegisterDeviceTokenResponse);
}

@ -1,29 +1,63 @@
import assert from 'node:assert'
import { InvalidRequestError } from '@atproto/xrpc-server'
import { Server } from '../../../../lexicon'
import AppContext from '../../../../context'
import { Platform } from '../../../../notifications'
import { CourierClient } from '../../../../courier'
import { AppPlatform } from '../../../../proto/courier_pb'
export default function (server: Server, ctx: AppContext) {
server.app.bsky.notification.registerPush({
auth: ctx.authVerifier.standard,
handler: async ({ auth, input }) => {
handler: async ({ req, auth, input }) => {
const { token, platform, serviceDid, appId } = input.body
const did = auth.credentials.iss
if (serviceDid !== auth.credentials.aud) {
throw new InvalidRequestError('Invalid serviceDid.')
}
const { notifServer } = ctx
if (platform !== 'ios' && platform !== 'android' && platform !== 'web') {
throw new InvalidRequestError(
'Unsupported platform: must be "ios", "android", or "web".',
)
}
await notifServer.registerDeviceForPushNotifications(
did,
token,
platform as Platform,
appId,
)
const db = ctx.db.getPrimary()
const registerDeviceWithAppview = async () => {
await ctx.services
.actor(db)
.registerPushDeviceToken(did, token, platform as Platform, appId)
}
const registerDeviceWithCourier = async (
courierClient: CourierClient,
) => {
await courierClient.registerDeviceToken({
did,
token,
platform:
platform === 'ios'
? AppPlatform.IOS
: platform === 'android'
? AppPlatform.ANDROID
: AppPlatform.WEB,
appId,
})
}
if (ctx.cfg.courierOnlyRegistration) {
assert(ctx.courierClient)
await registerDeviceWithCourier(ctx.courierClient)
} else {
await registerDeviceWithAppview()
if (ctx.courierClient) {
try {
await registerDeviceWithCourier(ctx.courierClient)
} catch (err) {
req.log.warn(err, 'failed to register device token with courier')
}
}
}
},
})
}

@ -36,6 +36,11 @@ export interface ServerConfigValues {
bsyncHttpVersion?: '1.1' | '2'
bsyncIgnoreBadTls?: boolean
bsyncOnlyMutes?: boolean
courierUrl?: string
courierApiKey?: string
courierHttpVersion?: '1.1' | '2'
courierIgnoreBadTls?: boolean
courierOnlyRegistration?: boolean
adminPassword: string
moderatorPassword: string
triagePassword: string
@ -100,6 +105,18 @@ export class ServerConfig {
const bsyncOnlyMutes = process.env.BSKY_BSYNC_ONLY_MUTES === 'true'
assert(!bsyncOnlyMutes || bsyncUrl, 'bsync-only mutes requires a bsync url')
assert(bsyncHttpVersion === '1.1' || bsyncHttpVersion === '2')
const courierUrl = process.env.BSKY_COURIER_URL || undefined
const courierApiKey = process.env.BSKY_COURIER_API_KEY || undefined
const courierHttpVersion = process.env.BSKY_COURIER_HTTP_VERSION || '2'
const courierIgnoreBadTls =
process.env.BSKY_COURIER_IGNORE_BAD_TLS === 'true'
const courierOnlyRegistration =
process.env.BSKY_COURIER_ONLY_REGISTRATION === 'true'
assert(
!courierOnlyRegistration || courierUrl,
'courier-only registration requires a courier url',
)
assert(courierHttpVersion === '1.1' || courierHttpVersion === '2')
const dbPrimaryPostgresUrl =
overrides?.dbPrimaryPostgresUrl || process.env.DB_PRIMARY_POSTGRES_URL
let dbReplicaPostgresUrls = overrides?.dbReplicaPostgresUrls
@ -169,6 +186,11 @@ export class ServerConfig {
bsyncHttpVersion,
bsyncIgnoreBadTls,
bsyncOnlyMutes,
courierUrl,
courierApiKey,
courierHttpVersion,
courierIgnoreBadTls,
courierOnlyRegistration,
adminPassword,
moderatorPassword,
triagePassword,
@ -305,6 +327,26 @@ export class ServerConfig {
return this.cfg.bsyncIgnoreBadTls
}
get courierUrl() {
return this.cfg.courierUrl
}
get courierApiKey() {
return this.cfg.courierApiKey
}
get courierHttpVersion() {
return this.cfg.courierHttpVersion
}
get courierIgnoreBadTls() {
return this.cfg.courierIgnoreBadTls
}
get courierOnlyRegistration() {
return this.cfg.courierOnlyRegistration
}
get adminPassword() {
return this.cfg.adminPassword
}

@ -10,10 +10,10 @@ import { Services } from './services'
import DidRedisCache from './did-cache'
import { BackgroundQueue } from './background'
import { MountedAlgos } from './feed-gen/types'
import { NotificationServer } from './notifications'
import { Redis } from './redis'
import { AuthVerifier } from './auth-verifier'
import { BsyncClient } from './bsync'
import { CourierClient } from './courier'
export class AppContext {
constructor(
@ -29,8 +29,8 @@ export class AppContext {
backgroundQueue: BackgroundQueue
searchAgent?: AtpAgent
bsyncClient?: BsyncClient
courierClient?: CourierClient
algos: MountedAlgos
notifServer: NotificationServer
authVerifier: AuthVerifier
},
) {}
@ -71,10 +71,6 @@ export class AppContext {
return this.opts.redis
}
get notifServer(): NotificationServer {
return this.opts.notifServer
}
get searchAgent(): AtpAgent | undefined {
return this.opts.searchAgent
}
@ -83,6 +79,10 @@ export class AppContext {
return this.opts.bsyncClient
}
get courierClient(): CourierClient | undefined {
return this.opts.courierClient
}
get authVerifier(): AuthVerifier {
return this.opts.authVerifier
}

@ -0,0 +1,41 @@
import { Service } from './proto/courier_connect'
import {
Code,
ConnectError,
PromiseClient,
createPromiseClient,
Interceptor,
} from '@connectrpc/connect'
import {
createConnectTransport,
ConnectTransportOptions,
} from '@connectrpc/connect-node'
export type CourierClient = PromiseClient<typeof Service>
export const createCourierClient = (
opts: ConnectTransportOptions,
): CourierClient => {
const transport = createConnectTransport(opts)
return createPromiseClient(Service, transport)
}
export { Code }
export const isCourierError = (
err: unknown,
code?: Code,
): err is ConnectError => {
if (err instanceof ConnectError) {
return !code || err.code === code
}
return false
}
export const authWithApiKey =
(apiKey: string): Interceptor =>
(next) =>
(req) => {
req.header.set('authorization', `Bearer ${apiKey}`)
return next(req)
}

@ -29,12 +29,12 @@ import {
} from './image/invalidator'
import { BackgroundQueue } from './background'
import { MountedAlgos } from './feed-gen/types'
import { NotificationServer } from './notifications'
import { AtpAgent } from '@atproto/api'
import { Keypair } from '@atproto/crypto'
import { Redis } from './redis'
import { AuthVerifier } from './auth-verifier'
import { authWithApiKey, createBsyncClient } from './bsync'
import { authWithApiKey as bsyncAuth, createBsyncClient } from './bsync'
import { authWithApiKey as courierAuth, createCourierClient } from './courier'
export type { ServerConfigValues } from './config'
export type { MountedAlgos } from './feed-gen/types'
@ -113,7 +113,6 @@ export class BskyAppView {
const backgroundQueue = new BackgroundQueue(db.getPrimary())
const notifServer = new NotificationServer(db.getPrimary())
const searchAgent = config.searchEndpoint
? new AtpAgent({ service: config.searchEndpoint })
: undefined
@ -142,7 +141,18 @@ export class BskyAppView {
httpVersion: config.bsyncHttpVersion ?? '2',
nodeOptions: { rejectUnauthorized: !config.bsyncIgnoreBadTls },
interceptors: config.bsyncApiKey
? [authWithApiKey(config.bsyncApiKey)]
? [bsyncAuth(config.bsyncApiKey)]
: [],
})
: undefined
const courierClient = config.courierUrl
? createCourierClient({
baseUrl: config.courierUrl,
httpVersion: config.courierHttpVersion ?? '2',
nodeOptions: { rejectUnauthorized: !config.courierIgnoreBadTls },
interceptors: config.courierApiKey
? [courierAuth(config.courierApiKey)]
: [],
})
: undefined
@ -159,8 +169,8 @@ export class BskyAppView {
backgroundQueue,
searchAgent,
bsyncClient,
courierClient,
algos,
notifServer,
authVerifier,
})

@ -22,6 +22,10 @@ export interface IndexerConfigValues {
fuzzyFalsePositiveB64?: string
labelerKeywords: Record<string, string>
moderationPushUrl: string
courierUrl?: string
courierApiKey?: string
courierHttpVersion?: '1.1' | '2'
courierIgnoreBadTls?: boolean
indexerConcurrency?: number
indexerPartitionIds: number[]
indexerPartitionBatchSize?: number
@ -72,6 +76,18 @@ export class IndexerConfig {
process.env.MODERATION_PUSH_URL ||
undefined
assert(moderationPushUrl)
const courierUrl =
overrides?.courierUrl || process.env.BSKY_COURIER_URL || undefined
const courierApiKey =
overrides?.courierApiKey || process.env.BSKY_COURIER_API_KEY || undefined
const courierHttpVersion =
overrides?.courierHttpVersion ||
process.env.BSKY_COURIER_HTTP_VERSION ||
'2'
const courierIgnoreBadTls =
overrides?.courierIgnoreBadTls ||
process.env.BSKY_COURIER_IGNORE_BAD_TLS === 'true'
assert(courierHttpVersion === '1.1' || courierHttpVersion === '2')
const hiveApiKey = process.env.HIVE_API_KEY || undefined
const abyssEndpoint = process.env.ABYSS_ENDPOINT
const abyssPassword = process.env.ABYSS_PASSWORD
@ -114,6 +130,10 @@ export class IndexerConfig {
didCacheMaxTTL,
handleResolveNameservers,
moderationPushUrl,
courierUrl,
courierApiKey,
courierHttpVersion,
courierIgnoreBadTls,
hiveApiKey,
abyssEndpoint,
abyssPassword,
@ -185,6 +205,22 @@ export class IndexerConfig {
return this.cfg.moderationPushUrl
}
get courierUrl() {
return this.cfg.courierUrl
}
get courierApiKey() {
return this.cfg.courierApiKey
}
get courierHttpVersion() {
return this.cfg.courierHttpVersion
}
get courierIgnoreBadTls() {
return this.cfg.courierIgnoreBadTls
}
get hiveApiKey() {
return this.cfg.hiveApiKey
}

@ -6,6 +6,7 @@ import { BackgroundQueue } from '../background'
import DidSqlCache from '../did-cache'
import { Redis } from '../redis'
import { AutoModerator } from '../auto-moderator'
import { NotificationServer } from '../notifications'
export class IndexerContext {
constructor(
@ -19,6 +20,7 @@ export class IndexerContext {
didCache: DidSqlCache
backgroundQueue: BackgroundQueue
autoMod: AutoModerator
notifServer?: NotificationServer
},
) {}
@ -57,6 +59,10 @@ export class IndexerContext {
get autoMod(): AutoModerator {
return this.opts.autoMod
}
get notifServer(): NotificationServer | undefined {
return this.opts.notifServer
}
}
export default IndexerContext

@ -11,8 +11,13 @@ import { createServices } from './services'
import { IndexerSubscription } from './subscription'
import { AutoModerator } from '../auto-moderator'
import { Redis } from '../redis'
import { NotificationServer } from '../notifications'
import {
CourierNotificationServer,
GorushNotificationServer,
NotificationServer,
} from '../notifications'
import { CloseFn, createServer, startServer } from './server'
import { authWithApiKey as courierAuth, createCourierClient } from '../courier'
export { IndexerConfig } from './config'
export type { IndexerConfigValues } from './config'
@ -60,9 +65,27 @@ export class BskyIndexer {
backgroundQueue,
})
const notifServer = cfg.pushNotificationEndpoint
? new NotificationServer(db, cfg.pushNotificationEndpoint)
const courierClient = cfg.courierUrl
? createCourierClient({
baseUrl: cfg.courierUrl,
httpVersion: cfg.courierHttpVersion ?? '2',
nodeOptions: { rejectUnauthorized: !cfg.courierIgnoreBadTls },
interceptors: cfg.courierApiKey
? [courierAuth(cfg.courierApiKey)]
: [],
})
: undefined
let notifServer: NotificationServer | undefined
if (courierClient) {
notifServer = new CourierNotificationServer(db, courierClient)
} else if (cfg.pushNotificationEndpoint) {
notifServer = new GorushNotificationServer(
db,
cfg.pushNotificationEndpoint,
)
}
const services = createServices({
idResolver,
autoMod,
@ -79,6 +102,7 @@ export class BskyIndexer {
didCache,
backgroundQueue,
autoMod,
notifServer,
})
const sub = new IndexerSubscription(ctx, {
partitionIds: cfg.indexerPartitionIds,

@ -1,6 +1,8 @@
import axios from 'axios'
import { Insertable, sql } from 'kysely'
import TTLCache from '@isaacs/ttlcache'
import { Struct, Timestamp } from '@bufbuild/protobuf'
import murmur from 'murmurhash'
import { AtUri } from '@atproto/api'
import { MINUTE, chunkArray } from '@atproto/common'
import Database from './db/primary'
@ -9,11 +11,13 @@ import { NotificationPushToken as PushToken } from './db/tables/notification-pus
import logger from './indexer/logger'
import { notSoftDeletedClause, valuesList } from './db/util'
import { ids } from './lexicon/lexicons'
import { retryHttp } from './util/retry'
import { retryConnect, retryHttp } from './util/retry'
import { Notification as CourierNotification } from './proto/courier_pb'
import { CourierClient } from './courier'
export type Platform = 'ios' | 'android' | 'web'
type PushNotification = {
type GorushNotification = {
tokens: string[]
platform: 1 | 2 // 1 = ios, 2 = android
title: string
@ -26,161 +30,24 @@ type PushNotification = {
collapse_key?: string
}
type InsertableNotif = Insertable<Notification>
type NotifRow = Insertable<Notification>
type NotifDisplay = {
type NotifView = {
key: string
rateLimit: boolean
title: string
body: string
notif: InsertableNotif
notif: NotifRow
}
export class NotificationServer {
private rateLimiter = new RateLimiter(1, 30 * MINUTE)
export abstract class NotificationServer<N = unknown> {
constructor(public db: Database) {}
constructor(public db: Database, public pushEndpoint?: string) {}
abstract prepareNotifications(notifs: NotifRow[]): Promise<N[]>
async getTokensByDid(dids: string[]) {
if (!dids.length) return {}
const tokens = await this.db.db
.selectFrom('notification_push_token')
.where('did', 'in', dids)
.selectAll()
.execute()
return tokens.reduce((acc, token) => {
acc[token.did] ??= []
acc[token.did].push(token)
return acc
}, {} as Record<string, PushToken[]>)
}
abstract processNotifications(prepared: N[]): Promise<void>
async prepareNotifsToSend(notifications: InsertableNotif[]) {
const now = Date.now()
const notifsToSend: PushNotification[] = []
const tokensByDid = await this.getTokensByDid(
unique(notifications.map((n) => n.did)),
)
// views for all notifications that have tokens
const notificationViews = await this.getNotificationDisplayAttributes(
notifications.filter((n) => tokensByDid[n.did]),
)
for (const notifView of notificationViews) {
if (!isRecent(notifView.notif.sortAt, 10 * MINUTE)) {
continue // if the notif is from > 10 minutes ago, don't send push notif
}
const { did: userDid } = notifView.notif
const userTokens = tokensByDid[userDid] ?? []
for (const t of userTokens) {
const { appId, platform, token } = t
if (notifView.rateLimit && !this.rateLimiter.check(token, now)) {
continue
}
if (platform === 'ios' || platform === 'android') {
notifsToSend.push({
tokens: [token],
platform: platform === 'ios' ? 1 : 2,
title: notifView.title,
message: notifView.body,
topic: appId,
data: {
reason: notifView.notif.reason,
recordUri: notifView.notif.recordUri,
recordCid: notifView.notif.recordCid,
},
collapse_id: notifView.key,
collapse_key: notifView.key,
})
} else {
// @TODO: Handle web notifs
logger.warn({ did: userDid }, 'cannot send web notification to user')
}
}
}
return notifsToSend
}
/**
* The function `addNotificationsToQueue` adds push notifications to a queue, taking into account rate
* limiting and batching the notifications for efficient processing.
* @param {PushNotification[]} notifs - An array of PushNotification objects. Each PushNotification
* object has a "tokens" property which is an array of tokens.
* @returns void
*/
async processNotifications(notifs: PushNotification[]) {
for (const batch of chunkArray(notifs, 20)) {
try {
await this.sendPushNotifications(batch)
} catch (err) {
logger.error({ err, batch }, 'notification push batch failed')
}
}
}
/** 1. Get the user's token (APNS or FCM for iOS and Android respectively) from the database
User token will be in the format:
did || token || platform (1 = iOS, 2 = Android, 3 = Web)
2. Send notification to `gorush` server with token
Notification will be in the format:
"notifications": [
{
"tokens": string[],
"platform": 1 | 2,
"message": string,
"title": string,
"priority": "normal" | "high",
"image": string, (Android only)
"expiration": number, (iOS only)
"badge": number, (iOS only)
}
]
3. `gorush` will send notification to APNS or FCM
4. store response from `gorush` which contains the ID of the notification
5. If notification needs to be updated or deleted, find the ID of the notification from the database and send a new notification to `gorush` with the ID (repeat step 2)
*/
private async sendPushNotifications(notifications: PushNotification[]) {
// if pushEndpoint is not defined, we are not running in the indexer service, so we can't send push notifications
if (!this.pushEndpoint) {
throw new Error('Push endpoint not defined')
}
// if no notifications, skip and return early
if (notifications.length === 0) {
return
}
const pushEndpoint = this.pushEndpoint
await retryHttp(() =>
axios.post(
pushEndpoint,
{ notifications },
{
headers: {
'Content-Type': 'application/json',
accept: 'application/json',
},
},
),
)
}
async registerDeviceForPushNotifications(
did: string,
token: string,
platform: Platform,
appId: string,
) {
// if token doesn't exist, insert it, on conflict do nothing
await this.db.db
.insertInto('notification_push_token')
.values({ did, token, platform, appId })
.onConflict((oc) => oc.doNothing())
.execute()
}
async getNotificationDisplayAttributes(
notifs: InsertableNotif[],
): Promise<NotifDisplay[]> {
async getNotificationViews(notifs: NotifRow[]): Promise<NotifView[]> {
const { ref } = this.db.db.dynamic
const authorDids = notifs.map((n) => n.author)
const subjectUris = notifs.flatMap((n) => n.reasonSubject ?? [])
@ -219,7 +86,7 @@ export class NotificationServer {
return acc
}, {} as Record<string, { text: string }>)
const results: NotifDisplay[] = []
const results: NotifView[] = []
for (const notif of notifs) {
const {
@ -310,7 +177,7 @@ export class NotificationServer {
return results
}
async findBlocksAndMutes(notifs: InsertableNotif[]) {
private async findBlocksAndMutes(notifs: NotifRow[]) {
const pairs = notifs.map((n) => ({ author: n.author, receiver: n.did }))
const { ref } = this.db.db.dynamic
const blockQb = this.db.db
@ -353,6 +220,155 @@ export class NotificationServer {
}
}
export class GorushNotificationServer extends NotificationServer<GorushNotification> {
private rateLimiter = new RateLimiter(1, 30 * MINUTE)
constructor(public db: Database, public pushEndpoint: string) {
super(db)
}
async prepareNotifications(
notifs: NotifRow[],
): Promise<GorushNotification[]> {
const now = Date.now()
const notifsToSend: GorushNotification[] = []
const tokensByDid = await this.getTokensByDid(
unique(notifs.map((n) => n.did)),
)
// views for all notifications that have tokens
const notificationViews = await this.getNotificationViews(
notifs.filter((n) => tokensByDid[n.did]),
)
for (const notifView of notificationViews) {
if (!isRecent(notifView.notif.sortAt, 10 * MINUTE)) {
continue // if the notif is from > 10 minutes ago, don't send push notif
}
const { did: userDid } = notifView.notif
const userTokens = tokensByDid[userDid] ?? []
for (const t of userTokens) {
const { appId, platform, token } = t
if (notifView.rateLimit && !this.rateLimiter.check(token, now)) {
continue
}
if (platform === 'ios' || platform === 'android') {
notifsToSend.push({
tokens: [token],
platform: platform === 'ios' ? 1 : 2,
title: notifView.title,
message: notifView.body,
topic: appId,
data: {
reason: notifView.notif.reason,
recordUri: notifView.notif.recordUri,
recordCid: notifView.notif.recordCid,
},
collapse_id: notifView.key,
collapse_key: notifView.key,
})
} else {
// @TODO: Handle web notifs
logger.warn({ did: userDid }, 'cannot send web notification to user')
}
}
}
return notifsToSend
}
async getTokensByDid(dids: string[]) {
if (!dids.length) return {}
const tokens = await this.db.db
.selectFrom('notification_push_token')
.where('did', 'in', dids)
.selectAll()
.execute()
return tokens.reduce((acc, token) => {
acc[token.did] ??= []
acc[token.did].push(token)
return acc
}, {} as Record<string, PushToken[]>)
}
async processNotifications(prepared: GorushNotification[]): Promise<void> {
for (const batch of chunkArray(prepared, 20)) {
try {
await this.sendToGorush(batch)
} catch (err) {
logger.error({ err, batch }, 'notification push batch failed')
}
}
}
private async sendToGorush(prepared: GorushNotification[]) {
// if no notifications, skip and return early
if (prepared.length === 0) {
return
}
const pushEndpoint = this.pushEndpoint
await retryHttp(() =>
axios.post(
pushEndpoint,
{ notifications: prepared },
{
headers: {
'content-type': 'application/json',
accept: 'application/json',
},
},
),
)
}
}
export class CourierNotificationServer extends NotificationServer<CourierNotification> {
constructor(public db: Database, public courierClient: CourierClient) {
super(db)
}
async prepareNotifications(
notifs: NotifRow[],
): Promise<CourierNotification[]> {
const notificationViews = await this.getNotificationViews(notifs)
const notifsToSend = notificationViews.map((n) => {
return new CourierNotification({
id: getCourierId(n),
recipientDid: n.notif.did,
title: n.title,
message: n.body,
collapseKey: n.key,
alwaysDeliver: !n.rateLimit,
timestamp: Timestamp.fromDate(new Date(n.notif.sortAt)),
additional: Struct.fromJson({
uri: n.notif.recordUri,
reason: n.notif.reason,
subject: n.notif.reasonSubject || '',
}),
})
})
return notifsToSend
}
async processNotifications(prepared: CourierNotification[]): Promise<void> {
try {
await retryConnect(() =>
this.courierClient.pushNotifications({ notifications: prepared }),
)
} catch (err) {
logger.error({ err }, 'notification push to courier failed')
}
}
}
const getCourierId = (notif: NotifView) => {
const key = [
notif.notif.recordUri,
notif.notif.did,
notif.notif.reason,
notif.notif.reasonSubject || '',
].join('::')
return murmur.v3(key).toString(16)
}
const isRecent = (isoTime: string, timeDiff: number): boolean => {
const diff = Date.now() - new Date(isoTime).getTime()
return diff < timeDiff

@ -0,0 +1,50 @@
// @generated by protoc-gen-connect-es v1.3.0 with parameter "target=ts,import_extension=.ts"
// @generated from file courier.proto (package courier, syntax proto3)
/* eslint-disable */
// @ts-nocheck
import {
PingRequest,
PingResponse,
PushNotificationsRequest,
PushNotificationsResponse,
RegisterDeviceTokenRequest,
RegisterDeviceTokenResponse,
} from './courier_pb.ts'
import { MethodKind } from '@bufbuild/protobuf'
/**
* @generated from service courier.Service
*/
export const Service = {
typeName: 'courier.Service',
methods: {
/**
* @generated from rpc courier.Service.Ping
*/
ping: {
name: 'Ping',
I: PingRequest,
O: PingResponse,
kind: MethodKind.Unary,
},
/**
* @generated from rpc courier.Service.PushNotifications
*/
pushNotifications: {
name: 'PushNotifications',
I: PushNotificationsRequest,
O: PushNotificationsResponse,
kind: MethodKind.Unary,
},
/**
* @generated from rpc courier.Service.RegisterDeviceToken
*/
registerDeviceToken: {
name: 'RegisterDeviceToken',
I: RegisterDeviceTokenRequest,
O: RegisterDeviceTokenResponse,
kind: MethodKind.Unary,
},
},
} as const

@ -0,0 +1,473 @@
// @generated by protoc-gen-es v1.6.0 with parameter "target=ts,import_extension=.ts"
// @generated from file courier.proto (package courier, syntax proto3)
/* eslint-disable */
// @ts-nocheck
import type {
BinaryReadOptions,
FieldList,
JsonReadOptions,
JsonValue,
PartialMessage,
PlainMessage,
} from '@bufbuild/protobuf'
import { Message, proto3, Struct, Timestamp } from '@bufbuild/protobuf'
/**
* @generated from enum courier.AppPlatform
*/
export enum AppPlatform {
/**
* @generated from enum value: APP_PLATFORM_UNSPECIFIED = 0;
*/
UNSPECIFIED = 0,
/**
* @generated from enum value: APP_PLATFORM_IOS = 1;
*/
IOS = 1,
/**
* @generated from enum value: APP_PLATFORM_ANDROID = 2;
*/
ANDROID = 2,
/**
* @generated from enum value: APP_PLATFORM_WEB = 3;
*/
WEB = 3,
}
// Retrieve enum metadata with: proto3.getEnumType(AppPlatform)
proto3.util.setEnumType(AppPlatform, 'courier.AppPlatform', [
{ no: 0, name: 'APP_PLATFORM_UNSPECIFIED' },
{ no: 1, name: 'APP_PLATFORM_IOS' },
{ no: 2, name: 'APP_PLATFORM_ANDROID' },
{ no: 3, name: 'APP_PLATFORM_WEB' },
])
/**
* Ping
*
* @generated from message courier.PingRequest
*/
export class PingRequest extends Message<PingRequest> {
constructor(data?: PartialMessage<PingRequest>) {
super()
proto3.util.initPartial(data, this)
}
static readonly runtime: typeof proto3 = proto3
static readonly typeName = 'courier.PingRequest'
static readonly fields: FieldList = proto3.util.newFieldList(() => [])
static fromBinary(
bytes: Uint8Array,
options?: Partial<BinaryReadOptions>,
): PingRequest {
return new PingRequest().fromBinary(bytes, options)
}
static fromJson(
jsonValue: JsonValue,
options?: Partial<JsonReadOptions>,
): PingRequest {
return new PingRequest().fromJson(jsonValue, options)
}
static fromJsonString(
jsonString: string,
options?: Partial<JsonReadOptions>,
): PingRequest {
return new PingRequest().fromJsonString(jsonString, options)
}
static equals(
a: PingRequest | PlainMessage<PingRequest> | undefined,
b: PingRequest | PlainMessage<PingRequest> | undefined,
): boolean {
return proto3.util.equals(PingRequest, a, b)
}
}
/**
* @generated from message courier.PingResponse
*/
export class PingResponse extends Message<PingResponse> {
constructor(data?: PartialMessage<PingResponse>) {
super()
proto3.util.initPartial(data, this)
}
static readonly runtime: typeof proto3 = proto3
static readonly typeName = 'courier.PingResponse'
static readonly fields: FieldList = proto3.util.newFieldList(() => [])
static fromBinary(
bytes: Uint8Array,
options?: Partial<BinaryReadOptions>,
): PingResponse {
return new PingResponse().fromBinary(bytes, options)
}
static fromJson(
jsonValue: JsonValue,
options?: Partial<JsonReadOptions>,
): PingResponse {
return new PingResponse().fromJson(jsonValue, options)
}
static fromJsonString(
jsonString: string,
options?: Partial<JsonReadOptions>,
): PingResponse {
return new PingResponse().fromJsonString(jsonString, options)
}
static equals(
a: PingResponse | PlainMessage<PingResponse> | undefined,
b: PingResponse | PlainMessage<PingResponse> | undefined,
): boolean {
return proto3.util.equals(PingResponse, a, b)
}
}
/**
* @generated from message courier.Notification
*/
export class Notification extends Message<Notification> {
/**
* @generated from field: string id = 1;
*/
id = ''
/**
* @generated from field: string recipient_did = 2;
*/
recipientDid = ''
/**
* @generated from field: string title = 3;
*/
title = ''
/**
* @generated from field: string message = 4;
*/
message = ''
/**
* @generated from field: string collapse_key = 5;
*/
collapseKey = ''
/**
* @generated from field: bool always_deliver = 6;
*/
alwaysDeliver = false
/**
* @generated from field: google.protobuf.Timestamp timestamp = 7;
*/
timestamp?: Timestamp
/**
* @generated from field: google.protobuf.Struct additional = 8;
*/
additional?: Struct
constructor(data?: PartialMessage<Notification>) {
super()
proto3.util.initPartial(data, this)
}
static readonly runtime: typeof proto3 = proto3
static readonly typeName = 'courier.Notification'
static readonly fields: FieldList = proto3.util.newFieldList(() => [
{ no: 1, name: 'id', kind: 'scalar', T: 9 /* ScalarType.STRING */ },
{
no: 2,
name: 'recipient_did',
kind: 'scalar',
T: 9 /* ScalarType.STRING */,
},
{ no: 3, name: 'title', kind: 'scalar', T: 9 /* ScalarType.STRING */ },
{ no: 4, name: 'message', kind: 'scalar', T: 9 /* ScalarType.STRING */ },
{
no: 5,
name: 'collapse_key',
kind: 'scalar',
T: 9 /* ScalarType.STRING */,
},
{
no: 6,
name: 'always_deliver',
kind: 'scalar',
T: 8 /* ScalarType.BOOL */,
},
{ no: 7, name: 'timestamp', kind: 'message', T: Timestamp },
{ no: 8, name: 'additional', kind: 'message', T: Struct },
])
static fromBinary(
bytes: Uint8Array,
options?: Partial<BinaryReadOptions>,
): Notification {
return new Notification().fromBinary(bytes, options)
}
static fromJson(
jsonValue: JsonValue,
options?: Partial<JsonReadOptions>,
): Notification {
return new Notification().fromJson(jsonValue, options)
}
static fromJsonString(
jsonString: string,
options?: Partial<JsonReadOptions>,
): Notification {
return new Notification().fromJsonString(jsonString, options)
}
static equals(
a: Notification | PlainMessage<Notification> | undefined,
b: Notification | PlainMessage<Notification> | undefined,
): boolean {
return proto3.util.equals(Notification, a, b)
}
}
/**
* @generated from message courier.PushNotificationsRequest
*/
export class PushNotificationsRequest extends Message<PushNotificationsRequest> {
/**
* @generated from field: repeated courier.Notification notifications = 1;
*/
notifications: Notification[] = []
constructor(data?: PartialMessage<PushNotificationsRequest>) {
super()
proto3.util.initPartial(data, this)
}
static readonly runtime: typeof proto3 = proto3
static readonly typeName = 'courier.PushNotificationsRequest'
static readonly fields: FieldList = proto3.util.newFieldList(() => [
{
no: 1,
name: 'notifications',
kind: 'message',
T: Notification,
repeated: true,
},
])
static fromBinary(
bytes: Uint8Array,
options?: Partial<BinaryReadOptions>,
): PushNotificationsRequest {
return new PushNotificationsRequest().fromBinary(bytes, options)
}
static fromJson(
jsonValue: JsonValue,
options?: Partial<JsonReadOptions>,
): PushNotificationsRequest {
return new PushNotificationsRequest().fromJson(jsonValue, options)
}
static fromJsonString(
jsonString: string,
options?: Partial<JsonReadOptions>,
): PushNotificationsRequest {
return new PushNotificationsRequest().fromJsonString(jsonString, options)
}
static equals(
a:
| PushNotificationsRequest
| PlainMessage<PushNotificationsRequest>
| undefined,
b:
| PushNotificationsRequest
| PlainMessage<PushNotificationsRequest>
| undefined,
): boolean {
return proto3.util.equals(PushNotificationsRequest, a, b)
}
}
/**
* @generated from message courier.PushNotificationsResponse
*/
export class PushNotificationsResponse extends Message<PushNotificationsResponse> {
constructor(data?: PartialMessage<PushNotificationsResponse>) {
super()
proto3.util.initPartial(data, this)
}
static readonly runtime: typeof proto3 = proto3
static readonly typeName = 'courier.PushNotificationsResponse'
static readonly fields: FieldList = proto3.util.newFieldList(() => [])
static fromBinary(
bytes: Uint8Array,
options?: Partial<BinaryReadOptions>,
): PushNotificationsResponse {
return new PushNotificationsResponse().fromBinary(bytes, options)
}
static fromJson(
jsonValue: JsonValue,
options?: Partial<JsonReadOptions>,
): PushNotificationsResponse {
return new PushNotificationsResponse().fromJson(jsonValue, options)
}
static fromJsonString(
jsonString: string,
options?: Partial<JsonReadOptions>,
): PushNotificationsResponse {
return new PushNotificationsResponse().fromJsonString(jsonString, options)
}
static equals(
a:
| PushNotificationsResponse
| PlainMessage<PushNotificationsResponse>
| undefined,
b:
| PushNotificationsResponse
| PlainMessage<PushNotificationsResponse>
| undefined,
): boolean {
return proto3.util.equals(PushNotificationsResponse, a, b)
}
}
/**
* @generated from message courier.RegisterDeviceTokenRequest
*/
export class RegisterDeviceTokenRequest extends Message<RegisterDeviceTokenRequest> {
/**
* @generated from field: string did = 1;
*/
did = ''
/**
* @generated from field: string token = 2;
*/
token = ''
/**
* @generated from field: string app_id = 3;
*/
appId = ''
/**
* @generated from field: courier.AppPlatform platform = 4;
*/
platform = AppPlatform.UNSPECIFIED
constructor(data?: PartialMessage<RegisterDeviceTokenRequest>) {
super()
proto3.util.initPartial(data, this)
}
static readonly runtime: typeof proto3 = proto3
static readonly typeName = 'courier.RegisterDeviceTokenRequest'
static readonly fields: FieldList = proto3.util.newFieldList(() => [
{ no: 1, name: 'did', kind: 'scalar', T: 9 /* ScalarType.STRING */ },
{ no: 2, name: 'token', kind: 'scalar', T: 9 /* ScalarType.STRING */ },
{ no: 3, name: 'app_id', kind: 'scalar', T: 9 /* ScalarType.STRING */ },
{
no: 4,
name: 'platform',
kind: 'enum',
T: proto3.getEnumType(AppPlatform),
},
])
static fromBinary(
bytes: Uint8Array,
options?: Partial<BinaryReadOptions>,
): RegisterDeviceTokenRequest {
return new RegisterDeviceTokenRequest().fromBinary(bytes, options)
}
static fromJson(
jsonValue: JsonValue,
options?: Partial<JsonReadOptions>,
): RegisterDeviceTokenRequest {
return new RegisterDeviceTokenRequest().fromJson(jsonValue, options)
}
static fromJsonString(
jsonString: string,
options?: Partial<JsonReadOptions>,
): RegisterDeviceTokenRequest {
return new RegisterDeviceTokenRequest().fromJsonString(jsonString, options)
}
static equals(
a:
| RegisterDeviceTokenRequest
| PlainMessage<RegisterDeviceTokenRequest>
| undefined,
b:
| RegisterDeviceTokenRequest
| PlainMessage<RegisterDeviceTokenRequest>
| undefined,
): boolean {
return proto3.util.equals(RegisterDeviceTokenRequest, a, b)
}
}
/**
* @generated from message courier.RegisterDeviceTokenResponse
*/
export class RegisterDeviceTokenResponse extends Message<RegisterDeviceTokenResponse> {
constructor(data?: PartialMessage<RegisterDeviceTokenResponse>) {
super()
proto3.util.initPartial(data, this)
}
static readonly runtime: typeof proto3 = proto3
static readonly typeName = 'courier.RegisterDeviceTokenResponse'
static readonly fields: FieldList = proto3.util.newFieldList(() => [])
static fromBinary(
bytes: Uint8Array,
options?: Partial<BinaryReadOptions>,
): RegisterDeviceTokenResponse {
return new RegisterDeviceTokenResponse().fromBinary(bytes, options)
}
static fromJson(
jsonValue: JsonValue,
options?: Partial<JsonReadOptions>,
): RegisterDeviceTokenResponse {
return new RegisterDeviceTokenResponse().fromJson(jsonValue, options)
}
static fromJsonString(
jsonString: string,
options?: Partial<JsonReadOptions>,
): RegisterDeviceTokenResponse {
return new RegisterDeviceTokenResponse().fromJsonString(jsonString, options)
}
static equals(
a:
| RegisterDeviceTokenResponse
| PlainMessage<RegisterDeviceTokenResponse>
| undefined,
b:
| RegisterDeviceTokenResponse
| PlainMessage<RegisterDeviceTokenResponse>
| undefined,
): boolean {
return proto3.util.equals(RegisterDeviceTokenResponse, a, b)
}
}

@ -12,6 +12,7 @@ import { GraphService } from '../graph'
import { LabelService } from '../label'
import { AtUri } from '@atproto/syntax'
import { ids } from '../../lexicon/lexicons'
import { Platform } from '../../notifications'
export * from './types'
@ -21,8 +22,8 @@ export class ActorService {
constructor(
public db: Database,
public imgUriBuilder: ImageUriBuilder,
private graph: FromDb<GraphService>,
private label: FromDb<LabelService>,
graph: FromDb<GraphService>,
label: FromDb<LabelService>,
) {
this.views = new ActorViews(this.db, this.imgUriBuilder, graph, label)
}
@ -214,6 +215,20 @@ export class ActorService {
}
}
}
async registerPushDeviceToken(
did: string,
token: string,
platform: Platform,
appId: string,
) {
await this.db
.asPrimary()
.db.insertInto('notification_push_token')
.values({ did, token, platform, appId })
.onConflict((oc) => oc.doNothing())
.execute()
}
}
type ActorResult = Actor

@ -257,7 +257,7 @@ export class RecordProcessor<T, S> {
const notifServer = this.notifServer
sendOnCommit.push(async () => {
try {
const preparedNotifs = await notifServer.prepareNotifsToSend(chunk)
const preparedNotifs = await notifServer.prepareNotifications(chunk)
await notifServer.processNotifications(preparedNotifs)
} catch (error) {
dbLogger.error({ error }, 'error sending push notifications')

@ -1,6 +1,7 @@
import { AxiosError } from 'axios'
import { XRPCError, ResponseType } from '@atproto/xrpc'
import { RetryOptions, retry } from '@atproto/common'
import { Code, ConnectError } from '@connectrpc/connect'
export async function retryHttp<T>(
fn: () => Promise<T>,
@ -24,3 +25,14 @@ export function retryableHttp(err: unknown) {
const retryableHttpStatusCodes = new Set([
408, 425, 429, 500, 502, 503, 504, 522, 524,
])
export async function retryConnect<T>(
fn: () => Promise<T>,
opts: RetryOptions = {},
): Promise<T> {
return retry(fn, { retryable: retryableConnect, ...opts })
}
export function retryableConnect(err: unknown) {
return err instanceof ConnectError && err.code === Code.Unavailable
}

@ -1,14 +1,18 @@
import AtpAgent, { AtUri } from '@atproto/api'
import { TestNetwork, SeedClient, basicSeed } from '@atproto/dev-env'
import { NotificationServer } from '../src/notifications'
import {
CourierNotificationServer,
GorushNotificationServer,
} from '../src/notifications'
import { Database } from '../src'
import { createCourierClient } from '../src/courier'
describe('notification server', () => {
let network: TestNetwork
let agent: AtpAgent
let pdsAgent: AtpAgent
let sc: SeedClient
let notifServer: NotificationServer
let notifServer: GorushNotificationServer
// account dids, for convenience
let alice: string
@ -24,14 +28,17 @@ describe('notification server', () => {
await network.processAll()
await network.bsky.processAll()
alice = sc.dids.alice
notifServer = network.bsky.ctx.notifServer
notifServer = new GorushNotificationServer(
network.bsky.ctx.db.getPrimary(),
'http://mock',
)
})
afterAll(async () => {
await network.close()
})
describe('registerPushNotification', () => {
describe('registerPush', () => {
it('registers push notification token and device.', async () => {
const res = await agent.api.app.bsky.notification.registerPush(
{
@ -95,19 +102,14 @@ describe('notification server', () => {
})
describe('NotificationServer', () => {
it('gets user tokens from db', async () => {
const tokens = await notifServer.getTokensByDid([alice])
expect(tokens[alice][0].token).toEqual('123')
})
it('gets notification display attributes: title and body', async () => {
const db = network.bsky.ctx.db.getPrimary()
const notif = await getLikeNotification(db, alice)
if (!notif) throw new Error('no notification found')
const attrs = await notifServer.getNotificationDisplayAttributes([notif])
if (!attrs.length)
const views = await notifServer.getNotificationViews([notif])
if (!views.length)
throw new Error('no notification display attributes found')
expect(attrs[0].title).toEqual('bobby liked your post')
expect(views[0].title).toEqual('bobby liked your post')
})
it('filters notifications that violate blocks', async () => {
@ -126,11 +128,11 @@ describe('notification server', () => {
did: notif.author,
author: notif.did,
}
const attrs = await notifServer.getNotificationDisplayAttributes([
const views = await notifServer.getNotificationViews([
notif,
flippedNotif,
])
expect(attrs.length).toBe(0)
expect(views.length).toBe(0)
const uri = new AtUri(blockRef.uri)
await pdsAgent.api.app.bsky.graph.block.delete(
{ repo: alice, rkey: uri.rkey },
@ -147,8 +149,8 @@ describe('notification server', () => {
{ actor: notif.author },
{ headers: sc.getHeaders(alice), encoding: 'application/json' },
)
const attrs = await notifServer.getNotificationDisplayAttributes([notif])
expect(attrs.length).toBe(0)
const views = await notifServer.getNotificationViews([notif])
expect(views.length).toBe(0)
await pdsAgent.api.app.bsky.graph.unmuteActor(
{ actor: notif.author },
{ headers: sc.getHeaders(alice), encoding: 'application/json' },
@ -182,13 +184,20 @@ describe('notification server', () => {
{ list: listRef.uri },
{ headers: sc.getHeaders(alice), encoding: 'application/json' },
)
const attrs = await notifServer.getNotificationDisplayAttributes([notif])
expect(attrs.length).toBe(0)
const views = await notifServer.getNotificationViews([notif])
expect(views.length).toBe(0)
await pdsAgent.api.app.bsky.graph.unmuteActorList(
{ list: listRef.uri },
{ headers: sc.getHeaders(alice), encoding: 'application/json' },
)
})
})
describe('GorushNotificationServer', () => {
it('gets user tokens from db', async () => {
const tokens = await notifServer.getTokensByDid([alice])
expect(tokens[alice][0].token).toEqual('123')
})
it('prepares notification to be sent', async () => {
const db = network.bsky.ctx.db.getPrimary()
@ -198,7 +207,7 @@ describe('notification server', () => {
notif,
notif /* second one will get dropped by rate limit */,
]
const prepared = await notifServer.prepareNotifsToSend(notifAsArray)
const prepared = await notifServer.prepareNotifications(notifAsArray)
expect(prepared).toEqual([
{
collapse_id: 'like',
@ -218,6 +227,37 @@ describe('notification server', () => {
})
})
describe('CourierNotificationServer', () => {
it('prepares notification to be sent', async () => {
const db = network.bsky.ctx.db.getPrimary()
const notif = await getLikeNotification(db, alice)
if (!notif) throw new Error('no notification found')
const courierNotifServer = new CourierNotificationServer(
db,
createCourierClient({ baseUrl: 'http://mock', httpVersion: '2' }),
)
const prepared = await courierNotifServer.prepareNotifications([notif])
expect(prepared[0]?.id).toBeTruthy()
expect(prepared.map((p) => p.toJson())).toEqual([
{
id: prepared[0].id, // already ensured it exists
recipientDid: notif.did,
title: 'bobby liked your post',
message: 'again',
collapseKey: 'like',
timestamp: notif.sortAt,
// this is missing, appears to be a quirk of toJson()
// alwaysDeliver: false,
additional: {
reason: notif.reason,
uri: notif.recordUri,
subject: notif.reasonSubject,
},
},
])
})
})
async function getLikeNotification(db: Database, did: string) {
return await db.db
.selectFrom('notification')

7
pnpm-lock.yaml generated

@ -234,6 +234,9 @@ importers:
multiformats:
specifier: ^9.9.0
version: 9.9.0
murmurhash:
specifier: ^2.0.1
version: 2.0.1
p-queue:
specifier: ^6.6.2
version: 6.6.2
@ -9790,6 +9793,10 @@ packages:
/multiformats@9.9.0:
resolution: {integrity: sha512-HoMUjhH9T8DDBNT+6xzkrd9ga/XiBI4xLr58LJACwK6G3HTOPeMz4nB4KJs33L2BelrIJa7P0VuNaVF3hMYfjg==}
/murmurhash@2.0.1:
resolution: {integrity: sha512-5vQEh3y+DG/lMPM0mCGPDnyV8chYg/g7rl6v3Gd8WMF9S429ox3Xk8qrk174kWhG767KQMqqxLD1WnGd77hiew==}
dev: false
/napi-build-utils@1.0.2:
resolution: {integrity: sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==}