@@ -17,6 +17,7 @@ import NIO
1717import NIOConcurrencyHelpers
1818import NIOFoundationCompat
1919import NIOHTTP1
20+ import NIOHTTPCompression
2021import NIOSSL
2122
2223extension HTTPClient {
@@ -486,22 +487,31 @@ extension URL {
486487extension HTTPClient {
487488 /// Response execution context. Will be created by the library and could be used for obtaining
488489 /// `EventLoopFuture<Response>` of the execution or cancellation of the execution.
489- public final class Task < Response> {
490+ public final class Task < Response> : TaskProtocol {
490491 /// The `EventLoop` the delegate will be executed on.
491492 public let eventLoop : EventLoop
492493
493494 let promise : EventLoopPromise < Response >
494- var channel : Channel ?
495- private var cancelled : Bool
496- private let lock : Lock
495+ var completion : EventLoopFuture < Void >
496+ var connection : ConnectionPool . Connection ?
497+ var cancelled : Bool
498+ let lock : Lock
499+ let id = UUID ( )
497500
498501 init ( eventLoop: EventLoop ) {
499502 self . eventLoop = eventLoop
500503 self . promise = eventLoop. makePromise ( )
504+ self . completion = self . promise. futureResult. map { _ in }
501505 self . cancelled = false
502506 self . lock = Lock ( )
503507 }
504508
509+ static func failedTask( eventLoop: EventLoop , error: Error ) -> Task < Response > {
510+ let task = self . init ( eventLoop: eventLoop)
511+ task. promise. fail ( error)
512+ return task
513+ }
514+
505515 /// `EventLoopFuture` for the response returned by this request.
506516 public var futureResult : EventLoopFuture < Response > {
507517 return self . promise. futureResult
@@ -520,28 +530,74 @@ extension HTTPClient {
520530 let channel : Channel ? = self . lock. withLock {
521531 if !cancelled {
522532 cancelled = true
523- return self . channel
533+ return self . connection? . channel
534+ } else {
535+ return nil
524536 }
525- return nil
526537 }
527538 channel? . triggerUserOutboundEvent ( TaskCancelEvent ( ) , promise: nil )
528539 }
529540
530541 @discardableResult
531- func setChannel ( _ channel : Channel ) -> Channel {
542+ func setConnection ( _ connection : ConnectionPool . Connection ) -> ConnectionPool . Connection {
532543 return self . lock. withLock {
533- self . channel = channel
534- return channel
544+ self . connection = connection
545+ if self . cancelled {
546+ connection. channel. triggerUserOutboundEvent ( TaskCancelEvent ( ) , promise: nil )
547+ }
548+ return connection
549+ }
550+ }
551+
552+ func succeed< Delegate: HTTPClientResponseDelegate > ( promise: EventLoopPromise < Response > ? , with value: Response , delegateType: Delegate . Type ) {
553+ self . releaseAssociatedConnection ( delegateType: delegateType) . whenSuccess {
554+ promise? . succeed ( value)
555+ }
556+ }
557+
558+ func fail< Delegate: HTTPClientResponseDelegate > ( with error: Error , delegateType: Delegate . Type ) {
559+ if let connection = self . connection {
560+ connection. close ( ) . whenComplete { _ in
561+ self . releaseAssociatedConnection ( delegateType: delegateType) . whenComplete { _ in
562+ self . promise. fail ( error)
563+ }
564+ }
565+ }
566+ }
567+
568+ func releaseAssociatedConnection< Delegate: HTTPClientResponseDelegate > ( delegateType: Delegate . Type ) -> EventLoopFuture < Void > {
569+ if let connection = self . connection {
570+ return connection. removeHandler ( NIOHTTPResponseDecompressor . self) . flatMap {
571+ connection. removeHandler ( IdleStateHandler . self)
572+ } . flatMap {
573+ connection. removeHandler ( TaskHandler< Delegate> . self )
574+ } . map {
575+ connection. release ( )
576+ } . flatMapError { error in
577+ fatalError ( " Couldn't remove taskHandler: \( error) " )
578+ }
579+
580+ } else {
581+ // TODO: This seems only reached in some internal unit test
582+ // Maybe there could be a better handling in the future to make
583+ // it an error outside of testing contexts
584+ return self . eventLoop. makeSucceededFuture ( ( ) )
535585 }
536586 }
537587 }
538588}
539589
540590internal struct TaskCancelEvent { }
541591
592+ internal protocol TaskProtocol {
593+ func cancel( )
594+ var id : UUID { get }
595+ var completion : EventLoopFuture < Void > { get }
596+ }
597+
542598// MARK: - TaskHandler
543599
544- internal class TaskHandler < Delegate: HTTPClientResponseDelegate > {
600+ internal class TaskHandler < Delegate: HTTPClientResponseDelegate > : RemovableChannelHandler {
545601 enum State {
546602 case idle
547603 case sent
@@ -581,7 +637,7 @@ extension TaskHandler {
581637 _ body: @escaping ( HTTPClient . Task < Delegate . Response > , Err ) -> Void ) {
582638 func doIt( ) {
583639 body ( self . task, error)
584- self . task. promise . fail ( error)
640+ self . task. fail ( with : error, delegateType : Delegate . self )
585641 }
586642
587643 if self . task. eventLoop. inEventLoop {
@@ -621,13 +677,14 @@ extension TaskHandler {
621677 }
622678
623679 func callOutToDelegate< Response> ( promise: EventLoopPromise < Response > ? = nil ,
624- _ body: @escaping ( HTTPClient . Task < Delegate . Response > ) throws -> Response ) {
680+ _ body: @escaping ( HTTPClient . Task < Delegate . Response > ) throws -> Response ) where Response == Delegate . Response {
625681 func doIt( ) {
626682 do {
627683 let result = try body ( self . task)
628- promise? . succeed ( result)
684+
685+ self . task. succeed ( promise: promise, with: result, delegateType: Delegate . self)
629686 } catch {
630- promise ? . fail ( error)
687+ self . task . fail ( with : error, delegateType : Delegate . self )
631688 }
632689 }
633690
@@ -641,7 +698,7 @@ extension TaskHandler {
641698 }
642699
643700 func callOutToDelegate< Response> ( channelEventLoop: EventLoop ,
644- _ body: @escaping ( HTTPClient . Task < Delegate . Response > ) throws -> Response ) -> EventLoopFuture < Response > {
701+ _ body: @escaping ( HTTPClient . Task < Delegate . Response > ) throws -> Response ) -> EventLoopFuture < Response > where Response == Delegate . Response {
645702 let promise = channelEventLoop. makePromise ( of: Response . self)
646703 self . callOutToDelegate ( promise: promise, body)
647704 return promise. futureResult
@@ -678,8 +735,6 @@ extension TaskHandler: ChannelDuplexHandler {
678735 headers. add ( name: " Host " , value: request. host)
679736 }
680737
681- headers. add ( name: " Connection " , value: " close " )
682-
683738 do {
684739 try headers. validate ( body: request. body)
685740 } catch {
@@ -702,16 +757,10 @@ extension TaskHandler: ChannelDuplexHandler {
702757 context. eventLoop. assertInEventLoop ( )
703758 self . state = . sent
704759 self . callOutToDelegateFireAndForget ( self . delegate. didSendRequest)
705-
706- let channel = context. channel
707- self . task. futureResult. whenComplete { _ in
708- channel. close ( promise: nil )
709- }
710760 } . flatMapErrorThrowing { error in
711761 context. eventLoop. assertInEventLoop ( )
712762 self . state = . end
713763 self . failTaskAndNotifyDelegate ( error: error, self . delegate. didReceiveError)
714- context. close ( promise: nil )
715764 throw error
716765 } . cascade ( to: promise)
717766 }
@@ -742,6 +791,16 @@ extension TaskHandler: ChannelDuplexHandler {
742791 let response = self . unwrapInboundIn ( data)
743792 switch response {
744793 case . head( let head) :
794+ if !head. isKeepAlive {
795+ self . task. lock. withLock {
796+ if let connection = self . task. connection {
797+ connection. isClosing = true
798+ } else {
799+ preconditionFailure ( " There should always be a connection at this point " )
800+ }
801+ }
802+ }
803+
745804 if let redirectURL = redirectHandler? . redirectTarget ( status: head. status, headers: head. headers) {
746805 self . state = . redirected( head, redirectURL)
747806 } else {
@@ -768,8 +827,9 @@ extension TaskHandler: ChannelDuplexHandler {
768827 switch self . state {
769828 case . redirected( let head, let redirectURL) :
770829 self . state = . end
771- self . redirectHandler? . redirect ( status: head. status, to: redirectURL, promise: self . task. promise)
772- context. close ( promise: nil )
830+ self . task. releaseAssociatedConnection ( delegateType: Delegate . self) . whenSuccess {
831+ self . redirectHandler? . redirect ( status: head. status, to: redirectURL, promise: self . task. promise)
832+ }
773833 default :
774834 self . state = . end
775835 self . callOutToDelegate ( promise: self . task. promise, self . delegate. didFinishRequest)
@@ -845,6 +905,13 @@ extension TaskHandler: ChannelDuplexHandler {
845905 self . failTaskAndNotifyDelegate ( error: error, self . delegate. didReceiveError)
846906 }
847907 }
908+
909+ func handlerAdded( context: ChannelHandlerContext ) {
910+ guard context. channel. isActive else {
911+ self . failTaskAndNotifyDelegate ( error: HTTPClientError . remoteConnectionClosed, self . delegate. didReceiveError)
912+ return
913+ }
914+ }
848915}
849916
850917// MARK: - RedirectHandler
@@ -931,9 +998,13 @@ internal struct RedirectHandler<ResponseType> {
931998 do {
932999 var newRequest = try HTTPClient . Request ( url: redirectURL, method: method, headers: headers, body: body)
9331000 newRequest. redirectState = nextState
934- return self . execute ( newRequest) . futureResult. cascade ( to: promise)
1001+ self . execute ( newRequest) . futureResult. whenComplete { result in
1002+ promise. futureResult. eventLoop. execute {
1003+ promise. completeWith ( result)
1004+ }
1005+ }
9351006 } catch {
936- return promise. fail ( error)
1007+ promise. fail ( error)
9371008 }
9381009 }
9391010}
0 commit comments