@@ -332,6 +332,7 @@ def __init__(
332332 encoding : str = "utf-8" ,
333333 encoding_errors : str = "strict" ,
334334 decode_responses : bool = False ,
335+ check_server_ready : bool = False ,
335336 parser_class = DefaultParser ,
336337 socket_read_size : int = 65536 ,
337338 health_check_interval : int = 0 ,
@@ -408,6 +409,7 @@ def __init__(
408409 self .redis_connect_func = redis_connect_func
409410 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
410411 self .handshake_metadata = None
412+ self .check_server_ready = check_server_ready
411413 self ._sock = None
412414 self ._socket_read_size = socket_read_size
413415 self ._connect_callbacks = []
@@ -571,17 +573,17 @@ def connect_check_health(
571573 return
572574 try :
573575 if retry_socket_connect :
574- sock = self .retry .call_with_retry (
575- lambda : self ._connect (), lambda error : self .disconnect (error )
576+ self .retry .call_with_retry (
577+ lambda : self ._connect_check_server_ready (),
578+ lambda error : self .disconnect (error ),
576579 )
577580 else :
578- sock = self ._connect ()
581+ self ._connect_check_server_ready ()
579582 except socket .timeout :
580583 raise TimeoutError ("Timeout connecting to server" )
581584 except OSError as e :
582585 raise ConnectionError (self ._error_message (e ))
583586
584- self ._sock = sock
585587 try :
586588 if self .redis_connect_func is None :
587589 # Use the default on_connect function
@@ -603,8 +605,27 @@ def connect_check_health(
603605 if callback :
604606 callback (self )
605607
608+ def _connect_check_server_ready (self ):
609+ self ._connect ()
610+
611+ # Doing handshake since connect and send operations work even when Redis is not ready
612+ if self .check_server_ready :
613+ try :
614+ self .send_command ("PING" , check_health = False )
615+
616+ response = str_if_bytes (self ._sock .recv (1024 ))
617+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
618+ raise ResponseError (f"Invalid PING response: { response } " )
619+ except (ConnectionResetError , ResponseError ) as err :
620+ try :
621+ self ._sock .shutdown (socket .SHUT_RDWR ) # ensure a clean close
622+ except OSError :
623+ pass
624+ self ._sock .close ()
625+ raise ConnectionError (self ._error_message (err ))
626+
606627 @abstractmethod
607- def _connect (self ):
628+ def _connect (self ) -> None :
608629 pass
609630
610631 @abstractmethod
@@ -1083,7 +1104,7 @@ def repr_pieces(self):
10831104 pieces .append (("client_name" , self .client_name ))
10841105 return pieces
10851106
1086- def _connect (self ):
1107+ def _connect (self ) -> None :
10871108 "Create a TCP socket connection"
10881109 # we want to mimic what socket.create_connection does to support
10891110 # ipv4/ipv6, but we want to set options prior to calling
@@ -1114,7 +1135,8 @@ def _connect(self):
11141135
11151136 # set the socket_timeout now that we're connected
11161137 sock .settimeout (self .socket_timeout )
1117- return sock
1138+ self ._sock = sock
1139+ return
11181140
11191141 except OSError as _ :
11201142 err = _
@@ -1427,15 +1449,15 @@ def __init__(
14271449 self .ssl_ciphers = ssl_ciphers
14281450 super ().__init__ (** kwargs )
14291451
1430- def _connect (self ):
1452+ def _connect (self ) -> None :
14311453 """
14321454 Wrap the socket with SSL support, handling potential errors.
14331455 """
1434- sock = super ()._connect ()
1456+ super ()._connect ()
14351457 try :
1436- return self ._wrap_socket_with_ssl (sock )
1458+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
14371459 except (OSError , RedisError ):
1438- sock .close ()
1460+ self . _sock .close ()
14391461 raise
14401462
14411463 def _wrap_socket_with_ssl (self , sock ):
@@ -1532,7 +1554,7 @@ def repr_pieces(self):
15321554 pieces .append (("client_name" , self .client_name ))
15331555 return pieces
15341556
1535- def _connect (self ):
1557+ def _connect (self ) -> None :
15361558 "Create a Unix domain socket connection"
15371559 sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
15381560 sock .settimeout (self .socket_connect_timeout )
@@ -1547,7 +1569,7 @@ def _connect(self):
15471569 sock .close ()
15481570 raise
15491571 sock .settimeout (self .socket_timeout )
1550- return sock
1572+ self . _sock = sock
15511573
15521574 def _host_error (self ):
15531575 return self .path
0 commit comments