Skip to content
This repository was archived by the owner on May 17, 2025. It is now read-only.

Commit 3983767

Browse files
authored
feat: provide the connectionId in the context at all times (#28)
- Bonus eslint updates - bonus refactor out some methods - remove all type warnings and errors - Fix spelling on some types BREAKING CHANGE: Changed the SubscribePseudoIterable name to correct it's spelling
1 parent 65b3973 commit 3983767

27 files changed

+394
-395
lines changed

.eslintrc.js

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,24 @@ module.exports = {
1616
'@typescript-eslint',
1717
],
1818
rules: {
19-
indent: [ 'error', 2 ],
20-
'linebreak-style': [ 'error', 'unix' ],
21-
quotes: [ 'error', 'single' ],
22-
semi: 'off',
23-
'@typescript-eslint/semi': ['error', 'never'],
24-
'quote-props': ['error', 'as-needed'],
25-
'no-param-reassign': 'error',
26-
'comma-dangle': ['error', 'always-multiline'],
27-
'space-infix-ops': ['error'],
28-
'no-multi-spaces': ['error'],
29-
'no-unused-vars': 'off',
30-
'@typescript-eslint/no-unused-vars': ['error', { argsIgnorePattern: '^_' }],
3119
'@typescript-eslint/member-delimiter-style': ['error', {
3220
multiline: { delimiter: 'none' },
3321
singleline: { delimiter: 'comma', requireLast: false },
3422
multilineDetection: 'last-member',
3523
}],
24+
'@typescript-eslint/no-unused-vars': ['error', { argsIgnorePattern: '^_' }],
25+
'@typescript-eslint/semi': ['error', 'never'],
26+
'array-bracket-spacing': ['error', 'never', { singleValue: false }],
27+
'comma-dangle': ['error', 'always-multiline'],
28+
'linebreak-style': ['error', 'unix'],
29+
'no-multi-spaces': ['error'],
30+
'no-param-reassign': 'error',
31+
'no-unused-vars': 'off',
32+
'object-curly-spacing': ['error', 'always'],
33+
'quote-props': ['error', 'as-needed'],
34+
'space-infix-ops': ['error'],
35+
indent: ['error', 2],
36+
quotes: ['error', 'single'],
37+
semi: 'off',
3638
},
3739
}

lib/gateway.ts

Lines changed: 62 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -12,81 +12,79 @@ import { subscribe } from './messages/subscribe'
1212
import { connection_init } from './messages/connection_init'
1313
import { pong } from './messages/pong'
1414

