99import concurrent .futures
1010import threading
1111import multiprocessing
12+ import weakref
1213
1314from irods .data_object import iRODSDataObject
1415from irods .exception import DataObjectDoesNotExist
1516import irods .keywords as kw
1617from queue import Queue , Full , Empty
1718
1819
20+ transfer_managers = weakref .WeakKeyDictionary ()
21+
22+
23+ def abort_asynchronous_transfers ():
24+ for mgr in transfer_managers :
25+ mgr .quit ()
26+
27+
1928logger = logging .getLogger (__name__ )
2029_nullh = logging .NullHandler ()
2130logger .addHandler (_nullh )
@@ -90,9 +99,11 @@ def __init__(
9099 for future in self ._futures :
91100 future .add_done_callback (self )
92101 else :
93- self .__invoke_done_callback ()
102+ self .__invoke_futures_done_logic ()
103+ return
94104
95105 self .progress = [0 , 0 ]
106+
96107 if (progress_Queue ) and (total is not None ):
97108 self .progress [1 ] = total
98109
@@ -111,7 +122,7 @@ def _progress(Q, this): # - thread to update progress indicator
111122
112123 self ._progress_fn = _progress
113124 self ._progress_thread = threading .Thread (
114- target = self ._progress_fn , args = (progress_Queue , self )
125+ target = self ._progress_fn , args = (progress_Queue , self ), daemon = True
115126 )
116127 self ._progress_thread .start ()
117128
@@ -152,11 +163,13 @@ def __call__(
152163 with self ._lock :
153164 self ._futures_done [future ] = future .result ()
154165 if len (self ._futures ) == len (self ._futures_done ):
155- self .__invoke_done_callback ()
166+ self .__invoke_futures_done_logic (
167+ skip_user_callback = (None in self ._futures_done .values ())
168+ )
156169
157- def __invoke_done_callback (self ):
170+ def __invoke_futures_done_logic (self , skip_user_callback = False ):
158171 try :
159- if callable (self .done_callback ):
172+ if not skip_user_callback and callable (self .done_callback ):
160173 self .done_callback (self )
161174 finally :
162175 self .keep .pop ("mgr" , None )
@@ -239,6 +252,9 @@ def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()):
239252 bytecount = 0
240253 accum = 0
241254 while True and bytecount < length :
255+ if mgr ._quit :
256+ bytecount = None
257+ break
242258 buf = src .read (min (COPY_BUF_SIZE , length - bytecount ))
243259 buf_len = len (buf )
244260 if 0 == buf_len :
@@ -274,11 +290,16 @@ class _Multipart_close_manager:
274290 """
275291
276292 def __init__ (self , initial_io_ , exit_barrier_ ):
293+ self ._quit = False
277294 self .exit_barrier = exit_barrier_
278295 self .initial_io = initial_io_
279296 self .__lock = threading .Lock ()
280297 self .aux = []
281298
299+ def quit (self ):
300+ self ._quit = True
301+ self .exit_barrier .abort ()
302+
282303 def __contains__ (self , Io ):
283304 with self .__lock :
284305 return Io is self .initial_io or Io in self .aux
@@ -303,8 +324,12 @@ def remove_io(self, Io):
303324 Io .close ()
304325 self .aux .remove (Io )
305326 is_initial = False
306- self .exit_barrier .wait ()
307- if is_initial :
327+ broken = False
328+ try :
329+ self .exit_barrier .wait ()
330+ except threading .BrokenBarrierError :
331+ broken = True
332+ if is_initial and not (broken or self ._quit ):
308333 self .finalize ()
309334
310335 def finalize (self ):
@@ -439,13 +464,19 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk):
439464 Io = File = None
440465
441466 if Operation .isNonBlocking ():
442- if queueLength :
443- return futures , queueObject , mgr
444- else :
445- return futures
467+ return futures , queueObject , mgr
446468 else :
447- bytecounts = [f .result () for f in futures ]
448- return sum (bytecounts ), total_size
469+ bytes_transferred = 0
470+ try :
471+ bytecounts = [f .result () for f in futures ]
472+ if None not in bytecounts :
473+ bytes_transferred = sum (bytecounts )
474+ except KeyboardInterrupt :
475+ if any (not f .done () for f in futures ):
476+ # Induce any threads still alive to quit the transfer and exit.
477+ mgr .quit ()
478+ raise
479+ return bytes_transferred , total_size
449480
450481
451482def io_main (session , Data , opr_ , fname , R = "" , ** kwopt ):
@@ -558,10 +589,10 @@ def io_main(session, Data, opr_, fname, R="", **kwopt):
558589
559590 if Operation .isNonBlocking ():
560591
561- if queueLength > 0 :
562- ( futures , chunk_notify_queue , mgr ) = retval
563- else :
564- futures = retval
592+ ( futures , chunk_notify_queue , mgr ) = retval
593+ transfer_managers [ mgr ] = None
594+
595+ if queueLength <= 0 :
565596 chunk_notify_queue = total_bytes = None
566597
567598 return AsyncNotify (
0 commit comments