3939import java .nio .channels .SocketChannel ;
4040import java .nio .channels .WritePendingException ;
4141import java .util .Iterator ;
42- import java .util .concurrent .CancellationException ;
42+ import java .util .concurrent .ConcurrentHashMap ;
4343import java .util .concurrent .ConcurrentLinkedQueue ;
4444import java .util .concurrent .CountDownLatch ;
4545import java .util .concurrent .ExecutorService ;
@@ -100,19 +100,15 @@ class RegisteredSocket {
100100 /** Bitwise union of pending operation to be registered in the selector */
101101 final AtomicInteger pendingOps = new AtomicInteger ();
102102
103- RegisteredSocket (TlsChannel tlsChannel , SocketChannel socketChannel )
104- throws ClosedChannelException {
103+ RegisteredSocket (TlsChannel tlsChannel , SocketChannel socketChannel ) {
105104 this .tlsChannel = tlsChannel ;
106105 this .socketChannel = socketChannel ;
107106 }
108107
109108 public void close () {
110- doCancelRead (this , null );
111- doCancelWrite (this , null );
112109 if (key != null ) {
113110 key .cancel ();
114111 }
115- currentRegistrations .getAndDecrement ();
116112 /*
117113 * Actual de-registration from the selector will happen asynchronously.
118114 */
@@ -195,8 +191,7 @@ private enum Shutdown {
195191 private LongAdder cancelledReads = new LongAdder ();
196192 private LongAdder cancelledWrites = new LongAdder ();
197193
198- // used for synchronization
199- private AtomicInteger currentRegistrations = new AtomicInteger ();
194+ private final ConcurrentHashMap <RegisteredSocket , Boolean > registrations = new ConcurrentHashMap <>();
200195
201196 private LongAdder currentReads = new LongAdder ();
202197 private LongAdder currentWrites = new LongAdder ();
@@ -232,13 +227,11 @@ public AsynchronousTlsChannelGroup() {
232227 this (Runtime .getRuntime ().availableProcessors ());
233228 }
234229
235- RegisteredSocket registerSocket (TlsChannel reader , SocketChannel socketChannel )
236- throws ClosedChannelException {
230+ RegisteredSocket registerSocket (TlsChannel reader , SocketChannel socketChannel ) {
237231 if (shutdown != Shutdown .No ) {
238232 throw new ShutdownChannelGroupException ();
239233 }
240234 RegisteredSocket socket = new RegisteredSocket (reader , socketChannel );
241- currentRegistrations .getAndIncrement ();
242235 pendingRegistrations .add (socket );
243236 selector .wakeup ();
244237 return socket ;
@@ -247,18 +240,13 @@ RegisteredSocket registerSocket(TlsChannel reader, SocketChannel socketChannel)
247240 boolean doCancelRead (RegisteredSocket socket , ReadOperation op ) {
248241 socket .readLock .lock ();
249242 try {
250- // a null op means cancel any operation
251- if (op != null && socket .readOperation == op || op == null && socket .readOperation != null ) {
252- if (op == null ) {
253- socket .readOperation .onFailure .accept (new CancellationException ());
254- }
255- socket .readOperation = null ;
256- cancelledReads .increment ();
257- currentReads .decrement ();
258- return true ;
259- } else {
243+ if (op != socket .readOperation ) {
260244 return false ;
261245 }
246+ socket .readOperation = null ;
247+ cancelledReads .increment ();
248+ currentReads .decrement ();
249+ return true ;
262250 } finally {
263251 socket .readLock .unlock ();
264252 }
@@ -267,18 +255,13 @@ boolean doCancelRead(RegisteredSocket socket, ReadOperation op) {
267255 boolean doCancelWrite (RegisteredSocket socket , WriteOperation op ) {
268256 socket .writeLock .lock ();
269257 try {
270- // a null op means cancel any operation
271- if (op != null && socket .writeOperation == op || op == null && socket .writeOperation != null ) {
272- if (op == null ) {
273- socket .writeOperation .onFailure .accept (new CancellationException ());
274- }
275- socket .writeOperation = null ;
276- cancelledWrites .increment ();
277- currentWrites .decrement ();
278- return true ;
279- } else {
258+ if (op != socket .writeOperation ) {
280259 return false ;
281260 }
261+ socket .writeOperation = null ;
262+ cancelledWrites .increment ();
263+ currentWrites .decrement ();
264+ return true ;
282265 } finally {
283266 socket .writeLock .unlock ();
284267 }
@@ -295,13 +278,23 @@ ReadOperation startRead(
295278 checkTerminated ();
296279 Util .assertTrue (buffer .hasRemaining ());
297280 waitForSocketRegistration (socket );
298- ReadOperation op ;
299281 socket .readLock .lock ();
300282 try {
301283 if (socket .readOperation != null ) {
302284 throw new ReadPendingException ();
303285 }
304- op = new ReadOperation (buffer , onSuccess , onFailure );
286+ ReadOperation op = new ReadOperation (buffer , onSuccess , onFailure );
287+
288+ startedReads .increment ();
289+ currentReads .increment ();
290+
291+ if (!registrations .containsKey (socket )) {
292+ op .onFailure .accept (new ClosedChannelException ());
293+ failedReads .increment ();
294+ currentReads .decrement ();
295+ return op ;
296+ }
297+
305298 /*
306299 * we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
307300 * operation
@@ -324,9 +317,7 @@ ReadOperation startRead(
324317 socket .readLock .unlock ();
325318 }
326319 selector .wakeup ();
327- startedReads .increment ();
328- currentReads .increment ();
329- return op ;
320+ return socket .readOperation ;
330321 }
331322
332323 WriteOperation startWrite (
@@ -340,13 +331,23 @@ WriteOperation startWrite(
340331 checkTerminated ();
341332 Util .assertTrue (buffer .hasRemaining ());
342333 waitForSocketRegistration (socket );
343- WriteOperation op ;
344334 socket .writeLock .lock ();
345335 try {
346336 if (socket .writeOperation != null ) {
347337 throw new WritePendingException ();
348338 }
349- op = new WriteOperation (buffer , onSuccess , onFailure );
339+ WriteOperation op = new WriteOperation (buffer , onSuccess , onFailure );
340+
341+ startedWrites .increment ();
342+ currentWrites .increment ();
343+
344+ if (!registrations .containsKey (socket )) {
345+ op .onFailure .accept (new ClosedChannelException ());
346+ failedWrites .increment ();
347+ currentWrites .decrement ();
348+ return op ;
349+ }
350+
350351 /*
351352 * we do not try to outsmart the TLS state machine and register for both IO operations for each new socket
352353 * operation
@@ -369,9 +370,7 @@ WriteOperation startWrite(
369370 socket .writeLock .unlock ();
370371 }
371372 selector .wakeup ();
372- startedWrites .increment ();
373- currentWrites .increment ();
374- return op ;
373+ return socket .writeOperation ;
375374 }
376375
377376 private void checkTerminated () {
@@ -391,8 +390,11 @@ private void waitForSocketRegistration(RegisteredSocket socket) {
391390 private void loop () {
392391 try {
393392 while (shutdown == Shutdown .No
394- || shutdown == Shutdown .Wait && currentRegistrations .intValue () > 0 ) {
395- int c = selector .select (); // block
393+ || shutdown == Shutdown .Wait
394+ && (!pendingRegistrations .isEmpty () || !registrations .isEmpty ())) {
395+ // most state-changing operations will wake the selector up, however, asynchronous closings
396+ // of the channels won't, so we have to timeout to allow checking those cases
397+ int c = selector .select (100 ); // block
396398 selectionCount .increment ();
397399 // avoid unnecessary creation of iterator object
398400 if (c > 0 ) {
@@ -413,24 +415,20 @@ private void loop() {
413415 }
414416 registerPendingSockets ();
415417 processPendingInterests ();
418+ checkClosings ();
416419 }
417420 } catch (Throwable e ) {
418421 LOGGER .error ("error in selector loop" , e );
419422 } finally {
420423 executor .shutdown ();
421424 // use shutdownNow to stop delayed tasks
422425 timeoutExecutor .shutdownNow ();
423- if (shutdown == Shutdown .Immediate ) {
424- for (SelectionKey key : selector .keys ()) {
425- RegisteredSocket socket = (RegisteredSocket ) key .attachment ();
426- socket .close ();
427- }
428- }
429426 try {
430427 selector .close ();
431428 } catch (IOException e ) {
432429 LOGGER .warn ("error closing selector: " + e .getMessage ());
433430 }
431+ checkClosings ();
434432 }
435433 }
436434
@@ -606,14 +604,67 @@ private long readHandlingTasks(RegisteredSocket socket, ReadOperation op) throws
606604 }
607605 }
608606
609- private void registerPendingSockets () throws ClosedChannelException {
607+ private void registerPendingSockets () {
610608 RegisteredSocket socket ;
611609 while ((socket = pendingRegistrations .poll ()) != null ) {
612- socket .key = socket .socketChannel .register (selector , 0 , socket );
613- if (LOGGER .isTraceEnabled ()) {
614- LOGGER .trace ("registered key: " + socket .key );
610+ try {
611+ socket .key = socket .socketChannel .register (selector , 0 , socket );
612+ registrations .put (socket , true );
613+ } catch (ClosedChannelException e ) {
614+ // can happen when channels are closed right after creation
615+ } finally {
616+ // decrement the count of the latch even in case of exceptions, so the waiting thread
617+ // is unlocked; it will have to check the result, though
618+ socket .registered .countDown ();
619+ }
620+ }
621+ }
622+
623+ /**
624+ * Channels that are closed asynchronously are silently removed from selectors. This method will
625+ * check them using the internal catalog and do the proper cleanup.
626+ */
627+ private void checkClosings () {
628+ for (RegisteredSocket socket : registrations .keySet ()) {
629+ if (!socket .key .isValid () || shutdown == Shutdown .Immediate ) {
630+ registrations .remove (socket );
631+ failCurrentRead (socket );
632+ failCurrentWrite (socket );
615633 }
616- socket .registered .countDown ();
634+ }
635+ }
636+
637+ private void failCurrentRead (RegisteredSocket socket ) {
638+ socket .readLock .lock ();
639+ try {
640+ if (socket .readOperation != null ) {
641+ socket .readOperation .onFailure .accept (new ClosedChannelException ());
642+ if (socket .readOperation .timeoutFuture != null ) {
643+ socket .readOperation .timeoutFuture .cancel (false );
644+ }
645+ socket .readOperation = null ;
646+ failedReads .increment ();
647+ currentReads .decrement ();
648+ }
649+ } finally {
650+ socket .readLock .unlock ();
651+ }
652+ }
653+
654+ private void failCurrentWrite (RegisteredSocket socket ) {
655+ socket .writeLock .lock ();
656+ try {
657+ if (socket .writeOperation != null ) {
658+ socket .writeOperation .onFailure .accept (new ClosedChannelException ());
659+ if (socket .writeOperation .timeoutFuture != null ) {
660+ socket .writeOperation .timeoutFuture .cancel (false );
661+ }
662+ socket .writeOperation = null ;
663+ failedWrites .increment ();
664+ currentWrites .decrement ();
665+ }
666+ } finally {
667+ socket .writeLock .unlock ();
617668 }
618669 }
619670
@@ -769,6 +820,6 @@ public long getCurrentWriteCount() {
769820 * @return number of sockets
770821 */
771822 public long getCurrentRegistrationCount () {
772- return currentRegistrations . longValue ();
823+ return registrations . mappingCount ();
773824 }
774825}
0 commit comments