Skip to content

Commit 8fcb9d7

Browse files
committed
[_722] fix segfault and hung threads on SIGINT during parallel get
1 parent c8ac7ba commit 8fcb9d7

File tree

3 files changed

+173
-17
lines changed

3 files changed

+173
-17
lines changed

irods/parallel.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,22 @@
99
import concurrent.futures
1010
import threading
1111
import multiprocessing
12+
import weakref
1213

1314
from irods.data_object import iRODSDataObject
1415
from irods.exception import DataObjectDoesNotExist
1516
import irods.keywords as kw
1617
from queue import Queue, Full, Empty
1718

1819

20+
transfer_managers = weakref.WeakKeyDictionary()
21+
22+
23+
def abort_asynchronous_transfers():
24+
for mgr in transfer_managers:
25+
mgr.quit()
26+
27+
1928
logger = logging.getLogger(__name__)
2029
_nullh = logging.NullHandler()
2130
logger.addHandler(_nullh)
@@ -90,9 +99,11 @@ def __init__(
9099
for future in self._futures:
91100
future.add_done_callback(self)
92101
else:
93-
self.__invoke_done_callback()
102+
self.__invoke_futures_done_logic()
103+
return
94104

95105
self.progress = [0, 0]
106+
96107
if (progress_Queue) and (total is not None):
97108
self.progress[1] = total
98109

@@ -111,7 +122,7 @@ def _progress(Q, this): # - thread to update progress indicator
111122

112123
self._progress_fn = _progress
113124
self._progress_thread = threading.Thread(
114-
target=self._progress_fn, args=(progress_Queue, self)
125+
target=self._progress_fn, args=(progress_Queue, self), daemon=True
115126
)
116127
self._progress_thread.start()
117128

@@ -152,11 +163,13 @@ def __call__(
152163
with self._lock:
153164
self._futures_done[future] = future.result()
154165
if len(self._futures) == len(self._futures_done):
155-
self.__invoke_done_callback()
166+
self.__invoke_futures_done_logic(
167+
skip_user_callback=(None in self._futures_done.values())
168+
)
156169

157-
def __invoke_done_callback(self):
170+
def __invoke_futures_done_logic(self, skip_user_callback=False):
158171
try:
159-
if callable(self.done_callback):
172+
if not skip_user_callback and callable(self.done_callback):
160173
self.done_callback(self)
161174
finally:
162175
self.keep.pop("mgr", None)
@@ -239,6 +252,9 @@ def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()):
239252
bytecount = 0
240253
accum = 0
241254
while True and bytecount < length:
255+
if mgr._quit:
256+
bytecount = None
257+
break
242258
buf = src.read(min(COPY_BUF_SIZE, length - bytecount))
243259
buf_len = len(buf)
244260
if 0 == buf_len:
@@ -274,11 +290,16 @@ class _Multipart_close_manager:
274290
"""
275291

276292
def __init__(self, initial_io_, exit_barrier_):
293+
self._quit = False
277294
self.exit_barrier = exit_barrier_
278295
self.initial_io = initial_io_
279296
self.__lock = threading.Lock()
280297
self.aux = []
281298

299+
def quit(self):
300+
self._quit = True
301+
self.exit_barrier.abort()
302+
282303
def __contains__(self, Io):
283304
with self.__lock:
284305
return Io is self.initial_io or Io in self.aux
@@ -303,8 +324,12 @@ def remove_io(self, Io):
303324
Io.close()
304325
self.aux.remove(Io)
305326
is_initial = False
306-
self.exit_barrier.wait()
307-
if is_initial:
327+
broken = False
328+
try:
329+
self.exit_barrier.wait()
330+
except threading.BrokenBarrierError:
331+
broken = True
332+
if is_initial and not (broken or self._quit):
308333
self.finalize()
309334

310335
def finalize(self):
@@ -439,13 +464,19 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk):
439464
Io = File = None
440465

441466
if Operation.isNonBlocking():
442-
if queueLength:
443-
return futures, queueObject, mgr
444-
else:
445-
return futures
467+
return futures, queueObject, mgr
446468
else:
447-
bytecounts = [f.result() for f in futures]
448-
return sum(bytecounts), total_size
469+
bytes_transferred = 0
470+
try:
471+
bytecounts = [f.result() for f in futures]
472+
if None not in bytecounts:
473+
bytes_transferred = sum(bytecounts)
474+
except KeyboardInterrupt:
475+
if any(not f.done() for f in futures):
476+
# Induce any threads still alive to quit the transfer and exit.
477+
mgr.quit()
478+
raise
479+
return bytes_transferred, total_size
449480

450481

451482
def io_main(session, Data, opr_, fname, R="", **kwopt):
@@ -558,10 +589,10 @@ def io_main(session, Data, opr_, fname, R="", **kwopt):
558589

559590
if Operation.isNonBlocking():
560591

561-
if queueLength > 0:
562-
(futures, chunk_notify_queue, mgr) = retval
563-
else:
564-
futures = retval
592+
(futures, chunk_notify_queue, mgr) = retval
593+
transfer_managers[mgr] = None
594+
595+
if queueLength <= 0:
565596
chunk_notify_queue = total_bytes = None
566597

567598
return AsyncNotify(

irods/test/data_obj_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2955,6 +2955,13 @@ def test_replica_truncate__issue_534(self):
29552955
if data_objs.exists(data_path):
29562956
data_objs.unlink(data_path, force=True)
29572957

2958+
def test_handling_of_termination_signals_during_multithread_get__issue_722(self):
2959+
from irods.test.modules.test_signal_handling_in_multithread_get import (
2960+
test as test__issue_722,
2961+
)
2962+
2963+
test__issue_722(self)
2964+
29582965

29592966
if __name__ == "__main__":
29602967
# let the tests find the parent irods lib
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import os
2+
import re
3+
import signal
4+
import subprocess
5+
import sys
6+
import tempfile
7+
import time
8+
9+
import irods
10+
import irods.helpers
11+
from irods.test import modules as test_modules
12+
13+
OBJECT_SIZE = 2 * 1024**3
14+
OBJECT_NAME = "data_get_issue__722"
15+
LOCAL_TEMPFILE_NAME = "data_object_for_issue_722.dat"
16+
17+
18+
_clock_polling_interval = max(0.01, time.clock_getres(time.CLOCK_BOOTTIME))
19+
20+
21+
def wait_till_true(function, timeout=None):
22+
start_time = time.clock_gettime_ns(time.CLOCK_BOOTTIME)
23+
while not (truth_value := function()):
24+
if (
25+
timeout is not None
26+
and (time.clock_gettime_ns(time.CLOCK_BOOTTIME) - start_time) * 1e-9
27+
> timeout
28+
):
29+
break
30+
time.sleep(_clock_polling_interval)
31+
return truth_value
32+
33+
34+
def test(test_case, signal_names=("SIGTERM", "SIGINT")):
35+
"""Creates a child process executing a long get() and ensures the process can be
36+
terminated using SIGINT or SIGTERM.
37+
"""
38+
program = os.path.join(test_modules.__path__[0], os.path.basename(__file__))
39+
40+
for signal_name in signal_names:
41+
# Call into this same module as a command. This will initiate another Python process that
42+
# performs a lengthy data object "get" operation (see the main body of the script, below.)
43+
process = subprocess.Popen(
44+
[sys.executable, program],
45+
stderr=subprocess.PIPE,
46+
stdout=subprocess.PIPE,
47+
text=True,
48+
)
49+
50+
# Wait for download process to reach the point of spawning data transfer threads. In Python 3.9+ versions
51+
# of the concurrent.futures module, these are nondaemon threads and will block the exit of the main thread
52+
# unless measures are taken (#722).
53+
localfile = process.stdout.readline().strip()
54+
test_case.assertTrue(
55+
wait_till_true(
56+
lambda: os.path.exists(localfile)
57+
and os.stat(localfile).st_size > OBJECT_SIZE // 2
58+
),
59+
"Parallel download from data_objects.get() probably experienced a fatal error before spawning auxiliary data transfer threads.",
60+
)
61+
62+
signal_message_info = f"While testing signal {signal_name}"
63+
sig = getattr(signal, signal_name)
64+
65+
# Interrupt the subprocess with the given signal.
66+
process.send_signal(sig)
67+
# Assert that this signal is what killed the subprocess, rather than a timed out process "wait" or a natural exit
68+
# due to misproper or incomplete handling of the signal.
69+
try:
70+
test_case.assertEqual(
71+
process.wait(timeout=15),
72+
-sig,
73+
"{signal_message_info}: unexpected subprocess return code.",
74+
)
75+
except subprocess.TimeoutExpired as timeout_exc:
76+
test_case.fail(
77+
f"{signal_message_info}: subprocess timed out before terminating. "
78+
"Non-daemon thread(s) probably prevented subprocess's main thread from exiting."
79+
)
80+
# Assert that in the case of SIGINT, the process registered a KeyboardInterrupt.
81+
if sig == signal.SIGINT:
82+
test_case.assertTrue(
83+
re.search("KeyboardInterrupt", process.stderr.read()),
84+
"{signal_message_info}: Expected 'KeyboardInterrupt' in log output.",
85+
)
86+
87+
88+
if __name__ == "__main__":
89+
# These lines are run only if the module is launched as a process.
90+
session = irods.helpers.make_session()
91+
hc = irods.helpers.home_collection(session)
92+
TESTFILE_FILL = b"_" * (1024 * 1024)
93+
object_path = f"{hc}/{OBJECT_NAME}"
94+
95+
# Create the object to be downloaded.
96+
with session.data_objects.open(object_path, "w") as f:
97+
for y in range(OBJECT_SIZE // len(TESTFILE_FILL)):
98+
f.write(TESTFILE_FILL)
99+
local_path = None
100+
# Establish where (ie absolute path) to place the downloaded file, i.e. the get() target.
101+
try:
102+
with tempfile.NamedTemporaryFile(
103+
prefix="local_file_issue_722.dat", delete=True
104+
) as t:
105+
local_path = t.name
106+
107+
# Tell the parent process the name of the local file being "get"ted (got) from iRODS
108+
print(local_path)
109+
sys.stdout.flush()
110+
111+
# "get" the object
112+
session.data_objects.get(object_path, local_path)
113+
finally:
114+
# Clean up, whether or not the download succeeded.
115+
if local_path is not None and os.path.exists(local_path):
116+
os.unlink(local_path)
117+
if session.data_objects.exists(object_path):
118+
session.data_objects.unlink(object_path, force=True)

0 commit comments

Comments
 (0)