atproto/packages/xrpc-server/tests/subscriptions.test.ts
Matthieu Sieben 7f26b17652
Add OAuth tests (#2874)
* Improve error message when using invalid client_id during code exchange

* Extract SPA example OAuth client in own package

* wip

* remove dependency on get-port

* Properly configure jest to only transpile "get-port" from node_modules

https://jestjs.io/docs/configuration#transformignorepatterns-arraystring

* Use dynamically assigned port number during tests

* use puppeteer to run tests

* remove login input "id" attribute

* code style

* add missing declaration

* tidy

* headless

* remove get-port dependency

* fix tests/proxied/admin.test.ts

* fix tests

* Allow unsecure oauth providers through configuration

* transpile "lande" during ozone tests

* Cache Puppeteer browser binaries

* Use puppeteer cache during all workflow steps

* remove use of set-output

* use get-port in xrpc-server tests

* Renamed to allowHttp

* tidy

* tidy
2024-10-18 15:40:05 +02:00

406 lines
11 KiB
TypeScript

import * as http from 'node:http'
import { AddressInfo } from 'node:net'
import { WebSocket, WebSocketServer, createWebSocketStream } from 'ws'
import getPort from 'get-port'
import { wait } from '@atproto/common'
import { LexiconDoc } from '@atproto/lexicon'
import { byFrame, MessageFrame, ErrorFrame, Frame, Subscription } from '../src'
import {
createServer,
closeServer,
createBasicAuth,
basicAuthHeaders,
} from './_util'
import * as xrpcServer from '../src'
const LEXICONS: LexiconDoc[] = [
{
lexicon: 1,
id: 'io.example.streamOne',
defs: {
main: {
type: 'subscription',
parameters: {
type: 'params',
required: ['countdown'],
properties: {
countdown: { type: 'integer' },
},
},
message: {
schema: {
type: 'object',
required: ['count'],
properties: { count: { type: 'integer' } },
},
},
},
},
},
{
lexicon: 1,
id: 'io.example.streamTwo',
defs: {
main: {
type: 'subscription',
parameters: {
type: 'params',
required: ['countdown'],
properties: {
countdown: { type: 'integer' },
},
},
message: {
schema: {
type: 'union',
refs: ['#even', '#odd'],
},
},
},
even: {
type: 'object',
required: ['count'],
properties: { count: { type: 'integer' } },
},
odd: {
type: 'object',
required: ['count'],
properties: { count: { type: 'integer' } },
},
},
},
{
lexicon: 1,
id: 'io.example.streamAuth',
defs: {
main: {
type: 'subscription',
},
},
},
]
describe('Subscriptions', () => {
let s: http.Server
const server = xrpcServer.createServer(LEXICONS)
const lex = server.lex
server.streamMethod('io.example.streamOne', async function* ({ params }) {
const countdown = Number(params.countdown ?? 0)
for (let i = countdown; i >= 0; i--) {
await wait(0)
yield { count: i }
}
})
server.streamMethod('io.example.streamTwo', async function* ({ params }) {
const countdown = Number(params.countdown ?? 0)
for (let i = countdown; i >= 0; i--) {
await wait(200)
yield {
$type: i % 2 === 0 ? '#even' : 'io.example.streamTwo#odd',
count: i,
}
}
yield {
$type: 'io.example.otherNsid#done',
}
})
server.streamMethod('io.example.streamAuth', {
auth: createBasicAuth({ username: 'admin', password: 'password' }),
handler: async function* ({ auth }) {
yield auth
},
})
let port: number
beforeAll(async () => {
s = await createServer(server)
port = (s.address() as AddressInfo).port
})
afterAll(async () => {
if (s) await closeServer(s)
})
it('streams messages', async () => {
const ws = new WebSocket(
`ws://localhost:${port}/xrpc/io.example.streamOne?countdown=5`,
)
const frames: Frame[] = []
for await (const frame of byFrame(ws)) {
frames.push(frame)
}
expect(frames).toEqual([
new MessageFrame({ count: 5 }),
new MessageFrame({ count: 4 }),
new MessageFrame({ count: 3 }),
new MessageFrame({ count: 2 }),
new MessageFrame({ count: 1 }),
new MessageFrame({ count: 0 }),
])
})
it('streams messages in a union', async () => {
const ws = new WebSocket(
`ws://localhost:${port}/xrpc/io.example.streamTwo?countdown=5`,
)
const frames: Frame[] = []
for await (const frame of byFrame(ws)) {
frames.push(frame)
}
expect(frames).toEqual([
new MessageFrame({ count: 5 }, { type: '#odd' }),
new MessageFrame({ count: 4 }, { type: '#even' }),
new MessageFrame({ count: 3 }, { type: '#odd' }),
new MessageFrame({ count: 2 }, { type: '#even' }),
new MessageFrame({ count: 1 }, { type: '#odd' }),
new MessageFrame({ count: 0 }, { type: '#even' }),
new MessageFrame({}, { type: 'io.example.otherNsid#done' }),
])
})
it('resolves auth into handler', async () => {
const ws = new WebSocket(
`ws://localhost:${port}/xrpc/io.example.streamAuth`,
{
headers: basicAuthHeaders({
username: 'admin',
password: 'password',
}),
},
)
const frames: Frame[] = []
for await (const frame of byFrame(ws)) {
frames.push(frame)
}
expect(frames).toEqual([
new MessageFrame({
credentials: {
username: 'admin',
},
artifacts: {
original: 'YWRtaW46cGFzc3dvcmQ=',
},
}),
])
})
it('errors immediately on bad parameter', async () => {
const ws = new WebSocket(`ws://localhost:${port}/xrpc/io.example.streamOne`)
const frames: Frame[] = []
for await (const frame of byFrame(ws)) {
frames.push(frame)
}
expect(frames).toEqual([
new ErrorFrame({
error: 'InvalidRequest',
message: 'Error: Params must have the property "countdown"',
}),
])
})
it('errors immediately on bad auth', async () => {
const ws = new WebSocket(
`ws://localhost:${port}/xrpc/io.example.streamAuth`,
{
headers: basicAuthHeaders({
username: 'bad',
password: 'wrong',
}),
},
)
const frames: Frame[] = []
for await (const frame of byFrame(ws)) {
frames.push(frame)
}
expect(frames).toEqual([
new ErrorFrame({
error: 'AuthenticationRequired',
message: 'Authentication Required',
}),
])
})
it('does not websocket upgrade at bad endpoint', async () => {
const ws = new WebSocket(`ws://localhost:${port}/xrpc/does.not.exist`)
const drainStream = async () => {
for await (const bytes of createWebSocketStream(ws)) {
bytes // drain
}
}
await expect(drainStream).rejects.toHaveProperty('code', 'ECONNRESET')
})
describe('Subscription consumer', () => {
it('receives messages w/ skips', async () => {
const sub = new Subscription({
service: `ws://localhost:${port}`,
method: 'io.example.streamOne',
getParams: () => ({ countdown: 5 }),
validate: (obj) => {
const result = lex.assertValidXrpcMessage<{ count: number }>(
'io.example.streamOne',
obj,
)
if (!result.count || result.count % 2) {
return result
}
},
})
const messages: { count: number }[] = []
for await (const msg of sub) {
messages.push(msg)
}
expect(messages).toEqual([
{ count: 5 },
{ count: 3 },
{ count: 1 },
{ count: 0 },
])
})
it('reconnects w/ param update', async () => {
let countdown = 10
let reconnects = 0
const sub = new Subscription({
service: `ws://localhost:${port}`,
method: 'io.example.streamOne',
onReconnectError: () => reconnects++,
getParams: () => ({ countdown }),
validate: (obj) => {
return lex.assertValidXrpcMessage<{ count: number }>(
'io.example.streamOne',
obj,
)
},
})
let disconnected = false
for await (const msg of sub) {
expect(msg.count).toBeGreaterThanOrEqual(countdown - 1) // No skips
countdown = Math.min(countdown, msg.count) // Only allow forward movement
if (msg.count <= 6 && !disconnected) {
disconnected = true
server.subscriptions.forEach(({ wss }) => {
wss.clients.forEach((c) => c.terminate())
})
}
}
expect(countdown).toEqual(0)
expect(reconnects).toBeGreaterThan(0)
})
it('aborts with signal', async () => {
const abortController = new AbortController()
const sub = new Subscription({
service: `ws://localhost:${port}`,
method: 'io.example.streamOne',
signal: abortController.signal,
getParams: () => ({ countdown: 10 }),
validate: (obj) => {
const result = lex.assertValidXrpcMessage<{ count: number }>(
'io.example.streamOne',
obj,
)
return result
},
})
let error
let disconnected = false
const messages: { count: number }[] = []
try {
for await (const msg of sub) {
messages.push(msg)
if (msg.count <= 6 && !disconnected) {
disconnected = true
abortController.abort(new Error('Oops!'))
}
}
} catch (err) {
error = err
}
expect(error).toEqual(new Error('Oops!'))
expect(messages).toEqual([
{ count: 10 },
{ count: 9 },
{ count: 8 },
{ count: 7 },
{ count: 6 },
])
})
})
it('uses a heartbeat to reconnect if a connection is dropped', async () => {
// we run a server that, on first connection, pauses for longer than the heartbeat interval (doesn't return "pong"s)
// on second connection, it returns a message frame and then closes
const port = await getPort()
const server = new WebSocketServer({ port })
let firstConnection = true
let firstWasClosed = false
server.on('connection', async (socket) => {
if (firstConnection === true) {
firstConnection = false
socket.pause()
await wait(600)
// shouldn't send this message because the socket would be closed
const frame = new ErrorFrame({
error: 'AuthenticationRequired',
message: 'Authentication Required',
})
socket.send(frame.toBytes(), { binary: true }, (err) => {
if (err) throw err
socket.close(xrpcServer.CloseCode.Normal)
})
socket.on('close', () => {
firstWasClosed = true
})
} else {
const frame = new MessageFrame({ count: 1 })
socket.send(frame.toBytes(), { binary: true }, (err) => {
if (err) throw err
socket.close(xrpcServer.CloseCode.Normal)
})
}
})
const subscription = new Subscription({
service: `ws://localhost:${port}`,
method: '',
heartbeatIntervalMs: 500,
validate: (obj) => {
return lex.assertValidXrpcMessage<{ count: number }>(
'io.example.streamOne',
obj,
)
},
})
const messages: { count: number }[] = []
for await (const msg of subscription) {
messages.push(msg)
}
expect(messages).toEqual([{ count: 1 }])
expect(firstWasClosed).toBe(true)
server.close()
})
})