Skip to content

Commit 76fa487

Browse files
committed
Type-safe react stub streaming calls
1 parent 9e2f4e7 commit 76fa487

File tree

6 files changed

+192
-33
lines changed

6 files changed

+192
-33
lines changed

packages/agents/src/client.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,14 @@ export type AgentClientOptions<State = unknown> = Omit<
2727
/**
2828
* Options for streaming RPC calls
2929
*/
30-
export type StreamOptions = {
30+
export type StreamOptions<
31+
OnChunkT extends unknown | SerializableValue = unknown,
32+
OnDoneT extends unknown | SerializableValue = unknown,
33+
> = {
3134
/** Called when a chunk of data is received */
32-
onChunk?: (chunk: unknown) => void;
35+
onChunk?: (chunk: OnChunkT) => void;
3336
/** Called when the stream ends */
34-
onDone?: (finalChunk: unknown) => void;
37+
onDone?: (finalChunk: OnDoneT) => void;
3538
/** Called when an error occurs */
3639
onError?: (error: string) => void;
3740
};

packages/agents/src/index.ts

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import type { Client } from "@modelcontextprotocol/sdk/client/index.js";
2525
import type { SSEClientTransportOptions } from "@modelcontextprotocol/sdk/client/sse.js";
2626

2727
import { camelCaseToKebabCase } from "./client";
28-
import type { MCPClientConnection } from "./mcp/client-connection";
28+
import type { SerializableValue } from "./serializable";
2929

3030
export type { Connection, ConnectionContext, WSMessage } from "partyserver";
3131

@@ -1205,7 +1205,10 @@ export async function getAgentByName<Env, T extends Agent<Env>>(
12051205
/**
12061206
* A wrapper for streaming responses in callable methods
12071207
*/
1208-
export class StreamingResponse {
1208+
export class StreamingResponse<
1209+
OnChunkT extends SerializableValue | unknown = unknown,
1210+
OnDoneT extends SerializableValue | unknown = unknown,
1211+
> {
12091212
private _connection: Connection;
12101213
private _id: string;
12111214
private _closed = false;
@@ -1219,7 +1222,7 @@ export class StreamingResponse {
12191222
* Send a chunk of data to the client
12201223
* @param chunk The data to send
12211224
*/
1222-
send(chunk: unknown) {
1225+
send(chunk: OnChunkT) {
12231226
if (this._closed) {
12241227
throw new Error("StreamingResponse is already closed");
12251228
}
@@ -1237,7 +1240,7 @@ export class StreamingResponse {
12371240
* End the stream and send the final chunk (if any)
12381241
* @param finalChunk Optional final chunk of data to send
12391242
*/
1240-
end(finalChunk?: unknown) {
1243+
end(finalChunk?: OnDoneT) {
12411244
if (this._closed) {
12421245
throw new Error("StreamingResponse is already closed");
12431246
}

packages/agents/src/react.tsx

Lines changed: 91 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ import { usePartySocket } from "partysocket/react";
33
import { useCallback, useRef } from "react";
44
import type { MCPServersState, RPCRequest, RPCResponse, Agent } from "./";
55
import type { StreamOptions } from "./client";
6-
import type { Method, RPCMethod } from "./serializable";
6+
import type {
7+
AllSerializableValues,
8+
SerializableReturnValue,
9+
SerializableValue,
10+
} from "./serializable";
711

812
/**
913
* Convert a camelCase string to a kebab-case string
@@ -44,17 +48,68 @@ export type UseAgentOptions<State = unknown> = Omit<
4448
onMcpUpdate?: (mcpServers: MCPServersState) => void;
4549
};
4650

51+
// biome-ignore lint: suppressions/parse
52+
type Method = (...args: any[]) => any;
53+
54+
type NonStreamingRPCMethod<T extends Method> =
55+
AllSerializableValues<Parameters<T>> extends true
56+
? ReturnType<T> extends SerializableReturnValue
57+
? T
58+
: never
59+
: never;
60+
61+
interface StreamingResponse<
62+
OnChunkT extends SerializableValue | unknown = unknown,
63+
OnDoneT extends SerializableValue | unknown = unknown,
64+
> {
65+
send(chunk: OnChunkT): void;
66+
end(finalChunk?: OnDoneT): void;
67+
}
68+
69+
type StreamingRPCMethod<T extends Method> = T extends (
70+
arg: infer A,
71+
...rest: infer R
72+
) => void | Promise<void>
73+
? A extends StreamingResponse<SerializableValue, SerializableValue>
74+
? AllSerializableValues<R> extends true
75+
? T
76+
: never
77+
: never
78+
: never;
79+
80+
type RPCMethod<T extends Method> =
81+
T extends NonStreamingRPCMethod<T>
82+
? NonStreamingRPCMethod<T>
83+
: T extends StreamingRPCMethod<T>
84+
? StreamingRPCMethod<T>
85+
: never;
86+
87+
type RPCMethods<T> = {
88+
[K in keyof T as T[K] extends Method ? K : never]: T[K] extends Method
89+
? RPCMethod<T[K]>
90+
: never;
91+
};
92+
4793
type AllOptional<T> = T extends [infer A, ...infer R]
4894
? undefined extends A
4995
? AllOptional<R>
5096
: false
5197
: true; // no params means optional by default
5298

53-
type RPCMethods<T> = {
54-
[K in keyof T as T[K] extends RPCMethod<T[K]> ? K : never]: RPCMethod<T[K]>;
55-
};
99+
type StreamOptionsFrom<StreamingResponseT> =
100+
StreamingResponseT extends StreamingResponse<
101+
infer T extends SerializableValue,
102+
infer U extends SerializableValue
103+
>
104+
? StreamOptions<T, U>
105+
: never;
56106

57-
type OptionalParametersMethod<T extends RPCMethod> =
107+
type RestParameters<T extends Method> =
108+
Parameters<StreamingRPCMethod<T>> extends [unknown, ...infer Rest]
109+
? Rest
110+
: never;
111+
112+
type OptionalParametersMethod<T extends Method> =
58113
AllOptional<Parameters<T>> extends true ? T : never;
59114

60115
// all methods of the Agent, excluding the ones that are declared in the base Agent class
@@ -74,6 +129,14 @@ type RequiredAgentMethods<T> = Omit<
74129
keyof OptionalAgentMethods<T>
75130
>;
76131

132+
type StreamingAgentMethods<T> = {
133+
[K in keyof AgentMethods<T> as AgentMethods<T>[K] extends StreamingRPCMethod<
134+
AgentMethods<T>[K]
135+
>
136+
? K
137+
: never]: StreamingRPCMethod<AgentMethods<T>[K]>;
138+
};
139+
77140
type AgentPromiseReturnType<T, K extends keyof AgentMethods<T>> =
78141
// biome-ignore lint: suppressions/parse
79142
ReturnType<AgentMethods<T>[K]> extends Promise<any>
@@ -96,7 +159,18 @@ type RequiredArgsAgentMethodCall<AgentT> = <
96159
streamOptions?: StreamOptions
97160
) => AgentPromiseReturnType<AgentT, K>;
98161

99-
type AgentMethodCall<AgentT> = OptionalArgsAgentMethodCall<AgentT> &
162+
type StreamingAgentMethodCall<AgentT> = <
163+
K extends keyof StreamingAgentMethods<AgentT>,
164+
>(
165+
method: K,
166+
args: RestParameters<StreamingAgentMethods<AgentT>[K]>,
167+
streamOptions: StreamOptionsFrom<
168+
Parameters<StreamingAgentMethods<AgentT>[K]>[0]
169+
>
170+
) => void;
171+
172+
type AgentMethodCall<AgentT> = StreamingAgentMethodCall<AgentT> &
173+
OptionalArgsAgentMethodCall<AgentT> &
100174
RequiredArgsAgentMethodCall<AgentT>;
101175

102176
type UntypedAgentMethodCall = <T = unknown>(
@@ -106,9 +180,16 @@ type UntypedAgentMethodCall = <T = unknown>(
106180
) => Promise<T>;
107181

108182
type AgentStub<T> = {
109-
[K in keyof AgentMethods<T>]: (
110-
...args: Parameters<AgentMethods<T>[K]>
111-
) => AgentPromiseReturnType<AgentMethods<T>, K>;
183+
[K in keyof AgentMethods<T>]: AgentMethods<T>[K] extends StreamingRPCMethod<
184+
AgentMethods<T>[K]
185+
>
186+
? (
187+
options: StreamOptionsFrom<Parameters<AgentMethods<T>[K]>[0]>,
188+
...args: RestParameters<AgentMethods<T>[K]>
189+
) => void
190+
: (
191+
...args: Parameters<AgentMethods<T>[K]>
192+
) => AgentPromiseReturnType<AgentMethods<T>, K>;
112193
};
113194

114195
// we neet to use Method instead of RPCMethod here for retro-compatibility
@@ -150,7 +231,7 @@ export function useAgent<State>(
150231
agent: string;
151232
name: string;
152233
setState: (state: State) => void;
153-
call: UntypedAgentMethodCall | AgentMethodCall<unknown>;
234+
call: UntypedAgentMethodCall;
154235
stub: UntypedAgentStub;
155236
} {
156237
const agentNamespace = camelCaseToKebabCase(options.agent);

packages/agents/src/serializable.ts

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,8 @@ export type SerializableReturnValue =
1313
| Promise<SerializableValue>
1414
| Promise<void>;
1515

16-
type AllSerializableValues<A> = A extends [infer First, ...infer Rest]
16+
export type AllSerializableValues<A> = A extends [infer First, ...infer Rest]
1717
? First extends SerializableValue
1818
? AllSerializableValues<Rest>
1919
: false
2020
: true; // no params means serializable by default
21-
22-
// biome-ignore lint: suspicious/noExplicitAny
23-
export type Method = (...args: any[]) => any;
24-
25-
export type RPCMethod<T = Method> = T extends Method
26-
? T extends (...arg: infer A) => infer R
27-
? AllSerializableValues<A> extends true
28-
? R extends SerializableReturnValue
29-
? T
30-
: never
31-
: never
32-
: never
33-
: never;

packages/agents/src/tests-d/example-stub.test-d.ts

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import type { env } from "cloudflare:workers";
2-
import { unstable_callable as callable, Agent } from "..";
2+
import {
3+
unstable_callable as callable,
4+
Agent,
5+
type StreamingResponse,
6+
} from "..";
37
import { useAgent } from "../react.tsx";
8+
import type { StreamOptions } from "../client.ts";
49

510
class MyAgent extends Agent<typeof env, {}> {
611
@callable()
@@ -17,6 +22,42 @@ class MyAgent extends Agent<typeof env, {}> {
1722
nonRpc(): void {
1823
// do something
1924
}
25+
26+
@callable({ streaming: true })
27+
performStream(
28+
options: StreamingResponse<number, boolean>,
29+
other: string
30+
): void {
31+
// do something
32+
}
33+
34+
// TODO should fail, first argument is not a streamOptions
35+
@callable({ streaming: true })
36+
performStreamFirstArgNotStreamOptions(
37+
other: string,
38+
options: StreamingResponse<number, boolean>
39+
): void {
40+
// do something
41+
}
42+
43+
// TODO should fail, should be marked as streaming
44+
@callable()
45+
performStreamFail(options: StreamingResponse): void {
46+
// do something
47+
}
48+
49+
// TODO should fail, has no streamOptions
50+
@callable({ streaming: true })
51+
async performFail(task: string): Promise<string> {
52+
// do something
53+
return "";
54+
}
55+
56+
@callable({ streaming: true })
57+
performStreamUnserializable(options: StreamingResponse<Date>): void {
58+
// @ts-expect-error parameter is not serializable
59+
options.onDone(new Date());
60+
}
2061
}
2162

2263
const { stub } = useAgent<MyAgent, {}>({ agent: "my-agent" });
@@ -38,9 +79,26 @@ await stub.nonRpc();
3879
// @ts-expect-error nonSerializable is not serializable
3980
await stub.nonSerializable("hello", new Date());
4081

82+
const streamOptions: StreamOptions<number, boolean> = {};
83+
84+
// biome-ignore lint: suspicious/noConfusingVoidType
85+
stub.performStream(streamOptions, "hello") satisfies void;
86+
87+
// @ts-expect-error there's no 2nd argument
88+
stub.performStream(streamOptions, "hello", 1);
89+
90+
const invalidStreamOptions: StreamOptions<string, boolean> = {};
91+
92+
// @ts-expect-error streamOptions must be of type StreamOptions<number, boolean>
93+
stub.performStream(invalidStreamOptions, "hello");
94+
95+
// @ts-expect-error first argument is not a streamOptions
96+
stub.performStreamFirstArgNotStreamOptions("hello", streamOptions);
97+
4198
const { stub: stub2 } = useAgent<Omit<MyAgent, "nonRpc">, {}>({
4299
agent: "my-agent",
43100
});
101+
44102
stub2.sayHello();
45103
// @ts-expect-error nonRpc excluded from useAgent
46104
stub2.nonRpc();

packages/agents/src/tests-d/example.test-d.ts

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import type { env } from "cloudflare:workers";
2-
import { unstable_callable as callable, Agent } from "..";
2+
import {
3+
unstable_callable as callable,
4+
Agent,
5+
type StreamingResponse,
6+
} from "..";
37
import { useAgent } from "../react.tsx";
8+
import type { StreamOptions } from "../client.ts";
49

510
class MyAgent extends Agent<typeof env, {}> {
611
@callable()
@@ -17,6 +22,16 @@ class MyAgent extends Agent<typeof env, {}> {
1722
nonRpc(): void {
1823
// do something
1924
}
25+
26+
@callable({ streaming: true })
27+
performStream(
28+
response: StreamingResponse<number, boolean>,
29+
other: string
30+
): void {
31+
response.send(1);
32+
response.send(2);
33+
response.end(true);
34+
}
2035
}
2136

2237
const agent = useAgent<MyAgent, {}>({ agent: "my-agent" });
@@ -38,6 +53,18 @@ await agent.call("nonRpc");
3853
// @ts-expect-error nonSerializable is not serializable
3954
await agent.call("nonSerializable", ["hello", new Date()]);
4055

56+
const streamOptions: StreamOptions<number, boolean> = {};
57+
58+
agent.call("performStream", ["hello"], streamOptions);
59+
60+
// @ts-expect-error there's no second parameter
61+
agent.call("performStream", ["a", 1], streamOptions);
62+
63+
const invalidStreamOptions: StreamOptions<string, boolean> = {};
64+
65+
// @ts-expect-error streamOptions must be of type StreamOptions<number, boolean>
66+
agent.call("performStream", ["a", 1], invalidStreamOptions);
67+
4168
const agent2 = useAgent<Omit<MyAgent, "nonRpc">, {}>({ agent: "my-agent" });
4269
agent2.call("sayHello");
4370
// @ts-expect-error nonRpc excluded from useAgent

0 commit comments

Comments
 (0)