1- use core:: fmt;
2-
31use crate :: {
42 pubsub:: { In , JsonSink , Listener , Out } ,
53 types:: InboundData ,
4+ HandlerCtx , TaskSet ,
65} ;
6+ use core:: fmt;
77use serde_json:: value:: RawValue ;
8- use tokio:: {
9- select,
10- sync:: { mpsc, oneshot, watch} ,
11- task:: JoinHandle ,
12- } ;
8+ use tokio:: { pin, select, sync:: mpsc, task:: JoinHandle } ;
139use tokio_stream:: StreamExt ;
10+ use tokio_util:: sync:: WaitForCancellationFutureOwned ;
1411use tracing:: { debug, debug_span, error, instrument, trace, Instrument } ;
1512
1613/// Default notification buffer size per task.
@@ -19,18 +16,6 @@ pub const DEFAULT_NOTIFICATION_BUFFER_PER_CLIENT: usize = 16;
1916/// Type alias for identifying connections.
2017pub type ConnectionId = u64 ;
2118
22- /// Holds the shutdown signal for some server.
23- #[ derive( Debug ) ]
24- pub struct ServerShutdown {
25- pub ( crate ) _shutdown : watch:: Sender < ( ) > ,
26- }
27-
28- impl From < watch:: Sender < ( ) > > for ServerShutdown {
29- fn from ( sender : watch:: Sender < ( ) > ) -> Self {
30- Self { _shutdown : sender }
31- }
32- }
33-
3419/// The `ListenerTask` listens for new connections, and spawns `RouteTask`s for
3520/// each.
3621pub ( crate ) struct ListenerTask < T : Listener > {
@@ -67,16 +52,17 @@ where
6752 }
6853
6954 /// Spawn the future produced by [`Self::task_future`].
70- pub ( crate ) fn spawn ( self ) -> JoinHandle < ( ) > {
55+ pub ( crate ) fn spawn ( self ) -> JoinHandle < Option < ( ) > > {
56+ let tasks = self . manager . root_tasks . clone ( ) ;
7157 let future = self . task_future ( ) ;
72- tokio :: spawn ( future)
58+ tasks . spawn_cancellable ( future)
7359 }
7460}
7561
7662/// The `ConnectionManager` provides connections with IDs, and handles spawning
7763/// the [`RouteTask`] for each connection.
7864pub ( crate ) struct ConnectionManager {
79- pub ( crate ) shutdown : watch :: Receiver < ( ) > ,
65+ pub ( crate ) root_tasks : TaskSet ,
8066
8167 pub ( crate ) next_id : ConnectionId ,
8268
@@ -107,19 +93,18 @@ impl ConnectionManager {
10793 ) -> ( RouteTask < T > , WriteTask < T > ) {
10894 let ( tx, rx) = mpsc:: channel ( self . notification_buffer_per_task ) ;
10995
110- let ( gone_tx , gone_rx ) = oneshot :: channel ( ) ;
96+ let tasks = self . root_tasks . child ( ) ;
11197
11298 let rt = RouteTask {
11399 router : self . router ( ) ,
114100 conn_id,
115101 write_task : tx,
116102 requests,
117- gone : gone_tx ,
103+ tasks : tasks . clone ( ) ,
118104 } ;
119105
120106 let wt = WriteTask {
121- shutdown : self . shutdown . clone ( ) ,
122- gone : gone_rx,
107+ tasks,
123108 conn_id,
124109 json : rx,
125110 connection,
@@ -156,8 +141,8 @@ struct RouteTask<T: crate::pubsub::Listener> {
156141 pub ( crate ) write_task : mpsc:: Sender < Box < RawValue > > ,
157142 /// Stream of requests.
158143 pub ( crate ) requests : In < T > ,
159- /// Sender to the [`WriteTask`], to notify it that this task is done.
160- pub ( crate ) gone : oneshot :: Sender < ( ) > ,
144+ /// The task set for this connection
145+ pub ( crate ) tasks : TaskSet ,
161146}
162147
163148impl < T : crate :: pubsub:: Listener > fmt:: Debug for RouteTask < T > {
@@ -179,18 +164,27 @@ where
179164 /// to handle the request, and given a sender to the [`WriteTask`]. This
180165 /// ensures that requests can be handled concurrently.
181166 #[ instrument( name = "RouteTask" , skip( self ) , fields( conn_id = self . conn_id) ) ]
182- pub async fn task_future ( self ) {
167+ pub async fn task_future ( self , cancel : WaitForCancellationFutureOwned ) {
183168 let RouteTask {
184169 router,
185170 mut requests,
186171 write_task,
187- gone ,
172+ tasks ,
188173 ..
189174 } = self ;
190175
176+ // The write task is responsible for waiting for its children
177+ let children = tasks. child ( ) ;
178+
179+ pin ! ( cancel) ;
180+
191181 loop {
192182 select ! {
193183 biased;
184+ _ = & mut cancel => {
185+ debug!( "RouteTask cancelled" ) ;
186+ break ;
187+ }
194188 _ = write_task. closed( ) => {
195189 debug!( "WriteTask has gone away" ) ;
196190 break ;
@@ -208,7 +202,11 @@ where
208202
209203 let span = debug_span!( "pubsub request handling" , reqs = reqs. len( ) ) ;
210204
211- let ctx = write_task. clone( ) . into( ) ;
205+ let ctx =
206+ HandlerCtx :: new(
207+ Some ( write_task. clone( ) ) ,
208+ children. clone( ) ,
209+ ) ;
212210
213211 let fut = router. handle_request_batch( ctx, reqs) ;
214212 let write_task = write_task. clone( ) ;
@@ -223,7 +221,7 @@ where
223221 } ;
224222
225223 // Run the future in a new task.
226- tokio :: spawn (
224+ children . spawn_cancellable (
227225 async move {
228226 // Send the response to the write task.
229227 // we don't care if the receiver has gone away,
@@ -239,27 +237,23 @@ where
239237 }
240238 }
241239 }
242- // No funny business. Drop the gone signal.
243- drop ( gone) ;
240+ children. shutdown ( ) . await ;
244241 }
245242
246243 /// Spawn the future produced by [`Self::task_future`].
247244 pub ( crate ) fn spawn ( self ) -> tokio:: task:: JoinHandle < ( ) > {
248- let future = self . task_future ( ) ;
249- tokio:: spawn ( future)
245+ let tasks = self . tasks . clone ( ) ;
246+
247+ let future = move |cancel| self . task_future ( cancel) ;
248+
249+ tasks. spawn_graceful ( future)
250250 }
251251}
252252
253253/// The Write Task is responsible for writing JSON to the outbound connection.
254254struct WriteTask < T : Listener > {
255- /// Shutdown signal.
256- ///
257- /// Shutdowns bubble back up to [`RouteTask`] when the write task is
258- /// dropped, via the closed `json` channel.
259- pub ( crate ) shutdown : watch:: Receiver < ( ) > ,
260-
261- /// Signal that the connection has gone away.
262- pub ( crate ) gone : oneshot:: Receiver < ( ) > ,
255+ /// Task set
256+ pub ( crate ) tasks : TaskSet ,
263257
264258 /// ID of the connection.
265259 pub ( crate ) conn_id : ConnectionId ,
@@ -281,25 +275,23 @@ impl<T: Listener> WriteTask<T> {
281275 /// channel, and acts on them. It handles JSON messages, and going away
282276 /// instructions. It also listens for the global shutdown signal from the
283277 /// [`ServerShutdown`] struct.
278+ ///
279+ /// [`ServerShutdown`]: crate::pubsub::ServerShutdown
284280 #[ instrument( skip( self ) , fields( conn_id = self . conn_id) ) ]
285281 pub ( crate ) async fn task_future ( self ) {
286282 let WriteTask {
287- mut shutdown,
288- mut gone,
283+ tasks,
289284 mut json,
290285 mut connection,
291286 ..
292287 } = self ;
293- shutdown . mark_unchanged ( ) ;
288+
294289 loop {
295290 select ! {
296291 biased;
297- _ = & mut gone => {
298- debug!( "Connection has gone away" ) ;
299- break ;
300- }
301- _ = shutdown. changed( ) => {
302- debug!( "shutdown signal received" ) ;
292+
293+ _ = tasks. cancelled( ) => {
294+ debug!( "Shutdown signal received" ) ;
303295 break ;
304296 }
305297 json = json. recv( ) => {
@@ -317,7 +309,9 @@ impl<T: Listener> WriteTask<T> {
317309 }
318310
319311 /// Spawn the future produced by [`Self::task_future`].
320- pub ( crate ) fn spawn ( self ) -> JoinHandle < ( ) > {
321- tokio:: spawn ( self . task_future ( ) )
312+ pub ( crate ) fn spawn ( self ) -> tokio:: task:: JoinHandle < Option < ( ) > > {
313+ let tasks = self . tasks . clone ( ) ;
314+ let future = self . task_future ( ) ;
315+ tasks. spawn_cancellable ( future)
322316 }
323317}
0 commit comments