@@ -575,17 +575,17 @@ def connect_check_health(
575575 return
576576 try :
577577 if retry_socket_connect :
578- sock = self .retry .call_with_retry (
579- lambda : self ._connect (), lambda error : self .disconnect (error )
578+ self .retry .call_with_retry (
579+ lambda : self ._connect (),
580+ lambda error : self .disconnect (error ),
580581 )
581582 else :
582- sock = self ._connect ()
583+ self ._connect ()
583584 except socket .timeout :
584585 raise TimeoutError ("Timeout connecting to server" )
585586 except OSError as e :
586587 raise ConnectionError (self ._error_message (e ))
587588
588- self ._sock = sock
589589 try :
590590 if self .redis_connect_func is None :
591591 # Use the default on_connect function
@@ -608,7 +608,7 @@ def connect_check_health(
608608 callback (self )
609609
610610 @abstractmethod
611- def _connect (self ):
611+ def _connect (self ) -> None :
612612 pass
613613
614614 @abstractmethod
@@ -626,6 +626,12 @@ def on_connect_check_health(self, check_health: bool = True):
626626 self ._parser .on_connect (self )
627627 parser = self ._parser
628628
629+ if check_health :
630+ self .retry .call_with_retry (
631+ lambda : self ._send_ping (),
632+ lambda error : self .disconnect (error ),
633+ )
634+
629635 auth_args = None
630636 # if credential provider or username and/or password are set, authenticate
631637 if self .credential_provider or (self .username or self .password ):
@@ -680,7 +686,7 @@ def on_connect_check_health(self, check_health: bool = True):
680686 # update cluster exception classes
681687 self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
682688 self ._parser .on_connect (self )
683- self .send_command ("HELLO" , self .protocol , check_health = check_health )
689+ self .send_command ("HELLO" , self .protocol , check_health = False )
684690 self .handshake_metadata = self .read_response ()
685691 if (
686692 self .handshake_metadata .get (b"proto" ) != self .protocol
@@ -711,7 +717,7 @@ def on_connect_check_health(self, check_health: bool = True):
711717 "ON" ,
712718 "moving-endpoint-type" ,
713719 endpoint_type .value ,
714- check_health = check_health ,
720+ check_health = False ,
715721 )
716722 response = self .read_response ()
717723 if str_if_bytes (response ) != "OK" :
@@ -737,7 +743,7 @@ def on_connect_check_health(self, check_health: bool = True):
737743 "CLIENT" ,
738744 "SETNAME" ,
739745 self .client_name ,
740- check_health = check_health ,
746+ check_health = False ,
741747 )
742748 if str_if_bytes (self .read_response ()) != "OK" :
743749 raise ConnectionError ("Error setting client name" )
@@ -750,7 +756,7 @@ def on_connect_check_health(self, check_health: bool = True):
750756 "SETINFO" ,
751757 "LIB-NAME" ,
752758 self .lib_name ,
753- check_health = check_health ,
759+ check_health = False ,
754760 )
755761 self .read_response ()
756762 except ResponseError :
@@ -763,15 +769,15 @@ def on_connect_check_health(self, check_health: bool = True):
763769 "SETINFO" ,
764770 "LIB-VER" ,
765771 self .lib_version ,
766- check_health = check_health ,
772+ check_health = False ,
767773 )
768774 self .read_response ()
769775 except ResponseError :
770776 pass
771777
772778 # if a database is specified, switch to it
773779 if self .db :
774- self .send_command ("SELECT" , self .db , check_health = check_health )
780+ self .send_command ("SELECT" , self .db , check_health = False )
775781 if str_if_bytes (self .read_response ()) != "OK" :
776782 raise ConnectionError ("Invalid Database" )
777783
@@ -800,8 +806,15 @@ def disconnect(self, *args):
800806 def _send_ping (self ):
801807 """Send PING, expect PONG in return"""
802808 self .send_command ("PING" , check_health = False )
803- if str_if_bytes (self .read_response ()) != "PONG" :
804- raise ConnectionError ("Bad response from PING health check" )
809+ try :
810+ # Do not disconnect on error here, since we want to keep the connection in case of AuthenticationError
811+ # since we are raising ConnectionError in all other cases and ping_failed already disconnects,
812+ # connection reload is already handled
813+ if str_if_bytes (self .read_response (disconnect_on_error = False )) != "PONG" :
814+ raise ConnectionError ("Bad response from PING health check" )
815+ except AuthenticationError :
816+ # if we get an authentication error, the server is healthy
817+ pass
805818
806819 def _ping_failed (self , error ):
807820 """Function to call when PING fails"""
@@ -1097,7 +1110,7 @@ def repr_pieces(self):
10971110 pieces .append (("client_name" , self .client_name ))
10981111 return pieces
10991112
1100- def _connect (self ):
1113+ def _connect (self ) -> None :
11011114 "Create a TCP socket connection"
11021115 # we want to mimic what socket.create_connection does to support
11031116 # ipv4/ipv6, but we want to set options prior to calling
@@ -1128,7 +1141,8 @@ def _connect(self):
11281141
11291142 # set the socket_timeout now that we're connected
11301143 sock .settimeout (self .socket_timeout )
1131- return sock
1144+ self ._sock = sock
1145+ return
11321146
11331147 except OSError as _ :
11341148 err = _
@@ -1448,15 +1462,15 @@ def __init__(
14481462 self .ssl_ciphers = ssl_ciphers
14491463 super ().__init__ (** kwargs )
14501464
1451- def _connect (self ):
1465+ def _connect (self ) -> None :
14521466 """
14531467 Wrap the socket with SSL support, handling potential errors.
14541468 """
1455- sock = super ()._connect ()
1469+ super ()._connect ()
14561470 try :
1457- return self ._wrap_socket_with_ssl (sock )
1471+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
14581472 except (OSError , RedisError ):
1459- sock .close ()
1473+ self . _sock .close ()
14601474 raise
14611475
14621476 def _wrap_socket_with_ssl (self , sock ):
@@ -1559,7 +1573,7 @@ def repr_pieces(self):
15591573 pieces .append (("client_name" , self .client_name ))
15601574 return pieces
15611575
1562- def _connect (self ):
1576+ def _connect (self ) -> None :
15631577 "Create a Unix domain socket connection"
15641578 sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
15651579 sock .settimeout (self .socket_connect_timeout )
@@ -1574,7 +1588,7 @@ def _connect(self):
15741588 sock .close ()
15751589 raise
15761590 sock .settimeout (self .socket_timeout )
1577- return sock
1591+ self . _sock = sock
15781592
15791593 def _host_error (self ):
15801594 return self .path
0 commit comments