Streaming: decouple producer side more, refactor db class ()

* Begin to decouple message queue consumers from the queue itself

* Tidy

* Reorganize pds message queue code out of db

* Decouple message queue, repo, and actor functionality from db instance w/ services

* Move repo processing into repo service

* Tidy

* Move repo blobs functionality into service

* Tidy

* Ensure to close message queue in all pds tests

* Fix typo

* Force specifying a db when using a service

* Reorg pds record plugins into record service

* Rename pds stream/ to event-stream/

* Tidy and fixes
This commit is contained in:
devin ivy 2022-12-15 00:00:14 -05:00 committed by GitHub
parent 91e7828021
commit ed9556f049
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
59 changed files with 942 additions and 853 deletions

@ -6,15 +6,15 @@ import * as handleLib from '@atproto/handle'
import * as locals from '../../../../locals'
import * as lex from '../../../../lexicon/lexicons'
import { TID } from '@atproto/common'
import { UserAlreadyExistsError } from '../../../../db'
import * as repo from '../../../../repo'
import ServerAuth from '../../../../auth'
import { UserAlreadyExistsError } from '../../../../services/actor'
export default function (server: Server) {
server.app.bsky.actor.createScene({
auth: ServerAuth.verifier,
handler: async ({ auth, input, res }) => {
const { db, config, keypair, logger } = locals.get(res)
const { db, services, config, keypair, logger } = locals.get(res)
const { recoveryKey } = input.body
const requester = auth.credentials.did
@ -40,9 +40,11 @@ export default function (server: Server) {
const now = new Date().toISOString()
const result = await db.transaction(async (dbTxn) => {
const actorTxn = services.actor(dbTxn)
const repoTxn = services.repo(dbTxn)
// Pre-register before going out to PLC to get a real did
try {
await dbTxn.preregisterDid(handle, tempDid)
await actorTxn.preregisterDid(handle, tempDid)
} catch (err) {
if (err instanceof UserAlreadyExistsError) {
throw new InvalidRequestError(
@ -77,7 +79,7 @@ export default function (server: Server) {
$type: lex.ids.AppBskySystemDeclaration,
actorType: APP_BSKY_SYSTEM.ActorScene,
}
await dbTxn.finalizeDid(handle, did, tempDid, declaration)
await actorTxn.finalizeDid(handle, did, tempDid, declaration)
await dbTxn.db
.insertInto('scene')
.values({ handle, owner: requester, createdAt: now })
@ -169,9 +171,9 @@ export default function (server: Server) {
])
await Promise.all([
repo.createRepo(dbTxn, did, sceneAuth, sceneWrites, now),
repo.writeToRepo(dbTxn, requester, userAuth, userWrites, now),
repo.indexWrites(dbTxn, [...sceneWrites, ...userWrites], now),
repoTxn.createRepo(did, sceneAuth, sceneWrites, now),
repoTxn.writeToRepo(requester, userAuth, userWrites, now),
repoTxn.indexWrites([...sceneWrites, ...userWrites], now),
])
return {

@ -4,7 +4,6 @@ import { countAll, actorWhereClause } from '../../../../db/util'
import * as locals from '../../../../locals'
import { getDeclarationSimple } from '../util'
import ServerAuth from '../../../../auth'
import { CID } from 'multiformats/cid'
export default function (server: Server) {
server.app.bsky.actor.getProfile({

@ -15,11 +15,13 @@ export default function (server: Server) {
server.app.bsky.actor.updateProfile({
auth: ServerAuth.verifier,
handler: async ({ auth, input, res }) => {
const { db, blobstore } = locals.get(res)
const { db, services } = locals.get(res)
const requester = auth.credentials.did
const did = input.body.did || requester
const authorized = await db.isUserControlledRepo(did, requester)
const authorized = await services
.repo(db)
.isUserControlledRepo(did, requester)
if (!authorized) {
throw new AuthRequiredError()
}
@ -30,13 +32,15 @@ export default function (server: Server) {
async (
dbTxn,
): Promise<{ profileCid: CID; updated: Profile.Record }> => {
const recordTxn = services.record(dbTxn)
const repoTxn = services.repo(dbTxn)
const now = new Date().toISOString()
let updated
const uri = AtUri.make(did, profileNsid, 'self')
const current = (await dbTxn.getRecord(uri, null))?.value
const current = (await recordTxn.getRecord(uri, null))?.value
if (current) {
if (!db.records.profile.matchesSchema(current)) {
if (!recordTxn.records.profile.matchesSchema(current)) {
// @TODO need a way to get a profile out of a broken state
throw new InvalidRequestError('could not parse current profile')
}
@ -58,7 +62,7 @@ export default function (server: Server) {
}
}
updated = common.noUndefinedVals(updated)
if (!db.records.profile.matchesSchema(updated)) {
if (!recordTxn.records.profile.matchesSchema(updated)) {
throw new InvalidRequestError(
'requested updates do not produce a valid profile doc',
)
@ -71,14 +75,8 @@ export default function (server: Server) {
value: updated,
})
const commit = await repo.writeToRepo(
dbTxn,
did,
authStore,
writes,
now,
)
await repo.processWriteBlobs(dbTxn, blobstore, did, commit, writes)
const commit = await repoTxn.writeToRepo(did, authStore, writes, now)
await repoTxn.blobs.processWriteBlobs(did, commit, writes)
const write = writes[0]
let profileCid: CID
@ -106,7 +104,7 @@ export default function (server: Server) {
.execute()
} else if (write.action === 'create') {
profileCid = write.cid
await dbTxn.indexRecord(uri, profileCid, updated, now)
await recordTxn.indexRecord(uri, profileCid, updated, now)
} else {
// should never hit this
throw new Error(

@ -12,13 +12,14 @@ export default function (server: Server) {
auth: ServerAuth.verifier,
handler: async ({ auth, input, res }) => {
const { subject, direction } = input.body
const { db } = locals.get(res)
const { db, services } = locals.get(res)
const requester = auth.credentials.did
const authStore = await locals.getAuthstore(res, requester)
const now = new Date().toISOString()
const voteUri = await db.transaction(async (dbTxn) => {
const repoTxn = services.repo(dbTxn)
const existingVotes = await dbTxn.db
.selectFrom('vote')
.select(['uri', 'direction'])
@ -67,8 +68,8 @@ export default function (server: Server) {
}
await Promise.all([
await repo.writeToRepo(dbTxn, requester, authStore, writes, now),
await repo.indexWrites(dbTxn, writes, now),
await repoTxn.writeToRepo(requester, authStore, writes, now),
await repoTxn.indexWrites(writes, now),
])
return create?.uri.toString()

@ -8,7 +8,7 @@ export default function (server: Server) {
auth: ServerAuth.verifier,
handler: async ({ input, auth, res }) => {
const { seenAt } = input.body
const { db } = locals.get(res)
const { db, services } = locals.get(res)
const requester = auth.credentials.did
let parsed: string
@ -18,7 +18,7 @@ export default function (server: Server) {
throw new InvalidRequestError('Invalid date')
}
const user = await db.getUser(requester)
const user = await services.actor(db).getUser(requester)
if (!user) {
throw new InvalidRequestError(`Could not find user: ${requester}`)
}

@ -2,14 +2,14 @@ import { InvalidRequestError } from '@atproto/xrpc-server'
import { PlcClient } from '@atproto/plc'
import * as crypto from '@atproto/crypto'
import * as handleLib from '@atproto/handle'
import { cidForData } from '@atproto/common'
import { Server, APP_BSKY_SYSTEM } from '../../../lexicon'
import * as locals from '../../../locals'
import { countAll } from '../../../db/util'
import { UserAlreadyExistsError } from '../../../db'
import { grantRefreshToken } from './util/auth'
import * as lex from '../../../lexicon/lexicons'
import * as repo from '../../../repo'
import { cidForData } from '@atproto/common'
import { UserAlreadyExistsError } from '../../../services/actor'
export default function (server: Server) {
server.com.atproto.server.getAccountsConfig(({ res }) => {
@ -36,7 +36,7 @@ export default function (server: Server) {
server.com.atproto.account.create(async ({ input, res }) => {
const { email, password, inviteCode, recoveryKey } = input.body
const { db, auth, config, keypair, logger } = locals.get(res)
const { db, services, auth, config, keypair, logger } = locals.get(res)
let handle: string
try {
@ -60,6 +60,8 @@ export default function (server: Server) {
const now = new Date().toISOString()
const result = await db.transaction(async (dbTxn) => {
const actorTxn = services.actor(dbTxn)
const repoTxn = services.repo(dbTxn)
if (config.inviteRequired) {
if (!inviteCode) {
throw new InvalidRequestError(
@ -93,7 +95,7 @@ export default function (server: Server) {
// Pre-register user before going out to PLC to get a real did
try {
await dbTxn.preregisterDid(handle, tempDid)
await actorTxn.preregisterDid(handle, tempDid)
} catch (err) {
if (err instanceof UserAlreadyExistsError) {
throw new InvalidRequestError(`Handle already taken: ${handle}`)
@ -101,7 +103,7 @@ export default function (server: Server) {
throw err
}
try {
await dbTxn.registerUser(email, handle, password)
await actorTxn.registerUser(email, handle, password)
} catch (err) {
if (err instanceof UserAlreadyExistsError) {
throw new InvalidRequestError(`Email already taken: ${email}`)
@ -133,7 +135,7 @@ export default function (server: Server) {
$type: lex.ids.AppBskySystemDeclaration,
actorType: APP_BSKY_SYSTEM.ActorUser,
}
await dbTxn.finalizeDid(handle, did, tempDid, declaration)
await actorTxn.finalizeDid(handle, did, tempDid, declaration)
if (config.inviteRequired && inviteCode) {
await dbTxn.db
.insertInto('invite_code_use')
@ -154,8 +156,8 @@ export default function (server: Server) {
// Setup repo root
const authStore = locals.getAuthstore(res, did)
await repo.createRepo(dbTxn, did, authStore, [write], now)
await repo.indexWrites(dbTxn, [write], now)
await repoTxn.createRepo(did, authStore, [write], now)
await repoTxn.indexWrites([write], now)
const declarationCid = await cidForData(declaration)
const access = auth.createAccessToken(did)

@ -1,20 +1,16 @@
import { Server } from '../../../lexicon'
import * as locals from '../../../locals'
import * as repo from '../../../repo'
import ServerAuth from '../../../auth'
export default function (server: Server) {
server.com.atproto.blob.upload({
auth: ServerAuth.verifier,
handler: async ({ input, res }) => {
const { db, blobstore } = locals.get(res)
const { db, services } = locals.get(res)
const cid = await repo.addUntetheredBlob(
db,
blobstore,
input.encoding,
input.body,
)
const cid = await services
.repo(db)
.blobs.addUntetheredBlob(input.encoding, input.body)
return {
encoding: 'application/json',

@ -4,7 +4,7 @@ import * as locals from '../../../locals'
export default function (server: Server) {
server.com.atproto.handle.resolve(async ({ params, res }) => {
const { db, config } = locals.get(res)
const { db, services, config } = locals.get(res)
const handle = params.handle
@ -19,7 +19,7 @@ export default function (server: Server) {
if (!supportedHandle) {
throw new InvalidRequestError('Not a supported handle domain')
}
const user = await db.getUser(handle)
const user = await services.actor(db).getUser(handle)
if (!user) {
throw new InvalidRequestError('Unable to resolve handle')
}

@ -7,10 +7,10 @@ import * as locals from '../../../locals'
export default function (server: Server) {
server.com.atproto.account.requestPasswordReset(async ({ input, res }) => {
const { db, mailer, config } = locals.get(res)
const { db, services, mailer, config } = locals.get(res)
const email = input.body.email.toLowerCase()
const user = await db.getUserByEmail(email)
const user = await services.actor(db).getUserByEmail(email)
if (user) {
// By signing with the password hash, this jwt becomes invalid once the user changes their password.
@ -29,7 +29,7 @@ export default function (server: Server) {
})
server.com.atproto.account.resetPassword(async ({ input, res }) => {
const { db, config } = locals.get(res)
const { db, services, config } = locals.get(res)
const { token, password } = input.body
const tokenBody = jwt.decode(token)
@ -42,7 +42,7 @@ export default function (server: Server) {
return createInvalidTokenError('Malformed token')
}
const user = await db.getUser(did)
const user = await services.actor(db).getUser(did)
if (!user) {
return createInvalidTokenError('Token could not be verified')
}
@ -59,7 +59,7 @@ export default function (server: Server) {
// Token had correct scope, was not expired, and referenced
// a user whose password has not changed since token issuance.
await db.updateUserPassword(user.handle, password)
await services.actor(db).updateUserPassword(user.handle, password)
})
}

@ -14,8 +14,8 @@ export default function (server: Server) {
server.com.atproto.repo.describe(async ({ params, res }) => {
const { user } = params
const { db, auth } = locals.get(res)
const userObj = await db.getUser(user)
const { db, auth, services } = locals.get(res)
const userObj = await services.actor(db).getUser(user)
if (userObj === null) {
throw new InvalidRequestError(`Could not find user: ${user}`)
}
@ -30,7 +30,9 @@ export default function (server: Server) {
const handle = didResolver.getHandle(didDoc)
const handleIsCorrect = handle === userObj.handle
const collections = await db.listCollectionsForDid(userObj.did)
const collections = await services
.record(db)
.listCollectionsForDid(userObj.did)
return {
encoding: 'application/json',
@ -47,20 +49,22 @@ export default function (server: Server) {
server.com.atproto.repo.listRecords(async ({ params, res }) => {
const { user, collection, limit, before, after, reverse } = params
const db = locals.db(res)
const did = await db.getDidForActor(user)
const { db, services } = locals.get(res)
const did = await services.actor(db).getDidForActor(user)
if (!did) {
throw new InvalidRequestError(`Could not find user: ${user}`)
}
const records = await db.listRecordsForCollection(
did,
collection,
limit || 50,
reverse || false,
before,
after,
)
const records = await services
.record(db)
.listRecordsForCollection(
did,
collection,
limit || 50,
reverse || false,
before,
after,
)
const lastRecord = records.at(-1)
const lastUri = lastRecord && new AtUri(lastRecord?.uri)
@ -77,16 +81,16 @@ export default function (server: Server) {
server.com.atproto.repo.getRecord(async ({ params, res }) => {
const { user, collection, rkey, cid } = params
const db = locals.db(res)
const { db, services } = locals.get(res)
const did = await db.getDidForActor(user)
const did = await services.actor(db).getDidForActor(user)
if (!did) {
throw new InvalidRequestError(`Could not find user: ${user}`)
}
const uri = new AtUri(`${did}/${collection}/${rkey}`)
const record = await db.getRecord(uri, cid || null)
const record = await services.record(db).getRecord(uri, cid || null)
if (!record) {
throw new InvalidRequestError(`Could not locate record: ${uri}`)
}
@ -101,9 +105,11 @@ export default function (server: Server) {
handler: async ({ input, auth, res }) => {
const tx = input.body
const { did, validate } = tx
const { db, blobstore } = locals.get(res)
const { db, services } = locals.get(res)
const requester = auth.credentials.did
const authorized = await db.isUserControlledRepo(did, requester)
const authorized = await services
.repo(db)
.isUserControlledRepo(did, requester)
if (!authorized) {
throw new AuthRequiredError()
}
@ -117,7 +123,9 @@ export default function (server: Server) {
for (const write of tx.writes) {
if (write.action === 'create' || write.action === 'update') {
try {
db.assertValidRecord(write.collection, write.value)
services
.record(db)
.assertValidRecord(write.collection, write.value)
} catch (e) {
throw new InvalidRequestError(
`Invalid ${write.collection} record: ${
@ -149,7 +157,8 @@ export default function (server: Server) {
await db.transaction(async (dbTxn) => {
const now = new Date().toISOString()
await repo.processWrites(dbTxn, did, authStore, blobstore, writes, now)
const repoTxn = services.repo(dbTxn)
await repoTxn.processWrites(did, authStore, writes, now)
})
},
})
@ -160,16 +169,18 @@ export default function (server: Server) {
const { did, collection, record } = input.body
const validate =
typeof input.body.validate === 'boolean' ? input.body.validate : true
const { db, blobstore } = locals.get(res)
const { db, services } = locals.get(res)
const requester = auth.credentials.did
const authorized = await db.isUserControlledRepo(did, requester)
const authorized = await services
.repo(db)
.isUserControlledRepo(did, requester)
if (!authorized) {
throw new AuthRequiredError()
}
if (validate) {
try {
db.assertValidRecord(collection, record)
services.record(db).assertValidRecord(collection, record)
} catch (e) {
throw new InvalidRequestError(
`Invalid ${collection} record: ${
@ -201,7 +212,8 @@ export default function (server: Server) {
})
await db.transaction(async (dbTxn) => {
await repo.processWrites(dbTxn, did, authStore, blobstore, [write], now)
const repoTxn = services.repo(dbTxn)
await repoTxn.processWrites(did, authStore, [write], now)
})
return {
@ -219,9 +231,11 @@ export default function (server: Server) {
auth: ServerAuth.verifier,
handler: async ({ input, auth, res }) => {
const { did, collection, rkey } = input.body
const { db, blobstore } = locals.get(res)
const { db, services } = locals.get(res)
const requester = auth.credentials.did
const authorized = await db.isUserControlledRepo(did, requester)
const authorized = await services
.repo(db)
.isUserControlledRepo(did, requester)
if (!authorized) {
throw new AuthRequiredError()
}
@ -236,7 +250,8 @@ export default function (server: Server) {
})
await db.transaction(async (dbTxn) => {
await repo.processWrites(dbTxn, did, authStore, blobstore, write, now)
const repoTxn = services.repo(dbTxn)
await repoTxn.processWrites(did, authStore, write, now)
})
},
})

@ -8,9 +8,9 @@ export default function (server: Server) {
server.com.atproto.session.get({
auth: ServerAuth.verifier,
handler: async ({ auth, res }) => {
const { db } = locals.get(res)
const { db, services } = locals.get(res)
const did = auth.credentials.did
const user = await db.getUser(did)
const user = await services.actor(db).getUser(did)
if (!user) {
throw new InvalidRequestError(
`Could not find user info for account: ${did}`,
@ -26,13 +26,15 @@ export default function (server: Server) {
server.com.atproto.session.create(async ({ input, res }) => {
const { password } = input.body
const handle = input.body.handle.toLowerCase()
const { db, auth } = locals.get(res)
const validPass = await db.verifyUserPassword(handle, password)
const { db, services, auth } = locals.get(res)
const validPass = await services
.actor(db)
.verifyUserPassword(handle, password)
if (!validPass) {
throw new AuthRequiredError('Invalid handle or password')
}
const user = await db.getUser(handle)
const user = await services.actor(db).getUser(handle)
if (!user) {
throw new InvalidRequestError(
`Could not find user info for account: ${handle}`,
@ -57,9 +59,9 @@ export default function (server: Server) {
server.com.atproto.session.refresh({
auth: ServerAuth.refreshVerifier,
handler: async ({ req, res, ...ctx }) => {
const { db, auth } = locals.get(res)
const { db, services, auth } = locals.get(res)
const did = ctx.auth.credentials.did
const user = await db.getUser(did)
const user = await services.actor(db).getUser(did)
if (!user) {
throw new InvalidRequestError(
`Could not find user info for account: ${did}`,

@ -8,8 +8,8 @@ import SqlBlockstore from '../../../sql-blockstore'
export default function (server: Server) {
server.com.atproto.sync.getRoot(async ({ params, res }) => {
const { did } = params
const db = locals.db(res)
const root = await db.getRepoRoot(did)
const { db, services } = locals.get(res)
const root = await services.repo(db).getRepoRoot(did)
if (root === null) {
throw new InvalidRequestError(`Could not find root for DID: ${did}`)
}
@ -21,8 +21,8 @@ export default function (server: Server) {
server.com.atproto.sync.getRepo(async ({ params, res }) => {
const { did, from = null } = params
const { db } = locals.get(res)
const repoRoot = await db.getRepoRoot(did)
const { db, services } = locals.get(res)
const repoRoot = await services.repo(db).getRepoRoot(did)
if (repoRoot === null) {
throw new InvalidRequestError(`Could not find repo for DID: ${did}`)
}

@ -21,10 +21,10 @@ import * as trend from './tables/trend'
import * as follow from './tables/follow'
import * as blob from './tables/blob'
import * as repoBlob from './tables/repo-blob'
import * as messageQueue from './message-queue/tables/message-queue'
import * as messageQueueCursor from './message-queue/tables/message-queue-cursor'
import * as sceneMemberCount from './message-queue/tables/scene-member-count'
import * as sceneVotesOnPost from './message-queue/tables/scene-votes-on-post'
import * as messageQueue from './tables/message-queue'
import * as messageQueueCursor from './tables/message-queue-cursor'
import * as sceneMemberCount from './tables/scene-member-count'
import * as sceneVotesOnPost from './tables/scene-votes-on-post'
export type DatabaseSchema = user.PartialDB &
didHandle.PartialDB &

@ -2,62 +2,18 @@ import assert from 'assert'
import { Kysely, SqliteDialect, PostgresDialect, Migrator } from 'kysely'
import SqliteDB from 'better-sqlite3'
import { Pool as PgPool, types as pgTypes } from 'pg'
import { ValidationError } from '@atproto/lexicon'
import * as Declaration from './records/declaration'
import * as Post from './records/post'
import * as Vote from './records/vote'
import * as Repost from './records/repost'
import * as Trend from './records/trend'
import * as Follow from './records/follow'
import * as Assertion from './records/assertion'
import * as Confirmation from './records/confirmation'
import * as Profile from './records/profile'
import { AtUri } from '@atproto/uri'
import * as common from '@atproto/common'
import { CID } from 'multiformats/cid'
import { dbLogger as log } from '../logger'
import { DatabaseSchema } from './database-schema'
import * as scrypt from './scrypt'
import { User } from './tables/user'
import { dummyDialect } from './util'
import * as migrations from './migrations'
import { CtxMigrationProvider } from './migrations/provider'
import { DidHandle } from './tables/did-handle'
import { Record as DeclarationRecord } from '../lexicon/types/app/bsky/system/declaration'
import { APP_BSKY_GRAPH } from '../lexicon'
import { MessageQueue } from './types'
export class Database {
migrator: Migrator
records: {
declaration: Declaration.PluginType
post: Post.PluginType
vote: Vote.PluginType
repost: Repost.PluginType
trend: Trend.PluginType
follow: Follow.PluginType
profile: Profile.PluginType
assertion: Assertion.PluginType
confirmation: Confirmation.PluginType
}
constructor(
public db: Kysely<DatabaseSchema>,
public dialect: Dialect,
public schema?: string,
public messageQueue?: MessageQueue,
) {
this.records = {
declaration: Declaration.makePlugin(db),
post: Post.makePlugin(db),
vote: Vote.makePlugin(db),
repost: Repost.makePlugin(db),
trend: Trend.makePlugin(db),
follow: Follow.makePlugin(db),
assertion: Assertion.makePlugin(db),
confirmation: Confirmation.makePlugin(db),
profile: Profile.makePlugin(db),
}
this.migrator = new Migrator({
db,
migrationTableSchema: schema,
@ -107,12 +63,7 @@ export class Database {
async transaction<T>(fn: (db: Database) => Promise<T>): Promise<T> {
return await this.db.transaction().execute((txn) => {
const dbTxn = new Database(
txn,
this.dialect,
this.schema,
this.messageQueue,
)
const dbTxn = new Database(txn, this.dialect, this.schema)
return fn(dbTxn)
})
}
@ -126,7 +77,6 @@ export class Database {
}
async close(): Promise<void> {
this.messageQueue?.destroy()
await this.db.destroy()
}
@ -143,333 +93,6 @@ export class Database {
}
return results
}
setMessageQueue(mq: MessageQueue) {
this.messageQueue = mq
}
async getRepoRoot(did: string, forUpdate?: boolean): Promise<CID | null> {
let builder = this.db
.selectFrom('repo_root')
.selectAll()
.where('did', '=', did)
if (forUpdate) {
this.assertTransaction()
if (this.dialect !== 'sqlite') {
// SELECT FOR UPDATE is not supported by sqlite, but sqlite txs are SERIALIZABLE so we don't actually need it
builder = builder.forUpdate()
}
}
const found = await builder.executeTakeFirst()
return found ? CID.parse(found.root) : null
}
async updateRepoRoot(
did: string,
root: CID,
prev: CID,
timestamp?: string,
): Promise<boolean> {
log.debug({ did, root: root.toString() }, 'updating repo root')
const res = await this.db
.updateTable('repo_root')
.set({
root: root.toString(),
indexedAt: timestamp || new Date().toISOString(),
})
.where('did', '=', did)
.where('root', '=', prev.toString())
.executeTakeFirst()
if (res.numUpdatedRows > 0) {
log.info({ did, root: root.toString() }, 'updated repo root')
return true
} else {
log.info(
{ did, root: root.toString() },
'failed to update repo root: misordered',
)
return false
}
}
async getUser(handleOrDid: string): Promise<(User & DidHandle) | null> {
let query = this.db
.selectFrom('user')
.innerJoin('did_handle', 'did_handle.handle', 'user.handle')
.selectAll()
if (handleOrDid.startsWith('did:')) {
query = query.where('did', '=', handleOrDid)
} else {
query = query.where('did_handle.handle', '=', handleOrDid.toLowerCase())
}
const found = await query.executeTakeFirst()
return found || null
}
async getUserByEmail(email: string): Promise<(User & DidHandle) | null> {
const found = await this.db
.selectFrom('user')
.innerJoin('did_handle', 'did_handle.handle', 'user.handle')
.selectAll()
.where('email', '=', email.toLowerCase())
.executeTakeFirst()
return found || null
}
async getDidForActor(handleOrDid: string): Promise<string | null> {
if (handleOrDid.startsWith('did:')) return handleOrDid
const found = await this.db
.selectFrom('did_handle')
.where('handle', '=', handleOrDid)
.select('did')
.executeTakeFirst()
return found ? found.did : null
}
async registerUser(email: string, handle: string, password: string) {
this.assertTransaction()
log.debug({ handle, email }, 'registering user')
const inserted = await this.db
.insertInto('user')
.values({
email: email.toLowerCase(),
handle: handle,
password: await scrypt.hash(password),
createdAt: new Date().toISOString(),
lastSeenNotifs: new Date().toISOString(),
})
.onConflict((oc) => oc.doNothing())
.returning('handle')
.executeTakeFirst()
if (!inserted) {
throw new UserAlreadyExistsError()
}
log.info({ handle, email }, 'registered user')
}
async preregisterDid(handle: string, tempDid: string) {
this.assertTransaction()
const inserted = await this.db
.insertInto('did_handle')
.values({
handle,
did: tempDid,
actorType: 'temp',
declarationCid: 'temp',
})
.onConflict((oc) => oc.doNothing())
.returning('handle')
.executeTakeFirst()
if (!inserted) {
throw new UserAlreadyExistsError()
}
log.info({ handle, tempDid }, 'pre-registered did')
}
async finalizeDid(
handle: string,
did: string,
tempDid: string,
declaration: DeclarationRecord,
) {
this.assertTransaction()
log.debug({ handle, did }, 'registering did-handle')
const declarationCid = await common.cidForData(declaration)
const updated = await this.db
.updateTable('did_handle')
.set({
did,
actorType: declaration.actorType,
declarationCid: declarationCid.toString(),
})
.where('handle', '=', handle)
.where('did', '=', tempDid)
.returningAll()
.executeTakeFirst()
if (!updated) {
throw new Error('DID could not be finalized')
}
log.info({ handle, did }, 'post-registered did-handle')
}
async updateUserPassword(handle: string, password: string) {
const hashedPassword = await scrypt.hash(password)
await this.db
.updateTable('user')
.set({ password: hashedPassword })
.where('handle', '=', handle)
.execute()
}
async verifyUserPassword(handle: string, password: string): Promise<boolean> {
const found = await this.db
.selectFrom('user')
.selectAll()
.where('handle', '=', handle)
.executeTakeFirst()
if (!found) return false
return scrypt.verify(password, found.password)
}
async isUserControlledRepo(
repoDid: string,
userDid: string | null,
): Promise<boolean> {
if (!userDid) return false
if (repoDid === userDid) return true
const found = await this.db
.selectFrom('did_handle')
.leftJoin('scene', 'scene.handle', 'did_handle.handle')
.where('did_handle.did', '=', repoDid)
.where('scene.owner', '=', userDid)
.select('scene.owner')
.executeTakeFirst()
return !!found
}
async getScenesForUser(userDid: string): Promise<string[]> {
const res = await this.db
.selectFrom('assertion')
.where('assertion.subjectDid', '=', userDid)
.where('assertion.assertion', '=', APP_BSKY_GRAPH.AssertMember)
.where('assertion.confirmUri', 'is not', null)
.select('assertion.creator as scene')
.execute()
return res.map((row) => row.scene)
}
assertValidRecord(collection: string, obj: unknown): void {
let table
try {
table = this.findTableForCollection(collection)
} catch (e) {
throw new ValidationError(`Schema not found`)
}
table.assertValidRecord(obj)
}
canIndexRecord(collection: string, obj: unknown): boolean {
const table = this.findTableForCollection(collection)
return table.matchesSchema(obj)
}
async indexRecord(uri: AtUri, cid: CID, obj: unknown, timestamp?: string) {
this.assertTransaction()
log.debug({ uri }, 'indexing record')
const record = {
uri: uri.toString(),
cid: cid.toString(),
did: uri.host,
collection: uri.collection,
rkey: uri.rkey,
}
if (!record.did.startsWith('did:')) {
throw new Error('Expected indexed URI to contain DID')
} else if (record.collection.length < 1) {
throw new Error('Expected indexed URI to contain a collection')
} else if (record.rkey.length < 1) {
throw new Error('Expected indexed URI to contain a record key')
}
await this.db.insertInto('record').values(record).execute()
const table = this.findTableForCollection(uri.collection)
const events = await table.insertRecord(uri, cid, obj, timestamp)
this.messageQueue && (await this.messageQueue.send(this, events))
log.info({ uri }, 'indexed record')
}
async deleteRecord(uri: AtUri, cascading = false) {
this.assertTransaction()
log.debug({ uri }, 'deleting indexed record')
const table = this.findTableForCollection(uri.collection)
const deleteQuery = this.db
.deleteFrom('record')
.where('uri', '=', uri.toString())
.execute()
const [events, _] = await Promise.all([
table.deleteRecord(uri, cascading),
deleteQuery,
])
this.messageQueue && (await this.messageQueue.send(this, events))
log.info({ uri }, 'deleted indexed record')
}
async listCollectionsForDid(did: string): Promise<string[]> {
const collections = await this.db
.selectFrom('record')
.select('collection')
.where('did', '=', did)
.execute()
return collections.map((row) => row.collection)
}
async listRecordsForCollection(
did: string,
collection: string,
limit: number,
reverse: boolean,
before?: string,
after?: string,
): Promise<{ uri: string; cid: string; value: object }[]> {
let builder = this.db
.selectFrom('record')
.innerJoin('ipld_block', 'ipld_block.cid', 'record.cid')
.where('record.did', '=', did)
.where('record.collection', '=', collection)
.orderBy('record.rkey', reverse ? 'asc' : 'desc')
.limit(limit)
.selectAll()
if (before !== undefined) {
builder = builder.where('record.rkey', '<', before)
}
if (after !== undefined) {
builder = builder.where('record.rkey', '>', after)
}
const res = await builder.execute()
return res.map((row) => {
return {
uri: row.uri,
cid: row.cid,
value: common.ipldBytesToRecord(row.content),
}
})
}
async getRecord(
uri: AtUri,
cid: string | null,
): Promise<{ uri: string; cid: string; value: object } | null> {
let builder = this.db
.selectFrom('record')
.innerJoin('ipld_block', 'ipld_block.cid', 'record.cid')
.selectAll()
.where('record.uri', '=', uri.toString())
if (cid) {
builder = builder.where('record.cid', '=', cid)
}
const record = await builder.executeTakeFirst()
if (!record) return null
return {
uri: record.uri,
cid: record.cid,
value: common.ipldBytesToRecord(record.content),
}
}
findTableForCollection(collection: string) {
const found = Object.values(this.records).find(
(plugin) => plugin.collection === collection,
)
if (!found) {
throw new Error('Could not find table for collection')
}
return found
}
}
export default Database
@ -478,5 +101,3 @@ export type Dialect = 'pg' | 'sqlite'
// Can use with typeof to get types for partial queries
export const dbType = new Kysely<DatabaseSchema>({ dialect: dummyDialect })
export class UserAlreadyExistsError extends Error {}

@ -1,32 +1,3 @@
import { AtUri } from '@atproto/uri'
import { CID } from 'multiformats/cid'
import { DynamicReferenceBuilder } from 'kysely/dist/cjs/dynamic/dynamic-reference-builder'
import { MessageOfType, Listenable } from '../stream/types'
import { Message } from '../stream/messages'
import Database from '.'
export type DbRecordPlugin<T> = {
collection: string
assertValidRecord: (obj: unknown) => void
matchesSchema: (obj: unknown) => obj is T
insert: (
uri: AtUri,
cid: CID,
obj: unknown,
timestamp?: string,
) => Promise<Message[]>
delete: (uri: AtUri) => Promise<Message[]>
}
export type Ref = DynamicReferenceBuilder<any>
export interface MessageQueue {
send(tx: Database, message: MessageOfType | MessageOfType[]): Promise<void>
listen<T extends string, M extends MessageOfType<T>>(
topic: T,
listenable: Listenable<M>,
): void
processNext(): Promise<void>
processAll(): Promise<void>
destroy(): void
}

@ -2,11 +2,13 @@ import { sql } from 'kysely'
import Database from '../../db'
import { Consumer } from '../types'
import { AddUpvote, sceneVotesOnPostTableUpdates } from '../messages'
import { ActorService } from '../../services/actor'
export default class extends Consumer<AddUpvote> {
async dispatch(ctx: { db: Database; message: AddUpvote }) {
const { db, message } = ctx
const userScenes = await db.getScenesForUser(message.user)
async dispatch(ctx: { message: AddUpvote; db: Database }) {
const { message, db } = ctx
const actorTxn = new ActorService(db)
const userScenes = await actorTxn.getScenesForUser(message.user)
if (userScenes.length < 1) return
const updated = await db.db
.updateTable('scene_votes_on_post')

@ -1,16 +1,18 @@
import { BlobStore } from '@atproto/repo'
import { DidableKey } from '@atproto/crypto'
import ServerAuth from '../../auth'
import { MessageQueue } from '../../db/types'
import AddMemberConsumer from './add-member'
import RemoveMemberConsumer from './remove-member'
import AddUpvoteConsumer from './add-upvote'
import SceneVotesOnPostConsumer from './scene-votes-on-post'
import RemoveUpvoteConsumer from './remove-upvote'
import CreateNotificationConsumer from './create-notification'
import DeleteNotificationsConsumer from './delete-notifications'
import SceneVotesOnPostConsumer from './scene-votes-on-post'
import { MessageQueue } from '../types'
export const listen = (
messageQueue: MessageQueue,
blobstore: BlobStore,
auth: ServerAuth,
keypair: DidableKey,
) => {
@ -22,7 +24,7 @@ export const listen = (
messageQueue.listen('add_upvote', new AddUpvoteConsumer())
messageQueue.listen(
'scene_votes_on_post__table_updates',
new SceneVotesOnPostConsumer(getAuthStore),
new SceneVotesOnPostConsumer(getAuthStore, messageQueue, blobstore),
)
messageQueue.listen('remove_upvote', new RemoveUpvoteConsumer())
messageQueue.listen('create_notification', new CreateNotificationConsumer())

@ -1,12 +1,14 @@
import { sql } from 'kysely'
import Database from '../../db'
import { ActorService } from '../../services/actor'
import { RemoveUpvote } from '../messages'
import { Consumer } from '../types'
export default class extends Consumer<RemoveUpvote> {
async dispatch(ctx: { db: Database; message: RemoveUpvote }) {
const { db, message } = ctx
const userScenes = await db.getScenesForUser(message.user)
const actorTxn = new ActorService(db)
const userScenes = await actorTxn.getScenesForUser(message.user)
if (userScenes.length === 0) return
await db.db
.updateTable('scene_votes_on_post')

@ -1,18 +1,24 @@
import { TID } from '@atproto/common'
import { AuthStore } from '@atproto/auth'
import { BlobStore } from '@atproto/repo'
import Database from '../../db'
import * as repo from '../../repo'
import * as lexicons from '../../lexicon/lexicons'
import { Consumer } from '../types'
import { RepoService } from '../../services/repo'
import { Consumer, MessageQueue } from '../types'
import { SceneVotesOnPostTableUpdates } from '../messages'
export default class extends Consumer<SceneVotesOnPostTableUpdates> {
constructor(private getAuthStore: GetAuthStoreFn) {
constructor(
private getAuthStore: GetAuthStoreFn,
private messageQueue: MessageQueue,
private blobstore: BlobStore,
) {
super()
}
async dispatch(ctx: { db: Database; message: SceneVotesOnPostTableUpdates }) {
const { db, message } = ctx
async dispatch(ctx: { message: SceneVotesOnPostTableUpdates; db: Database }) {
const { message, db } = ctx
const { dids: scenes, subject } = message
if (scenes.length === 0) return
const state = await db.db
@ -61,9 +67,11 @@ export default class extends Consumer<SceneVotesOnPostTableUpdates> {
.where('subject', '=', scene.subject)
.execute()
const repoTxn = new RepoService(db, this.messageQueue, this.blobstore)
await Promise.all([
repo.writeToRepo(db, scene.did, sceneAuth, writes, now),
repo.indexWrites(db, writes, now),
repoTxn.writeToRepo(scene.did, sceneAuth, writes, now),
repoTxn.indexWrites(writes, now),
setTrendPosted,
])
}),

@ -1,7 +1,6 @@
import Database from '..'
import { dbLogger as log } from '../../logger'
import { Listenable, Listener, MessageOfType } from '../../stream/types'
import { MessageQueue } from '../types'
import Database from '../db'
import { dbLogger as log } from '../logger'
import { MessageQueue, Listenable, Listener, MessageOfType } from './types'
export class SqlMessageQueue implements MessageQueue {
private cursorExists = false

@ -6,8 +6,8 @@ export type MessageOfType<T extends string = string> = {
}
export type Listener<M extends MessageOfType = MessageOfType> = (ctx: {
db: Database
message: M
db: Database
}) => Promise<void | MessageOfType[]>
export interface Listenable<M extends MessageOfType = MessageOfType> {
@ -25,3 +25,14 @@ export abstract class Consumer<M extends MessageOfType>
return this.dispatch.bind(this)
}
}
export interface MessageQueue {
send(tx: Database, message: MessageOfType | MessageOfType[]): Promise<void>
listen<T extends string, M extends MessageOfType<T>>(
topic: T,
listenable: Listenable<M>,
): void
processNext(): Promise<void>
processAll(): Promise<void>
destroy(): void
}

@ -12,16 +12,17 @@ import { DidResolver } from '@atproto/did-resolver'
import API, { health } from './api'
import Database from './db'
import ServerAuth from './auth'
import * as streamConsumers from './stream/consumers'
import * as streamConsumers from './event-stream/consumers'
import * as error from './error'
import { httpLogger, loggerMiddleware } from './logger'
import { ServerConfig, ServerConfigValues } from './config'
import { Locals } from './locals'
import { ServerMailer } from './mailer'
import { createTransport } from 'nodemailer'
import SqlMessageQueue from './db/message-queue'
import SqlMessageQueue from './event-stream/message-queue'
import { ImageUriBuilder } from './image/uri'
import { BlobDiskCache, ImageProcessingServer } from './image/server'
import { createServices } from './services'
export type { ServerConfigValues } from './config'
export { ServerConfig } from './config'
@ -45,8 +46,9 @@ const runServer = (
})
const messageQueue = new SqlMessageQueue('pds', db)
db.setMessageQueue(messageQueue)
streamConsumers.listen(messageQueue, auth, keypair)
streamConsumers.listen(messageQueue, blobstore, auth, keypair)
const services = createServices(db, messageQueue, blobstore)
const mailTransport =
config.emailSmtpUrl !== undefined
@ -87,6 +89,8 @@ const runServer = (
imgUriBuilder,
config,
mailer,
services,
messageQueue,
}
app.locals = locals

@ -9,6 +9,8 @@ import { ServerMailer } from './mailer'
import { App } from '.'
import { BlobStore } from '@atproto/repo'
import { ImageUriBuilder } from './image/uri'
import { Services } from './services'
import { MessageQueue } from './event-stream/types'
export type Locals = {
logger: pino.Logger
@ -19,6 +21,8 @@ export type Locals = {
imgUriBuilder: ImageUriBuilder
config: ServerConfig
mailer: ServerMailer
services: Services
messageQueue: MessageQueue
}
type HasLocals = App | Response
@ -87,6 +91,22 @@ export const imgUriBuilder = (res: HasLocals): ImageUriBuilder => {
return imgUriBuilder as ImageUriBuilder
}
export const services = (res: HasLocals): Services => {
const services = res.locals.services
if (!services) {
throw new Error('No Services object attached to server')
}
return services as Services
}
export const messageQueue = (res: HasLocals): MessageQueue => {
const messageQueue = res.locals.messageQueue
if (!messageQueue) {
throw new Error('No MessageQueue object attached to server')
}
return messageQueue as MessageQueue
}
export const getLocals = (res: HasLocals): Locals => {
return {
logger: logger(res),
@ -97,6 +117,8 @@ export const getLocals = (res: HasLocals): Locals => {
imgUriBuilder: imgUriBuilder(res),
config: config(res),
mailer: mailer(res),
services: services(res),
messageQueue: messageQueue(res),
}
}
export const get = getLocals

@ -1,4 +1,2 @@
export * from './blobs'
export * from './prepare'
export * from './process'
export * from './types'

@ -1,93 +0,0 @@
import { CID } from 'multiformats/cid'
import { BlobStore, Repo } from '@atproto/repo'
import * as auth from '@atproto/auth'
import Database from '../db'
import SqlBlockstore from '../sql-blockstore'
import { InvalidRequestError } from '@atproto/xrpc-server'
import { PreparedCreate, PreparedWrites } from './types'
import { processWriteBlobs } from './blobs'
export const createRepo = async (
dbTxn: Database,
did: string,
authStore: auth.AuthStore,
writes: PreparedCreate[],
now: string,
) => {
dbTxn.assertTransaction()
const blockstore = new SqlBlockstore(dbTxn, did, now)
const writeOps = writes.map((write) => write.op)
const repo = await Repo.create(blockstore, did, authStore, writeOps)
await dbTxn.db
.insertInto('repo_root')
.values({
did: did,
root: repo.cid.toString(),
indexedAt: now,
})
.execute()
}
export const processWrites = async (
dbTxn: Database,
did: string,
authStore: auth.AuthStore,
blobs: BlobStore,
writes: PreparedWrites,
now: string,
) => {
// make structural write to repo & send to indexing
// @TODO get commitCid first so we can do all db actions in tandem
const [commit] = await Promise.all([
writeToRepo(dbTxn, did, authStore, writes, now),
indexWrites(dbTxn, writes, now),
])
// make blobs permanent & associate w commit + recordUri in DB
await processWriteBlobs(dbTxn, blobs, did, commit, writes)
}
export const writeToRepo = async (
dbTxn: Database,
did: string,
authStore: auth.AuthStore,
writes: PreparedWrites,
now: string,
): Promise<CID> => {
dbTxn.assertTransaction()
const blockstore = new SqlBlockstore(dbTxn, did, now)
const currRoot = await dbTxn.getRepoRoot(did, true)
if (!currRoot) {
throw new InvalidRequestError(
`${did} is not a registered repo on this server`,
)
}
const writeOps = writes.map((write) => write.op)
const repo = await Repo.load(blockstore, currRoot)
const updated = await repo
.stageUpdate(writeOps)
.createCommit(authStore, async (prev, curr) => {
const success = await dbTxn.updateRepoRoot(did, curr, prev, now)
if (!success) {
throw new Error('Repo root update failed, could not linearize')
}
return null
})
return updated.cid
}
export const indexWrites = async (
dbTxn: Database,
writes: PreparedWrites,
now: string,
) => {
dbTxn.assertTransaction()
await Promise.all(
writes.map(async (write) => {
if (write.action === 'create') {
await dbTxn.indexRecord(write.uri, write.cid, write.op.value, now)
} else if (write.action === 'delete') {
await dbTxn.deleteRecord(write.uri)
}
}),
)
}

@ -0,0 +1,148 @@
import * as common from '@atproto/common'
import { dbLogger as log } from '../logger'
import Database from '../db'
import * as scrypt from '../db/scrypt'
import { User } from '../db/tables/user'
import { DidHandle } from '../db/tables/did-handle'
import { Record as DeclarationRecord } from '../lexicon/types/app/bsky/system/declaration'
import { APP_BSKY_GRAPH } from '../lexicon'
export class ActorService {
constructor(public db: Database) {}
static creator() {
return (db: Database) => new ActorService(db)
}
async getUser(handleOrDid: string): Promise<(User & DidHandle) | null> {
let query = this.db.db
.selectFrom('user')
.innerJoin('did_handle', 'did_handle.handle', 'user.handle')
.selectAll()
if (handleOrDid.startsWith('did:')) {
query = query.where('did', '=', handleOrDid)
} else {
query = query.where('did_handle.handle', '=', handleOrDid.toLowerCase())
}
const found = await query.executeTakeFirst()
return found || null
}
async getUserByEmail(email: string): Promise<(User & DidHandle) | null> {
const found = await this.db.db
.selectFrom('user')
.innerJoin('did_handle', 'did_handle.handle', 'user.handle')
.selectAll()
.where('email', '=', email.toLowerCase())
.executeTakeFirst()
return found || null
}
async getDidForActor(handleOrDid: string): Promise<string | null> {
if (handleOrDid.startsWith('did:')) return handleOrDid
const found = await this.db.db
.selectFrom('did_handle')
.where('handle', '=', handleOrDid)
.select('did')
.executeTakeFirst()
return found ? found.did : null
}
async registerUser(email: string, handle: string, password: string) {
this.db.assertTransaction()
log.debug({ handle, email }, 'registering user')
const inserted = await this.db.db
.insertInto('user')
.values({
email: email.toLowerCase(),
handle: handle,
password: await scrypt.hash(password),
createdAt: new Date().toISOString(),
lastSeenNotifs: new Date().toISOString(),
})
.onConflict((oc) => oc.doNothing())
.returning('handle')
.executeTakeFirst()
if (!inserted) {
throw new UserAlreadyExistsError()
}
log.info({ handle, email }, 'registered user')
}
async preregisterDid(handle: string, tempDid: string) {
this.db.assertTransaction()
const inserted = await this.db.db
.insertInto('did_handle')
.values({
handle,
did: tempDid,
actorType: 'temp',
declarationCid: 'temp',
})
.onConflict((oc) => oc.doNothing())
.returning('handle')
.executeTakeFirst()
if (!inserted) {
throw new UserAlreadyExistsError()
}
log.info({ handle, tempDid }, 'pre-registered did')
}
async finalizeDid(
handle: string,
did: string,
tempDid: string,
declaration: DeclarationRecord,
) {
this.db.assertTransaction()
log.debug({ handle, did }, 'registering did-handle')
const declarationCid = await common.cidForData(declaration)
const updated = await this.db.db
.updateTable('did_handle')
.set({
did,
actorType: declaration.actorType,
declarationCid: declarationCid.toString(),
})
.where('handle', '=', handle)
.where('did', '=', tempDid)
.returningAll()
.executeTakeFirst()
if (!updated) {
throw new Error('DID could not be finalized')
}
log.info({ handle, did }, 'post-registered did-handle')
}
async updateUserPassword(handle: string, password: string) {
const hashedPassword = await scrypt.hash(password)
await this.db.db
.updateTable('user')
.set({ password: hashedPassword })
.where('handle', '=', handle)
.execute()
}
async verifyUserPassword(handle: string, password: string): Promise<boolean> {
const found = await this.db.db
.selectFrom('user')
.selectAll()
.where('handle', '=', handle)
.executeTakeFirst()
if (!found) return false
return scrypt.verify(password, found.password)
}
async getScenesForUser(userDid: string): Promise<string[]> {
const res = await this.db.db
.selectFrom('assertion')
.where('assertion.subjectDid', '=', userDid)
.where('assertion.assertion', '=', APP_BSKY_GRAPH.AssertMember)
.where('assertion.confirmUri', 'is not', null)
.select('assertion.creator as scene')
.execute()
return res.map((row) => row.scene)
}
}
export class UserAlreadyExistsError extends Error {}

@ -0,0 +1,26 @@
import { BlobStore } from '@atproto/repo'
import Database from '../db'
import { MessageQueue } from '../event-stream/types'
import { ActorService } from './actor'
import { RecordService } from './record'
import { RepoService } from './repo'
export function createServices(
db: Database,
messageQueue: MessageQueue,
blobstore: BlobStore,
): Services {
return {
actor: ActorService.creator(),
record: RecordService.creator(messageQueue),
repo: RepoService.creator(messageQueue, blobstore),
}
}
export type Services = {
actor: FromDb<ActorService>
record: FromDb<RecordService>
repo: FromDb<RepoService>
}
type FromDb<T> = (db: Database) => T

@ -0,0 +1,181 @@
import { CID } from 'multiformats/cid'
import { ValidationError } from '@atproto/lexicon'
import { AtUri } from '@atproto/uri'
import * as common from '@atproto/common'
import { dbLogger as log } from '../../logger'
import Database from '../../db'
import * as Declaration from './plugins/declaration'
import * as Post from './plugins/post'
import * as Vote from './plugins/vote'
import * as Repost from './plugins/repost'
import * as Trend from './plugins/trend'
import * as Follow from './plugins/follow'
import * as Assertion from './plugins/assertion'
import * as Confirmation from './plugins/confirmation'
import * as Profile from './plugins/profile'
import { MessageQueue } from '../../event-stream/types'
export class RecordService {
records: {
declaration: Declaration.PluginType
post: Post.PluginType
vote: Vote.PluginType
repost: Repost.PluginType
trend: Trend.PluginType
follow: Follow.PluginType
profile: Profile.PluginType
assertion: Assertion.PluginType
confirmation: Confirmation.PluginType
}
constructor(public db: Database, public messageQueue: MessageQueue) {
this.records = {
declaration: Declaration.makePlugin(this.db.db),
post: Post.makePlugin(this.db.db),
vote: Vote.makePlugin(this.db.db),
repost: Repost.makePlugin(this.db.db),
trend: Trend.makePlugin(this.db.db),
follow: Follow.makePlugin(this.db.db),
assertion: Assertion.makePlugin(this.db.db),
confirmation: Confirmation.makePlugin(this.db.db),
profile: Profile.makePlugin(this.db.db),
}
}
static creator(messageQueue: MessageQueue) {
return (db: Database) => new RecordService(db, messageQueue)
}
assertValidRecord(collection: string, obj: unknown): void {
let table
try {
table = this.findTableForCollection(collection)
} catch (e) {
throw new ValidationError(`Schema not found`)
}
table.assertValidRecord(obj)
}
canIndexRecord(collection: string, obj: unknown): boolean {
const table = this.findTableForCollection(collection)
return table.matchesSchema(obj)
}
async indexRecord(uri: AtUri, cid: CID, obj: unknown, timestamp?: string) {
this.db.assertTransaction()
log.debug({ uri }, 'indexing record')
const record = {
uri: uri.toString(),
cid: cid.toString(),
did: uri.host,
collection: uri.collection,
rkey: uri.rkey,
}
if (!record.did.startsWith('did:')) {
throw new Error('Expected indexed URI to contain DID')
} else if (record.collection.length < 1) {
throw new Error('Expected indexed URI to contain a collection')
} else if (record.rkey.length < 1) {
throw new Error('Expected indexed URI to contain a record key')
}
await this.db.db.insertInto('record').values(record).execute()
const table = this.findTableForCollection(uri.collection)
const events = await table.insertRecord(uri, cid, obj, timestamp)
await this.messageQueue.send(this.db, events)
log.info({ uri }, 'indexed record')
}
async deleteRecord(uri: AtUri, cascading = false) {
this.db.assertTransaction()
log.debug({ uri }, 'deleting indexed record')
const table = this.findTableForCollection(uri.collection)
const deleteQuery = this.db.db
.deleteFrom('record')
.where('uri', '=', uri.toString())
.execute()
const [events, _] = await Promise.all([
table.deleteRecord(uri, cascading),
deleteQuery,
])
await this.messageQueue.send(this.db, events)
log.info({ uri }, 'deleted indexed record')
}
async listCollectionsForDid(did: string): Promise<string[]> {
const collections = await this.db.db
.selectFrom('record')
.select('collection')
.where('did', '=', did)
.execute()
return collections.map((row) => row.collection)
}
async listRecordsForCollection(
did: string,
collection: string,
limit: number,
reverse: boolean,
before?: string,
after?: string,
): Promise<{ uri: string; cid: string; value: object }[]> {
let builder = this.db.db
.selectFrom('record')
.innerJoin('ipld_block', 'ipld_block.cid', 'record.cid')
.where('record.did', '=', did)
.where('record.collection', '=', collection)
.orderBy('record.rkey', reverse ? 'asc' : 'desc')
.limit(limit)
.selectAll()
if (before !== undefined) {
builder = builder.where('record.rkey', '<', before)
}
if (after !== undefined) {
builder = builder.where('record.rkey', '>', after)
}
const res = await builder.execute()
return res.map((row) => {
return {
uri: row.uri,
cid: row.cid,
value: common.ipldBytesToRecord(row.content),
}
})
}
async getRecord(
uri: AtUri,
cid: string | null,
): Promise<{ uri: string; cid: string; value: object } | null> {
let builder = this.db.db
.selectFrom('record')
.innerJoin('ipld_block', 'ipld_block.cid', 'record.cid')
.selectAll()
.where('record.uri', '=', uri.toString())
if (cid) {
builder = builder.where('record.cid', '=', cid)
}
const record = await builder.executeTakeFirst()
if (!record) return null
return {
uri: record.uri,
cid: record.cid,
value: common.ipldBytesToRecord(record.content),
}
}
findTableForCollection(collection: string) {
const found = Object.values(this.records).find(
(plugin) => plugin.collection === collection,
)
if (!found) {
throw new Error('Could not find table for collection')
}
return found
}
}

@ -1,15 +1,15 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import { CID } from 'multiformats/cid'
import * as Assertion from '../../lexicon/types/app/bsky/graph/assertion'
import { Assertion as IndexedAssertion } from '../tables/assertion'
import * as lex from '../../lexicon/lexicons'
import * as messages from '../../stream/messages'
import { Message } from '../../stream/messages'
import DatabaseSchema from '../database-schema'
import RecordProcessor from '../record-processor'
import * as Assertion from '../../../lexicon/types/app/bsky/graph/assertion'
import * as lex from '../../../lexicon/lexicons'
import * as messages from '../../../event-stream/messages'
import { Message } from '../../../event-stream/messages'
import DatabaseSchema from '../../../db/database-schema'
import RecordProcessor from '../processor'
const lexId = lex.ids.AppBskyGraphAssertion
type IndexedAssertion = DatabaseSchema['assertion']
const insertFn = async (
db: Kysely<DatabaseSchema>,

@ -1,16 +1,16 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import { CID } from 'multiformats/cid'
import * as Confirmation from '../../lexicon/types/app/bsky/graph/confirmation'
import { Assertion as IndexedAssertion } from '../tables/assertion'
import * as lex from '../../lexicon/lexicons'
import * as messages from '../../stream/messages'
import { Message } from '../../stream/messages'
import { APP_BSKY_GRAPH } from '../../lexicon'
import DatabaseSchema from '../database-schema'
import RecordProcessor from '../record-processor'
import * as Confirmation from '../../../lexicon/types/app/bsky/graph/confirmation'
import * as lex from '../../../lexicon/lexicons'
import * as messages from '../../../event-stream/messages'
import { Message } from '../../../event-stream/messages'
import { APP_BSKY_GRAPH } from '../../../lexicon'
import DatabaseSchema from '../../../db/database-schema'
import RecordProcessor from '../processor'
const lexId = lex.ids.AppBskyGraphConfirmation
type IndexedAssertion = DatabaseSchema['assertion']
const insertFn = async (
db: Kysely<DatabaseSchema>,

@ -1,14 +1,14 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import { CID } from 'multiformats/cid'
import * as Declaration from '../../lexicon/types/app/bsky/system/declaration'
import * as lex from '../../lexicon/lexicons'
import { DidHandle } from '../tables/did-handle'
import { Message } from '../../stream/messages'
import RecordProcessor from '../record-processor'
import DatabaseSchema from '../database-schema'
import * as Declaration from '../../../lexicon/types/app/bsky/system/declaration'
import * as lex from '../../../lexicon/lexicons'
import { Message } from '../../../event-stream/messages'
import DatabaseSchema from '../../../db/database-schema'
import RecordProcessor from '../processor'
const lexId = lex.ids.AppBskySystemDeclaration
type DidHandle = DatabaseSchema['did_handle']
const insertFn = async (
db: Kysely<DatabaseSchema>,

@ -1,15 +1,15 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import { CID } from 'multiformats/cid'
import * as Follow from '../../lexicon/types/app/bsky/graph/follow'
import { Follow as IndexedFollow } from '../tables/follow'
import * as lex from '../../lexicon/lexicons'
import * as messages from '../../stream/messages'
import { Message } from '../../stream/messages'
import DatabaseSchema from '../database-schema'
import RecordProcessor from '../record-processor'
import * as Follow from '../../../lexicon/types/app/bsky/graph/follow'
import * as lex from '../../../lexicon/lexicons'
import * as messages from '../../../event-stream/messages'
import { Message } from '../../../event-stream/messages'
import DatabaseSchema from '../../../db/database-schema'
import RecordProcessor from '../processor'
const lexId = lex.ids.AppBskyGraphFollow
type IndexedFollow = DatabaseSchema['follow']
const insertFn = async (
db: Kysely<DatabaseSchema>,

@ -1,20 +1,19 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import { CID } from 'multiformats/cid'
import { Record as PostRecord } from '../../lexicon/types/app/bsky/feed/post'
import { Main as ImagesEmbedFragment } from '../../lexicon/types/app/bsky/embed/images'
import { Main as ExternalEmbedFragment } from '../../lexicon/types/app/bsky/embed/external'
import { Post } from '../tables/post'
import { PostEntity } from '../tables/post-entity'
import { PostEmbedImage } from '../tables/post-embed-image'
import { PostEmbedExternal } from '../tables/post-embed-external'
import * as lex from '../../lexicon/lexicons'
import * as messages from '../../stream/messages'
import { Message } from '../../stream/messages'
import DatabaseSchema from '../database-schema'
import RecordProcessor from '../record-processor'
import { imgUriBuilder } from '../../locals'
import { Record as PostRecord } from '../../../lexicon/types/app/bsky/feed/post'
import { Main as ImagesEmbedFragment } from '../../../lexicon/types/app/bsky/embed/images'
import { Main as ExternalEmbedFragment } from '../../../lexicon/types/app/bsky/embed/external'
import * as lex from '../../../lexicon/lexicons'
import * as messages from '../../../event-stream/messages'
import { Message } from '../../../event-stream/messages'
import DatabaseSchema from '../../../db/database-schema'
import RecordProcessor from '../processor'
type Post = DatabaseSchema['post']
type PostEntity = DatabaseSchema['post_entity']
type PostEmbedImage = DatabaseSchema['post_embed_image']
type PostEmbedExternal = DatabaseSchema['post_embed_external']
type IndexedPost = {
post: Post
entities: PostEntity[]

@ -1,14 +1,14 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import { CID } from 'multiformats/cid'
import * as Profile from '../../lexicon/types/app/bsky/actor/profile'
import { Profile as IndexedProfile } from '../tables/profile'
import * as lex from '../../lexicon/lexicons'
import { Message } from '../../stream/messages'
import DatabaseSchema from '../database-schema'
import RecordProcessor from '../record-processor'
import * as Profile from '../../../lexicon/types/app/bsky/actor/profile'
import * as lex from '../../../lexicon/lexicons'
import { Message } from '../../../event-stream/messages'
import DatabaseSchema from '../../../db/database-schema'
import RecordProcessor from '../processor'
const lexId = lex.ids.AppBskyActorProfile
type IndexedProfile = DatabaseSchema['profile']
const insertFn = async (
db: Kysely<DatabaseSchema>,

@ -1,15 +1,15 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import * as Repost from '../../lexicon/types/app/bsky/feed/repost'
import { Repost as IndexedRepost } from '../tables/repost'
import * as lex from '../../lexicon/lexicons'
import * as Repost from '../../../lexicon/types/app/bsky/feed/repost'
import * as lex from '../../../lexicon/lexicons'
import { CID } from 'multiformats/cid'
import * as messages from '../../stream/messages'
import { Message } from '../../stream/messages'
import { DatabaseSchema } from '../database-schema'
import RecordProcessor from '../record-processor'
import * as messages from '../../../event-stream/messages'
import { Message } from '../../../event-stream/messages'
import { DatabaseSchema } from '../../../db/database-schema'
import RecordProcessor from '../processor'
const lexId = lex.ids.AppBskyFeedRepost
type IndexedRepost = DatabaseSchema['repost']
const insertFn = async (
db: Kysely<DatabaseSchema>,

@ -1,15 +1,15 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import * as Trend from '../../lexicon/types/app/bsky/feed/trend'
import { Trend as IndexedTrend } from '../tables/trend'
import * as lex from '../../lexicon/lexicons'
import { CID } from 'multiformats/cid'
import * as messages from '../../stream/messages'
import { Message } from '../../stream/messages'
import DatabaseSchema from '../database-schema'
import RecordProcessor from '../record-processor'
import { AtUri } from '@atproto/uri'
import * as Trend from '../../../lexicon/types/app/bsky/feed/trend'
import * as lex from '../../../lexicon/lexicons'
import * as messages from '../../../event-stream/messages'
import { Message } from '../../../event-stream/messages'
import DatabaseSchema from '../../../db/database-schema'
import RecordProcessor from '../processor'
const lexId = lex.ids.AppBskyFeedTrend
type IndexedTrend = DatabaseSchema['trend']
const insertFn = async (
db: Kysely<DatabaseSchema>,

@ -1,15 +1,15 @@
import { Kysely } from 'kysely'
import { AtUri } from '@atproto/uri'
import { CID } from 'multiformats/cid'
import * as Vote from '../../lexicon/types/app/bsky/feed/vote'
import { Vote as IndexedVote } from '../tables/vote'
import * as lex from '../../lexicon/lexicons'
import * as messages from '../../stream/messages'
import { Message } from '../../stream/messages'
import { DatabaseSchema } from '../database-schema'
import RecordProcessor from '../record-processor'
import * as Vote from '../../../lexicon/types/app/bsky/feed/vote'
import * as lex from '../../../lexicon/lexicons'
import * as messages from '../../../event-stream/messages'
import { Message } from '../../../event-stream/messages'
import { DatabaseSchema } from '../../../db/database-schema'
import RecordProcessor from '../processor'
const lexId = lex.ids.AppBskyFeedVote
type IndexedVote = DatabaseSchema['vote']
const insertFn = async (
db: Kysely<DatabaseSchema>,

@ -1,10 +1,10 @@
import { AtUri } from '@atproto/uri'
import * as common from '@atproto/common'
import { Kysely } from 'kysely'
import { CID } from 'multiformats/cid'
import { DatabaseSchema } from './database-schema'
import { Message } from '../stream/messages'
import { lexicons } from '../lexicon/lexicons'
import { AtUri } from '@atproto/uri'
import * as common from '@atproto/common'
import DatabaseSchema from '../../db/database-schema'
import { Message } from '../../event-stream/messages'
import { lexicons } from '../../lexicon/lexicons'
type RecordProcessorParams<T, S> = {
lexId: string

@ -1,74 +1,131 @@
import stream from 'stream'
import { CID } from 'multiformats/cid'
import { BlobStore } from '@atproto/repo'
import bytes from 'bytes'
import Database from '../db'
import { fromStream as fileTypeFromStream } from 'file-type'
import { BlobStore } from '@atproto/repo'
import { AtUri } from '@atproto/uri'
import { sha256Stream } from '@atproto/crypto'
import { cloneStream, sha256RawToCid, streamSize } from '@atproto/common'
import { InvalidRequestError } from '@atproto/xrpc-server'
import { AtUri } from '@atproto/uri'
import { BlobRef, PreparedWrites } from './types'
import { Blob as BlobTable } from '../db/tables/blob'
import * as img from '../image'
import { sha256Stream } from '@atproto/crypto'
import { fromStream as fileTypeFromStream } from 'file-type'
import { BlobRef, PreparedWrites } from '../../repo/types'
import Database from '../../db'
import { Blob as BlobTable } from '../../db/tables/blob'
import * as img from '../../image'
export const addUntetheredBlob = async (
dbTxn: Database,
blobstore: BlobStore,
mimeType: string,
blobStream: stream.Readable,
): Promise<CID> => {
const [tempKey, size, sha256, imgInfo, fileType] = await Promise.all([
blobstore.putTemp(cloneStream(blobStream)),
streamSize(cloneStream(blobStream)),
sha256Stream(cloneStream(blobStream)),
img.maybeGetInfo(cloneStream(blobStream)),
fileTypeFromStream(blobStream),
])
export class RepoBlobs {
constructor(public db: Database, public blobstore: BlobStore) {}
const cid = sha256RawToCid(sha256)
async addUntetheredBlob(
mimeType: string,
blobStream: stream.Readable,
): Promise<CID> {
const [tempKey, size, sha256, imgInfo, fileType] = await Promise.all([
this.blobstore.putTemp(cloneStream(blobStream)),
streamSize(cloneStream(blobStream)),
sha256Stream(cloneStream(blobStream)),
img.maybeGetInfo(cloneStream(blobStream)),
fileTypeFromStream(blobStream),
])
await dbTxn.db
.insertInto('blob')
.values({
cid: cid.toString(),
mimeType: fileType?.mime || mimeType,
size,
tempKey,
width: imgInfo?.width || null,
height: imgInfo?.height || null,
createdAt: new Date().toISOString(),
})
.onConflict((oc) =>
oc
.column('cid')
.doUpdateSet({ tempKey })
.where('blob.tempKey', 'is not', null),
)
.execute()
return cid
}
const cid = sha256RawToCid(sha256)
export const processWriteBlobs = async (
dbTxn: Database,
blobstore: BlobStore,
did: string,
commit: CID,
writes: PreparedWrites,
) => {
const blobPromises: Promise<void>[] = []
for (const write of writes) {
if (write.action === 'create' || write.action === 'update') {
for (const blob of write.blobs) {
blobPromises.push(verifyBlobAndMakePermanent(dbTxn, blobstore, blob))
blobPromises.push(associateBlob(dbTxn, blob, write.uri, commit, did))
await this.db.db
.insertInto('blob')
.values({
cid: cid.toString(),
mimeType: fileType?.mime || mimeType,
size,
tempKey,
width: imgInfo?.width || null,
height: imgInfo?.height || null,
createdAt: new Date().toISOString(),
})
.onConflict((oc) =>
oc
.column('cid')
.doUpdateSet({ tempKey })
.where('blob.tempKey', 'is not', null),
)
.execute()
return cid
}
async processWriteBlobs(did: string, commit: CID, writes: PreparedWrites) {
const blobPromises: Promise<void>[] = []
for (const write of writes) {
if (write.action === 'create' || write.action === 'update') {
for (const blob of write.blobs) {
blobPromises.push(this.verifyBlobAndMakePermanent(blob))
blobPromises.push(this.associateBlob(blob, write.uri, commit, did))
}
}
}
await Promise.all(blobPromises)
}
async verifyBlobAndMakePermanent(blob: BlobRef): Promise<void> {
const found = await this.db.db
.selectFrom('blob')
.selectAll()
.where('cid', '=', blob.cid.toString())
.executeTakeFirst()
if (!found) {
throw new InvalidRequestError(
`Could not find blob: ${blob.cid.toString()}`,
'BlobNotFound',
)
}
if (found.tempKey) {
verifyBlob(blob, found)
await this.blobstore.makePermanent(found.tempKey, blob.cid)
await this.db.db
.updateTable('blob')
.set({ tempKey: null })
.where('tempKey', '=', found.tempKey)
.execute()
}
}
async associateBlob(
blob: BlobRef,
recordUri: AtUri,
commit: CID,
did: string,
): Promise<void> {
await this.db.db
.insertInto('repo_blob')
.values({
cid: blob.cid.toString(),
recordUri: recordUri.toString(),
commit: commit.toString(),
did,
})
.onConflict((oc) => oc.doNothing())
.execute()
}
await Promise.all(blobPromises)
}
export const verifyBlob = (blob: BlobRef, found: BlobTable) => {
export class CidNotFound extends Error {
cid: CID
constructor(cid: CID) {
super(`cid not found: ${cid.toString()}`)
this.cid = cid
}
}
function acceptedMime(mime: string, accepted: string[]): boolean {
if (accepted.includes('*/*')) return true
const globs = accepted.filter((a) => a.endsWith('/*'))
for (const glob of globs) {
const [start] = glob.split('/')
if (mime.startsWith(`${start}/`)) {
return true
}
}
return accepted.includes(mime)
}
function verifyBlob(blob: BlobRef, found: BlobTable) {
const throwInvalid = (msg: string, errName = 'InvalidBlob') => {
throw new InvalidRequestError(msg, errName)
}
@ -144,69 +201,3 @@ export const verifyBlob = (blob: BlobRef, found: BlobTable) => {
}
}
}
const acceptedMime = (mime: string, accepted: string[]): boolean => {
if (accepted.includes('*/*')) return true
const globs = accepted.filter((a) => a.endsWith('/*'))
for (const glob of globs) {
const [start] = glob.split('/')
if (mime.startsWith(`${start}/`)) {
return true
}
}
return accepted.includes(mime)
}
export const verifyBlobAndMakePermanent = async (
dbTxn: Database,
blobstore: BlobStore,
blob: BlobRef,
): Promise<void> => {
const found = await dbTxn.db
.selectFrom('blob')
.selectAll()
.where('cid', '=', blob.cid.toString())
.executeTakeFirst()
if (!found) {
throw new InvalidRequestError(
`Could not find blob: ${blob.cid.toString()}`,
'BlobNotFound',
)
}
if (found.tempKey) {
verifyBlob(blob, found)
await blobstore.makePermanent(found.tempKey, blob.cid)
await dbTxn.db
.updateTable('blob')
.set({ tempKey: null })
.where('tempKey', '=', found.tempKey)
.execute()
}
}
export const associateBlob = async (
dbTxn: Database,
blob: BlobRef,
recordUri: AtUri,
commit: CID,
did: string,
): Promise<void> => {
await dbTxn.db
.insertInto('repo_blob')
.values({
cid: blob.cid.toString(),
recordUri: recordUri.toString(),
commit: commit.toString(),
did,
})
.onConflict((oc) => oc.doNothing())
.execute()
}
export class CidNotFound extends Error {
cid: CID
constructor(cid: CID) {
super(`cid not found: ${cid.toString()}`)
this.cid = cid
}
}

@ -0,0 +1,165 @@
import { CID } from 'multiformats/cid'
import * as auth from '@atproto/auth'
import { BlobStore, Repo } from '@atproto/repo'
import { InvalidRequestError } from '@atproto/xrpc-server'
import Database from '../../db'
import { dbLogger as log } from '../../logger'
import { MessageQueue } from '../../event-stream/types'
import SqlBlockstore from '../../sql-blockstore'
import { PreparedCreate, PreparedWrites } from '../../repo/types'
import { RecordService } from '../record'
import { RepoBlobs } from './blobs'
export class RepoService {
blobs: RepoBlobs
constructor(
public db: Database,
public messageQueue: MessageQueue,
public blobstore: BlobStore,
) {
this.blobs = new RepoBlobs(db, blobstore)
}
static creator(messageQueue: MessageQueue, blobstore: BlobStore) {
return (db: Database) => new RepoService(db, messageQueue, blobstore)
}
async getRepoRoot(did: string, forUpdate?: boolean): Promise<CID | null> {
let builder = this.db.db
.selectFrom('repo_root')
.selectAll()
.where('did', '=', did)
if (forUpdate) {
this.db.assertTransaction()
if (this.db.dialect !== 'sqlite') {
// SELECT FOR UPDATE is not supported by sqlite, but sqlite txs are SERIALIZABLE so we don't actually need it
builder = builder.forUpdate()
}
}
const found = await builder.executeTakeFirst()
return found ? CID.parse(found.root) : null
}
async updateRepoRoot(
did: string,
root: CID,
prev: CID,
timestamp?: string,
): Promise<boolean> {
log.debug({ did, root: root.toString() }, 'updating repo root')
const res = await this.db.db
.updateTable('repo_root')
.set({
root: root.toString(),
indexedAt: timestamp || new Date().toISOString(),
})
.where('did', '=', did)
.where('root', '=', prev.toString())
.executeTakeFirst()
if (res.numUpdatedRows > 0) {
log.info({ did, root: root.toString() }, 'updated repo root')
return true
} else {
log.info(
{ did, root: root.toString() },
'failed to update repo root: misordered',
)
return false
}
}
async isUserControlledRepo(
repoDid: string,
userDid: string | null,
): Promise<boolean> {
if (!userDid) return false
if (repoDid === userDid) return true
const found = await this.db.db
.selectFrom('did_handle')
.leftJoin('scene', 'scene.handle', 'did_handle.handle')
.where('did_handle.did', '=', repoDid)
.where('scene.owner', '=', userDid)
.select('scene.owner')
.executeTakeFirst()
return !!found
}
async createRepo(
did: string,
authStore: auth.AuthStore,
writes: PreparedCreate[],
now: string,
) {
this.db.assertTransaction()
const blockstore = new SqlBlockstore(this.db, did, now)
const writeOps = writes.map((write) => write.op)
const repo = await Repo.create(blockstore, did, authStore, writeOps)
await this.db.db
.insertInto('repo_root')
.values({
did: did,
root: repo.cid.toString(),
indexedAt: now,
})
.execute()
}
async processWrites(
did: string,
authStore: auth.AuthStore,
writes: PreparedWrites,
now: string,
) {
// make structural write to repo & send to indexing
// @TODO get commitCid first so we can do all db actions in tandem
const [commit] = await Promise.all([
this.writeToRepo(did, authStore, writes, now),
this.indexWrites(writes, now),
])
// make blobs permanent & associate w commit + recordUri in DB
await this.blobs.processWriteBlobs(did, commit, writes)
}
async writeToRepo(
did: string,
authStore: auth.AuthStore,
writes: PreparedWrites,
now: string,
): Promise<CID> {
this.db.assertTransaction()
const blockstore = new SqlBlockstore(this.db, did, now)
const currRoot = await this.getRepoRoot(did, true)
if (!currRoot) {
throw new InvalidRequestError(
`${did} is not a registered repo on this server`,
)
}
const writeOps = writes.map((write) => write.op)
const repo = await Repo.load(blockstore, currRoot)
const updated = await repo
.stageUpdate(writeOps)
.createCommit(authStore, async (prev, curr) => {
const success = await this.updateRepoRoot(did, curr, prev, now)
if (!success) {
throw new Error('Repo root update failed, could not linearize')
}
return null
})
return updated.cid
}
async indexWrites(writes: PreparedWrites, now: string) {
this.db.assertTransaction()
const recordTxn = new RecordService(this.db, this.messageQueue)
await Promise.all(
writes.map(async (write) => {
if (write.action === 'create') {
await recordTxn.indexRecord(write.uri, write.cid, write.op.value, now)
} else if (write.action === 'delete') {
await recordTxn.deleteRecord(write.uri)
}
}),
)
}
}

@ -5,6 +5,7 @@ import path from 'path'
import * as crypto from '@atproto/crypto'
import * as plc from '@atproto/plc'
import { AtUri } from '@atproto/uri'
import { randomStr } from '@atproto/crypto'
import { CID } from 'multiformats/cid'
import * as uint8arrays from 'uint8arrays'
import server, {
@ -16,7 +17,9 @@ import server, {
import * as GetAuthorFeed from '../src/lexicon/types/app/bsky/feed/getAuthorFeed'
import * as GetTimeline from '../src/lexicon/types/app/bsky/feed/getTimeline'
import DiskBlobStore from '../src/storage/disk-blobstore'
import { randomStr } from '@atproto/crypto'
import * as locals from '../src/locals'
import { MessageQueue } from '../src/event-stream/types'
import { Services } from '../src/services'
const ADMIN_PASSWORD = 'admin-pass'
@ -27,6 +30,8 @@ export type TestServerInfo = {
serverKey: string
app: App
db: Database
messageQueue: MessageQueue
services: Services
blobstore: DiskBlobStore | MemoryBlobStore
close: CloseFn
}
@ -99,6 +104,7 @@ export const runTestServer = async (
const { app, listener } = server(db, blobstore, keypair, config)
const pdsPort = (listener.address() as AddressInfo).port
const { services, messageQueue } = locals.get(app)
return {
url: `http://localhost:${pdsPort}`,
@ -106,8 +112,11 @@ export const runTestServer = async (
serverKey: keypair.did(),
app,
db,
services,
messageQueue,
blobstore,
close: async () => {
await messageQueue.destroy()
await Promise.all([
db.close(),
closeServer(listener),

@ -5,17 +5,20 @@ import * as lex from '../src/lexicon/lexicons'
import { cidForData, TID, valueToIpldBytes } from '@atproto/common'
import { CID } from 'multiformats/cid'
import { APP_BSKY_GRAPH } from '../src/lexicon'
import { Services } from '../src/services'
describe('duplicate record', () => {
let close: CloseFn
let did: string
let db: Database
let services: Services
beforeAll(async () => {
const server = await runTestServer({
dbPostgresSchema: 'duplicates',
})
db = server.db
services = server.services
close = server.close
did = 'did:example:alice'
})
@ -66,7 +69,7 @@ describe('duplicate record', () => {
}
const uri = AtUri.make(did, coll, TID.nextStr())
const cid = await putBlock(tx, repost)
await tx.indexRecord(uri, cid, repost)
await services.record(tx).indexRecord(uri, cid, repost)
uris.push(uri)
}
})
@ -75,14 +78,14 @@ describe('duplicate record', () => {
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(uris[0], false)
await services.record(tx).deleteRecord(uris[0], false)
})
count = await countRecords(db, 'repost')
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(uris[1], true)
await services.record(tx).deleteRecord(uris[1], true)
})
count = await countRecords(db, 'repost')
@ -106,7 +109,7 @@ describe('duplicate record', () => {
}
const uri = AtUri.make(did, coll, TID.nextStr())
const cid = await putBlock(tx, trend)
await tx.indexRecord(uri, cid, trend)
await services.record(tx).indexRecord(uri, cid, trend)
uris.push(uri)
}
})
@ -115,14 +118,14 @@ describe('duplicate record', () => {
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(uris[0], false)
await services.record(tx).deleteRecord(uris[0], false)
})
count = await countRecords(db, 'trend')
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(uris[1], true)
await services.record(tx).deleteRecord(uris[1], true)
})
count = await countRecords(db, 'trend')
@ -148,7 +151,7 @@ describe('duplicate record', () => {
}
const uri = AtUri.make(did, coll, TID.nextStr())
const cid = await putBlock(tx, vote)
await tx.indexRecord(uri, cid, vote)
await services.record(tx).indexRecord(uri, cid, vote)
uris.push(uri)
}
})
@ -157,7 +160,7 @@ describe('duplicate record', () => {
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(uris[0], false)
await services.record(tx).deleteRecord(uris[0], false)
})
count = await countRecords(db, 'vote')
@ -171,7 +174,7 @@ describe('duplicate record', () => {
expect(got?.direction === 'down')
await db.transaction(async (tx) => {
await tx.deleteRecord(uris[1], true)
await services.record(tx).deleteRecord(uris[1], true)
})
count = await countRecords(db, 'vote')
@ -194,7 +197,7 @@ describe('duplicate record', () => {
}
const uri = AtUri.make(did, coll, TID.nextStr())
const cid = await putBlock(tx, follow)
await tx.indexRecord(uri, cid, follow)
await services.record(tx).indexRecord(uri, cid, follow)
uris.push(uri)
}
})
@ -203,14 +206,14 @@ describe('duplicate record', () => {
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(uris[0], false)
await services.record(tx).deleteRecord(uris[0], false)
})
count = await countRecords(db, 'follow')
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(uris[1], true)
await services.record(tx).deleteRecord(uris[1], true)
})
count = await countRecords(db, 'follow')
@ -236,7 +239,7 @@ describe('duplicate record', () => {
}
const uri = AtUri.make(did, coll, TID.nextStr())
const cid = await putBlock(tx, assertion)
await tx.indexRecord(uri, cid, assertion)
await services.record(tx).indexRecord(uri, cid, assertion)
assertUris.push(uri)
assertCids.push(cid)
}
@ -261,7 +264,7 @@ describe('duplicate record', () => {
}
const uri = AtUri.make(did, coll, TID.nextStr())
const cid = await putBlock(tx, follow)
await tx.indexRecord(uri, cid, follow)
await services.record(tx).indexRecord(uri, cid, follow)
confirmUris.push(uri)
confirmCids.push(cid)
}
@ -279,7 +282,7 @@ describe('duplicate record', () => {
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(confirmUris[0], false)
await services.record(tx).deleteRecord(confirmUris[0], false)
})
count = await countRecords(db, 'assertion')
@ -288,7 +291,7 @@ describe('duplicate record', () => {
expect(assertion?.confirmUri).toEqual(confirmUris[1].toString())
await db.transaction(async (tx) => {
await tx.deleteRecord(confirmUris[1], true)
await services.record(tx).deleteRecord(confirmUris[1], true)
})
count = await countRecords(db, 'assertion')
@ -297,14 +300,14 @@ describe('duplicate record', () => {
expect(assertion?.confirmUri).toBeNull()
await db.transaction(async (tx) => {
await tx.deleteRecord(assertUris[0], false)
await services.record(tx).deleteRecord(assertUris[0], false)
})
count = await countRecords(db, 'assertion')
expect(count).toBe(1)
await db.transaction(async (tx) => {
await tx.deleteRecord(assertUris[1], false)
await services.record(tx).deleteRecord(assertUris[1], false)
})
count = await countRecords(db, 'assertion')

@ -1,16 +1,18 @@
import { AtUri } from '@atproto/uri'
import AtpApi, { ServiceClient as AtpServiceClient } from '@atproto/api'
import { Database } from '../src'
import { runTestServer, CloseFn } from './_util'
import { SeedClient } from './seeds/client'
import usersSeed from './seeds/users'
import scenesSeed from './seeds/scenes'
import { AtUri } from '@atproto/uri'
import { MessageQueue } from '../src/event-stream/types'
describe('db', () => {
let close: CloseFn
let client: AtpServiceClient
let sc: SeedClient
let db: Database
let messageQueue: MessageQueue
let uri1: AtUri, uri2: AtUri
@ -20,11 +22,12 @@ describe('db', () => {
})
close = server.close
db = server.db
messageQueue = server.messageQueue
client = AtpApi.service(server.url)
sc = new SeedClient(client)
await usersSeed(sc)
await scenesSeed(sc)
await db.messageQueue?.processAll()
await messageQueue.processAll()
const post1 = await sc.post(sc.dids.alice, 'test1')
const post2 = await sc.post(sc.dids.bob, 'test2')
@ -54,28 +57,28 @@ describe('db', () => {
})
it('handles vote increments', async () => {
await db.messageQueue?.send(db, {
await messageQueue.send(db, {
type: 'add_upvote',
user: sc.dids.alice,
subject: uri1.toString(),
})
await db.messageQueue?.send(db, {
await messageQueue.send(db, {
type: 'add_upvote',
user: sc.dids.bob,
subject: uri1.toString(),
})
await db.messageQueue?.send(db, {
await messageQueue.send(db, {
type: 'add_upvote',
user: sc.dids.carol,
subject: uri1.toString(),
})
await db.messageQueue?.send(db, {
await messageQueue.send(db, {
type: 'add_upvote',
user: sc.dids.alice,
subject: uri2.toString(),
})
await db.messageQueue?.processAll()
await messageQueue.processAll()
const res = await db.db
.selectFrom('scene_votes_on_post')

@ -3,9 +3,13 @@ import { Database } from '../../src'
import { cidForData, TID } from '@atproto/common'
import * as lex from '../../src/lexicon/lexicons'
import { APP_BSKY_GRAPH } from '../../src/lexicon'
import SqlMessageQueue from '../../src/event-stream/message-queue'
import { RecordService } from '../../src/services/record'
import { MessageQueue } from '../../src/event-stream/types'
describe('duplicate record', () => {
let db: Database
let messageQueue: MessageQueue
beforeAll(async () => {
if (process.env.DB_POSTGRES_URL) {
@ -16,10 +20,12 @@ describe('duplicate record', () => {
} else {
db = Database.memory()
}
messageQueue = new SqlMessageQueue('pds', db)
await db.migrator.migrateTo('_20221021T162202001Z')
})
afterAll(async () => {
await messageQueue.destroy()
await db.close()
})
@ -29,9 +35,10 @@ describe('duplicate record', () => {
const collection = record.$type
const cid = await cidForData(record)
await db.transaction(async (tx) => {
const recordTx = new RecordService(tx, messageQueue)
for (let i = 0; i < times; i++) {
const uri = AtUri.make(did, collection, TID.nextStr())
await tx.indexRecord(uri, cid, record)
await recordTx.indexRecord(uri, cid, record)
}
})
}

@ -1,4 +1,4 @@
import { MessageQueue } from '../../src/db/types'
import { MessageQueue } from '../../src/event-stream/types'
import { SeedClient } from './client'
import scenesSeed from './scenes'
import usersSeed from './users'

@ -21,7 +21,7 @@ describe('pds author feed views', () => {
close = server.close
client = AtpApi.service(server.url)
sc = new SeedClient(client)
await basicSeed(sc, server.db.messageQueue)
await basicSeed(sc, server.messageQueue)
alice = sc.dids.alice
bob = sc.dids.bob
carol = sc.dids.carol

@ -23,7 +23,7 @@ describe('pds notification views', () => {
app = server.app
client = AtpApi.service(server.url)
sc = new SeedClient(client)
await basicSeed(sc, server.db.messageQueue)
await basicSeed(sc, server.messageQueue)
alice = sc.dids.alice
})

@ -12,7 +12,6 @@ describe('pds profile views', () => {
// account dids, for convenience
let alice: string
let bob: string
let carol: string
let dan: string
let scene: string
@ -26,7 +25,6 @@ describe('pds profile views', () => {
await basicSeed(sc)
alice = sc.dids.alice
bob = sc.dids.bob
carol = sc.dids.carol
dan = sc.dids.dan
scene = sc.scenes['scene.test'].did
})

@ -1,11 +1,8 @@
import AtpApi, { ServiceClient as AtpServiceClient } from '@atproto/api'
import { runTestServer, forSnapshot, CloseFn, paginateAll } from '../_util'
import { runTestServer, CloseFn } from '../_util'
import { SeedClient } from '../seeds/client'
import usersBulkSeed from '../seeds/users-bulk'
import { App } from '../../src'
describe('pds user search views', () => {
let app: App
let client: AtpServiceClient
let close: CloseFn
let sc: SeedClient
@ -16,7 +13,6 @@ describe('pds user search views', () => {
dbPostgresSchema: 'views_suggestions',
})
close = server.close
app = server.app
client = AtpApi.service(server.url)
sc = new SeedClient(client)
const users = [
@ -72,6 +68,7 @@ describe('pds user search views', () => {
{ limit: 3 },
{ headers },
)
// @TODO test contents of result
})
it('paginates', async () => {

@ -29,7 +29,7 @@ describe('timeline views', () => {
close = server.close
client = AtpApi.service(server.url)
sc = new SeedClient(client)
await basicSeed(sc, server.db.messageQueue)
await basicSeed(sc, server.messageQueue)
alice = sc.dids.alice
bob = sc.dids.bob
carol = sc.dids.carol