2929# limitations under the License.
3030
3131import asyncio
32+ from asyncio import CancelledError
3233from collections import OrderedDict
3334from typing import Union
3435
@@ -57,7 +58,8 @@ def connection_made(self, transport: asyncio.WriteTransport) -> None:
5758 try :
5859 self .__send_handshake (transport , self ._conn )
5960 except Exception as e :
60- self ._handshake_fut .set_exception (e )
61+ if not self ._handshake_fut .done ():
62+ self ._handshake_fut .set_exception (e )
6163
6264 def data_received (self , data : bytes ) -> None :
6365 self ._buffer += data
@@ -67,7 +69,7 @@ def data_received(self, data: bytes) -> None:
6769 if not self ._handshake_fut .done ():
6870 hs_response = self .__parse_handshake (packet , self ._conn .client )
6971 self ._handshake_fut .set_result (hs_response )
70- else :
72+ elif not self . _handshake_fut . cancelled () or not self . _handshake_fut . exception () :
7173 self ._conn .process_message (packet )
7274 self ._buffer = self ._buffer [packet_sz :len (self ._buffer )]
7375
@@ -203,7 +205,8 @@ async def _connect(self):
203205 def process_connection_lost (self , err , reconnect = False ):
204206 self .failed = True
205207 for _ , fut in self ._pending_reqs .items ():
206- fut .set_exception (err )
208+ if not fut .done ():
209+ fut .set_exception (err )
207210 self ._pending_reqs .clear ()
208211
209212 if self ._transport_closed_fut and not self ._transport_closed_fut .done ():
@@ -215,8 +218,11 @@ def process_connection_lost(self, err, reconnect=False):
215218
216219 def process_message (self , data ):
217220 req_id = int .from_bytes (data [4 :12 ], byteorder = PROTOCOL_BYTE_ORDER , signed = True )
218- if req_id in self ._pending_reqs :
219- self ._pending_reqs [req_id ].set_result (data )
221+
222+ req_fut = self ._pending_reqs .get (req_id )
223+ if req_fut :
224+ if not req_fut .done ():
225+ req_fut .set_result (data )
220226 del self ._pending_reqs [req_id ]
221227
222228 async def _connect_version (self ) -> Union [dict , OrderedDict ]:
0 commit comments