@@ -237,6 +237,7 @@ def __init__(
237237 encoding : str = "utf-8" ,
238238 encoding_errors : str = "strict" ,
239239 decode_responses : bool = False ,
240+ check_server_ready : bool = False ,
240241 parser_class = DefaultParser ,
241242 socket_read_size : int = 65536 ,
242243 health_check_interval : int = 0 ,
@@ -303,6 +304,7 @@ def __init__(
303304 self .redis_connect_func = redis_connect_func
304305 self .encoder = Encoder (encoding , encoding_errors , decode_responses )
305306 self .handshake_metadata = None
307+ self .check_server_ready = check_server_ready
306308 self ._sock = None
307309 self ._socket_read_size = socket_read_size
308310 self .set_parser (parser_class )
@@ -386,17 +388,17 @@ def connect_check_health(
386388 return
387389 try :
388390 if retry_socket_connect :
389- sock = self .retry .call_with_retry (
390- lambda : self ._connect (), lambda error : self .disconnect (error )
391+ self .retry .call_with_retry (
392+ lambda : self ._connect_check_server_ready (),
393+ lambda error : self .disconnect (error ),
391394 )
392395 else :
393- sock = self ._connect ()
396+ self ._connect_check_server_ready ()
394397 except socket .timeout :
395398 raise TimeoutError ("Timeout connecting to server" )
396399 except OSError as e :
397400 raise ConnectionError (self ._error_message (e ))
398401
399- self ._sock = sock
400402 try :
401403 if self .redis_connect_func is None :
402404 # Use the default on_connect function
@@ -418,8 +420,27 @@ def connect_check_health(
418420 if callback :
419421 callback (self )
420422
423+ def _connect_check_server_ready (self ):
424+ self ._connect ()
425+
426+ # Doing handshake since connect and send operations work even when Redis is not ready
427+ if self .check_server_ready :
428+ try :
429+ self .send_command ("PING" , check_health = False )
430+
431+ response = str_if_bytes (self ._sock .recv (1024 ))
432+ if not (response .startswith ("+PONG" ) or response .startswith ("-NOAUTH" )):
433+ raise ResponseError (f"Invalid PING response: { response } " )
434+ except (ConnectionResetError , ResponseError ) as err :
435+ try :
436+ self ._sock .shutdown (socket .SHUT_RDWR ) # ensure a clean close
437+ except OSError :
438+ pass
439+ self ._sock .close ()
440+ raise ConnectionError (self ._error_message (err ))
441+
421442 @abstractmethod
422- def _connect (self ):
443+ def _connect (self ) -> None :
423444 pass
424445
425446 @abstractmethod
@@ -758,7 +779,7 @@ def repr_pieces(self):
758779 pieces .append (("client_name" , self .client_name ))
759780 return pieces
760781
761- def _connect (self ):
782+ def _connect (self ) -> None :
762783 "Create a TCP socket connection"
763784 # we want to mimic what socket.create_connection does to support
764785 # ipv4/ipv6, but we want to set options prior to calling
@@ -788,7 +809,8 @@ def _connect(self):
788809
789810 # set the socket_timeout now that we're connected
790811 sock .settimeout (self .socket_timeout )
791- return sock
812+ self ._sock = sock
813+ return
792814
793815 except OSError as _ :
794816 err = _
@@ -1101,15 +1123,15 @@ def __init__(
11011123 self .ssl_ciphers = ssl_ciphers
11021124 super ().__init__ (** kwargs )
11031125
1104- def _connect (self ):
1126+ def _connect (self ) -> None :
11051127 """
11061128 Wrap the socket with SSL support, handling potential errors.
11071129 """
1108- sock = super ()._connect ()
1130+ super ()._connect ()
11091131 try :
1110- return self ._wrap_socket_with_ssl (sock )
1132+ self . _sock = self ._wrap_socket_with_ssl (self . _sock )
11111133 except (OSError , RedisError ):
1112- sock .close ()
1134+ self . _sock .close ()
11131135 raise
11141136
11151137 def _wrap_socket_with_ssl (self , sock ):
@@ -1206,7 +1228,7 @@ def repr_pieces(self):
12061228 pieces .append (("client_name" , self .client_name ))
12071229 return pieces
12081230
1209- def _connect (self ):
1231+ def _connect (self ) -> None :
12101232 "Create a Unix domain socket connection"
12111233 sock = socket .socket (socket .AF_UNIX , socket .SOCK_STREAM )
12121234 sock .settimeout (self .socket_connect_timeout )
@@ -1221,7 +1243,7 @@ def _connect(self):
12211243 sock .close ()
12221244 raise
12231245 sock .settimeout (self .socket_timeout )
1224- return sock
1246+ self . _sock = sock
12251247
12261248 def _host_error (self ):
12271249 return self .path
0 commit comments