15-
export const handleGatewayEvent =
16-
(server: ServerClosure): ApiGatewayHandler<APIGatewayWebSocketEvent, WebsocketResponse> =>
17-
async (event) => {
18-
if (!event.requestContext) {
19-
return {
20-
statusCode: 200,
21-
body: '',
22-
}
23-
}
24-
25-
if (event.requestContext.eventType === 'CONNECT') {
26-
await server.onConnect?.({ event })
27-
return {
28-
statusCode: 200,
29-
headers: {
30-
'Sec-WebSocket-Protocol': GRAPHQL_TRANSPORT_WS_PROTOCOL,
31-
},
32-
body: '',
33-
}
34-
}
35-
36-
if (event.requestContext.eventType === 'MESSAGE') {
37-
const message = JSON.parse(event.body!)
15+
export const handleGatewayEvent = (server: ServerClosure): ApiGatewayHandler<APIGatewayWebSocketEvent, WebsocketResponse> => async (event) => {
16+
if (!event.requestContext) {
17+
return {
18+
statusCode: 200,
19+
body: '',
20+
}
21+
}
3822

39-
if (message.type === MessageType.ConnectionInit) {
40-
await connection_init({ server, event, message })
41-
return {
42-
statusCode: 200,
43-
body: '',
44-
}
45-
}
23+
if (event.requestContext.eventType === 'CONNECT') {
24+
await server.onConnect?.({ event })
25+
return {
26+
statusCode: 200,
27+
headers: {
28+
'Sec-WebSocket-Protocol': GRAPHQL_TRANSPORT_WS_PROTOCOL,
29+
},
30+
body: '',
31+
}
32+
}
4633

47-
if (message.type === MessageType.Subscribe) {
48-
await subscribe({ server, event, message })
49-
return {
50-
statusCode: 200,
51-
body: '',
52-
}
53-
}
34+
if (event.requestContext.eventType === 'MESSAGE') {
35+
const message = event.body === null ? null : JSON.parse(event.body)
5436

55-
if (message.type === MessageType.Complete) {
56-
await complete({ server, event, message })
57-
return {
58-
statusCode: 200,
59-
body: '',
60-
}
61-
}
37+
if (message.type === MessageType.ConnectionInit) {
38+
await connection_init({ server, event, message })
39+
return {
40+
statusCode: 200,
41+
body: '',
42+
}
43+
}
6244

63-
if (message.type === MessageType.Ping) {
64-
await ping({ server, event, message })
65-
return {
66-
statusCode: 200,
67-
body: '',
68-
}
69-
}
45+
if (message.type === MessageType.Subscribe) {
46+
await subscribe({ server, event, message })
47+
return {
48+
statusCode: 200,
49+
body: '',
50+
}
51+
}
7052

71-
if (message.type === MessageType.Pong) {
72-
await pong({ server, event, message })
73-
return {
74-
statusCode: 200,
75-
body: '',
76-
}
77-
}
53+
if (message.type === MessageType.Complete) {
54+
await complete({ server, event, message })
55+
return {
56+
statusCode: 200,
57+
body: '',
7858
}
59+
}
7960

80-
if (event.requestContext.eventType === 'DISCONNECT') {
81-
await disconnect({ server, event, message: null })
82-
return {
83-
statusCode: 200,
84-
body: '',
85-
}
61+
if (message.type === MessageType.Ping) {
62+
await ping({ server, event, message })
63+
return {
64+
statusCode: 200,
65+
body: '',
8666
}
67+
}
8768

69+
if (message.type === MessageType.Pong) {
70+
await pong({ server, event, message })
8871
return {
8972
statusCode: 200,
9073
body: '',
9174
}
9275
}
76+
}
77+
78+
if (event.requestContext.eventType === 'DISCONNECT') {
79+
await disconnect({ server, event, message: null })
80+
return {
81+
statusCode: 200,
82+
body: '',
83+
}
84+
}
85+
86+
return {
87+
statusCode: 200,
88+
body: '',
89+
}
90+
}

lib/messages/complete.ts

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,47 @@ import AggregateError from 'aggregate-error'
22
import { parse } from 'graphql'
33
import { CompleteMessage } from 'graphql-ws'
44
import { buildExecutionContext } from 'graphql/execution/execute'
5-
import { SubscribePsuedoIterable, MessageHandler } from '../types'
6-
import { deleteConnection } from '../utils/aws'
7-
import { constructContext, getResolverAndArgs } from '../utils/graphql'
5+
import { collect } from 'streaming-iterables'
6+
import { SubscribePseudoIterable, MessageHandler } from '../types'
7+
import { deleteConnection } from '../utils/deleteConnection'
8+
import { constructContext } from '../utils/constructContext'
9+
import { getResolverAndArgs } from '../utils/getResolverAndArgs'
810
import { isArray } from '../utils/isArray'
911

1012
/** Handler function for 'complete' message. */
1113
export const complete: MessageHandler<CompleteMessage> =
1214
async ({ server, event, message }) => {
1315
try {
14-
const topicSubscriptions = server.mapper.query(server.model.Subscription, {
15-
id: `${event.requestContext.connectionId!}|${message.id}`,
16-
})
17-
let deletions = [] as Promise<any>[]
18-
for await (const entity of topicSubscriptions) {
19-
deletions = [
20-
...deletions,
21-
(async () => {
22-
// only call onComplete per subscription
23-
if (deletions.length === 0) {
24-
const execContext = buildExecutionContext(
25-
server.schema,
26-
parse(entity.subscription.query),
27-
undefined,
28-
await constructContext(server)(entity),
29-
entity.subscription.variables,
30-
entity.subscription.operationName,
31-
undefined,
32-
)
33-
34-
if (isArray(execContext)) {
35-
throw new AggregateError(execContext)
36-
}
16+
const topicSubscriptions = await collect(server.mapper.query(server.model.Subscription, {
17+
id: `${event.requestContext.connectionId}|${message.id}`,
18+
}))
19+
if (topicSubscriptions.length === 0) {
20+
return
21+
}
22+
// only call onComplete on the first one as any others are duplicates
23+
const sub = topicSubscriptions[0]
24+
const execContext = buildExecutionContext(
25+
server.schema,
26+
parse(sub.subscription.query),
27+
undefined,
28+
await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }),
29+
sub.subscription.variables,
30+
sub.subscription.operationName,
31+
undefined,
32+
)
3733

38-
const [field, root, args, context, info] = getResolverAndArgs(server)(execContext)
34+
if (isArray(execContext)) {
35+
throw new AggregateError(execContext)
36+
}
3937

40-
const onComplete = (field?.subscribe as SubscribePsuedoIterable)?.onComplete
41-
if (onComplete) {
42-
await onComplete(root, args, context, info)
43-
}
44-
}
38+
const [field, root, args, context, info] = getResolverAndArgs(server)(execContext)
4539

46-
await server.mapper.delete(entity)
47-
})(),
48-
]
40+
const onComplete = (field?.subscribe as SubscribePseudoIterable)?.onComplete
41+
if (onComplete) {
42+
await onComplete(root, args, context, info)
4943
}
5044

51-
await Promise.all(deletions)
45+
await Promise.all(topicSubscriptions.map(sub => server.mapper.delete(sub)))
5246
} catch (err) {
5347
await server.onError?.(err, { event, message })
5448
await deleteConnection(server)(event.requestContext)

lib/messages/connection_init.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import { StepFunctions } from 'aws-sdk'
22
import { ConnectionInitMessage, MessageType } from 'graphql-ws'
33
import { StateFunctionInput, MessageHandler } from '../types'
4-
import { deleteConnection, sendMessage } from '../utils/aws'
4+
import { sendMessage } from '../utils/sendMessage'
5+
import { deleteConnection } from '../utils/deleteConnection'
56

67
/** Handler function for 'connection_init' message. */
78
export const connection_init: MessageHandler<ConnectionInitMessage> =
@@ -15,10 +16,10 @@ export const connection_init: MessageHandler<ConnectionInitMessage> =
1516
await new StepFunctions()
1617
.startExecution({
1718
stateMachineArn: server.pingpong.machine,
18-
name: event.requestContext.connectionId!,
19+
name: event.requestContext.connectionId,
1920
input: JSON.stringify({
20-
connectionId: event.requestContext.connectionId!,
21-
domainName: event.requestContext.domainName!,
21+
connectionId: event.requestContext.connectionId,
22+
domainName: event.requestContext.domainName,
2223
stage: event.requestContext.stage,
2324
state: 'PING',
2425
choice: 'WAIT',
@@ -30,7 +31,7 @@ export const connection_init: MessageHandler<ConnectionInitMessage> =
3031

3132
// Write to persistence
3233
const connection = Object.assign(new server.model.Connection(), {
33-
id: event.requestContext.connectionId!,
34+
id: event.requestContext.connectionId,
3435
requestContext: event.requestContext,
3536
payload: res,
3637
})

lib/messages/disconnect.ts

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,40 +2,43 @@ import AggregateError from 'aggregate-error'
22
import { parse } from 'graphql'
33
import { equals } from '@aws/dynamodb-expressions'
44
import { buildExecutionContext } from 'graphql/execution/execute'
5-
import { constructContext, getResolverAndArgs } from '../utils/graphql'
6-
import { SubscribePsuedoIterable, MessageHandler } from '../types'
5+
import { constructContext } from '../utils/constructContext'
6+
import { getResolverAndArgs } from '../utils/getResolverAndArgs'
7+
import { SubscribePseudoIterable, MessageHandler } from '../types'
78
import { isArray } from '../utils/isArray'
9+
import { collect } from 'streaming-iterables'
10+
import { Connection } from '../model/Connection'
811

912
/** Handler function for 'disconnect' message. */
1013
export const disconnect: MessageHandler<null> =
1114
async ({ server, event }) => {
1215
try {
1316
await server.onDisconnect?.({ event })
1417

15-
const entities = server.mapper.query(
18+
const topicSubscriptions = await collect(server.mapper.query(
1619
server.model.Subscription,
1720
{
1821
connectionId: equals(event.requestContext.connectionId),
1922
},
2023
{ indexName: 'ConnectionIndex' },
21-
)
24+
))
2225

2326
const completed = {} as Record<string, boolean>
24-
const deletions = [] as Promise<any>[]
25-
for await (const entity of entities) {
27+
const deletions = [] as Promise<void|Connection>[]
28+
for (const sub of topicSubscriptions) {
2629
deletions.push(
2730
(async () => {
2831
// only call onComplete per subscription
29-
if (!completed[entity.subscriptionId]) {
30-
completed[entity.subscriptionId] = true
32+
if (!completed[sub.subscriptionId]) {
33+
completed[sub.subscriptionId] = true
3134

3235
const execContext = buildExecutionContext(
3336
server.schema,
34-
parse(entity.subscription.query),
37+
parse(sub.subscription.query),
3538
undefined,
36-
await constructContext(server)(entity),
37-
entity.subscription.variables,
38-
entity.subscription.operationName,
39+
await constructContext({ server, connectionParams: sub.connectionParams, connectionId: sub.connectionId }),
40+
sub.subscription.variables,
41+
sub.subscription.operationName,
3942
undefined,
4043
)
4144

@@ -46,13 +49,13 @@ export const disconnect: MessageHandler<null> =
4649

4750
const [field, root, args, context, info] = getResolverAndArgs(server)(execContext)
4851

49-
const onComplete = (field?.subscribe as SubscribePsuedoIterable)?.onComplete
52+
const onComplete = (field?.subscribe as SubscribePseudoIterable)?.onComplete
5053
if (onComplete) {
5154
await onComplete(root, args, context, info)
5255
}
5356
}
5457

55-
await server.mapper.delete(entity)
58+
await server.mapper.delete(sub)
5659
})(),
5760
)
5861
}
@@ -63,7 +66,7 @@ export const disconnect: MessageHandler<null> =
6366
// Delete connection
6467
server.mapper.delete(
6568
Object.assign(new server.model.Connection(), {
66-
id: event.requestContext.connectionId!,
69+
id: event.requestContext.connectionId,
6770
}),
6871
),
6972
])

0 commit comments

Comments
 (0)