Skip to content

Commit 3ee1832

Browse files
Add Redis readiness verification (#3555)
1 parent 1c8d77f commit 3ee1832

File tree

11 files changed

+407
-64
lines changed

11 files changed

+407
-64
lines changed

redis/asyncio/client.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
encoding: str = "utf-8",
230230
encoding_errors: str = "strict",
231231
decode_responses: bool = False,
232+
check_server_ready: bool = False,
232233
retry_on_timeout: bool = False,
233234
retry: Retry = Retry(
234235
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -276,6 +277,10 @@ def __init__(
276277
277278
When 'connection_pool' is provided - the retry configuration of the
278279
provided pool will be used.
280+
281+
Args:
282+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
283+
connect and send operations work even when Redis server is not ready.
279284
"""
280285
kwargs: Dict[str, Any]
281286
if event_dispatcher is None:
@@ -310,6 +315,7 @@ def __init__(
310315
"encoding": encoding,
311316
"encoding_errors": encoding_errors,
312317
"decode_responses": decode_responses,
318+
"check_server_ready": check_server_ready,
313319
"retry_on_error": retry_on_error,
314320
"retry": copy.deepcopy(retry),
315321
"max_connections": max_connections,

redis/asyncio/cluster.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ def __init__(
289289
encoding_errors: str = "strict",
290290
decode_responses: bool = False,
291291
# Connection related kwargs
292+
check_server_ready: bool = False,
292293
health_check_interval: float = 0,
293294
socket_connect_timeout: Optional[float] = None,
294295
socket_keepalive: bool = False,
@@ -342,6 +343,7 @@ def __init__(
342343
"encoding_errors": encoding_errors,
343344
"decode_responses": decode_responses,
344345
# Connection related kwargs
346+
"check_server_ready": check_server_ready,
345347
"health_check_interval": health_check_interval,
346348
"socket_connect_timeout": socket_connect_timeout,
347349
"socket_keepalive": socket_keepalive,

redis/asyncio/connection.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
encoding_errors: str = "strict",
149149
decode_responses: bool = False,
150150
parser_class: Type[BaseParser] = DefaultParser,
151+
check_server_ready: bool = False,
151152
socket_read_size: int = 65536,
152153
health_check_interval: float = 0,
153154
client_name: Optional[str] = None,
@@ -204,6 +205,7 @@ def __init__(
204205
self.health_check_interval = health_check_interval
205206
self.next_health_check: float = -1
206207
self.encoder = encoder_class(encoding, encoding_errors, decode_responses)
208+
self.check_server_ready = check_server_ready
207209
self.redis_connect_func = redis_connect_func
208210
self._reader: Optional[asyncio.StreamReader] = None
209211
self._writer: Optional[asyncio.StreamWriter] = None
@@ -300,9 +302,11 @@ async def connect_check_health(self, check_health: bool = True):
300302
return
301303
try:
302304
await self.retry.call_with_retry(
303-
lambda: self._connect(), lambda error: self.disconnect()
305+
lambda: self._connect_check_server_ready(),
306+
lambda error: self.disconnect(),
304307
)
305308
except asyncio.CancelledError:
309+
self._close()
306310
raise # in 3.7 and earlier, this is an Exception, not BaseException
307311
except (socket.timeout, asyncio.TimeoutError):
308312
raise TimeoutError("Timeout connecting to server")
@@ -337,6 +341,33 @@ async def connect_check_health(self, check_health: bool = True):
337341
if task and inspect.isawaitable(task):
338342
await task
339343

344+
async def _connect_check_server_ready(self):
345+
await self._connect()
346+
347+
# Doing handshake since connect and send operations work even when Redis is not ready
348+
if self.check_server_ready:
349+
try:
350+
await self.send_command("PING", check_health=False)
351+
352+
if self.socket_timeout is not None:
353+
async with async_timeout(self.socket_timeout):
354+
response = str_if_bytes(await self._reader.read(1024))
355+
else:
356+
response = str_if_bytes(await self._reader.read(1024))
357+
358+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
359+
raise ResponseError(f"Invalid PING response: {response}")
360+
except (
361+
socket.timeout,
362+
asyncio.TimeoutError,
363+
ResponseError,
364+
ConnectionResetError,
365+
) as e:
366+
# `socket_keepalive_options` might contain invalid options
367+
# causing an error. Do not leave the connection open.
368+
self._close()
369+
raise ConnectionError(self._error_message(e))
370+
340371
@abstractmethod
341372
async def _connect(self):
342373
pass
@@ -526,8 +557,7 @@ async def send_packed_command(
526557
self._send_packed_command(command), self.socket_timeout
527558
)
528559
else:
529-
self._writer.writelines(command)
530-
await self._writer.drain()
560+
await self._send_packed_command(command)
531561
except asyncio.TimeoutError:
532562
await self.disconnect(nowait=True)
533563
raise TimeoutError("Timeout writing to socket") from None
@@ -770,7 +800,7 @@ async def _connect(self):
770800
except (OSError, TypeError):
771801
# `socket_keepalive_options` might contain invalid options
772802
# causing an error. Do not leave the connection open.
773-
writer.close()
803+
self._close()
774804
raise
775805

776806
def _host_error(self) -> str:
@@ -931,7 +961,6 @@ async def _connect(self):
931961
reader, writer = await asyncio.open_unix_connection(path=self.path)
932962
self._reader = reader
933963
self._writer = writer
934-
await self.on_connect()
935964

936965
def _host_error(self) -> str:
937966
return self.path

redis/client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def __init__(
211211
encoding: str = "utf-8",
212212
encoding_errors: str = "strict",
213213
decode_responses: bool = False,
214+
check_server_ready: bool = False,
214215
retry_on_timeout: bool = False,
215216
retry: Retry = Retry(
216217
backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3
@@ -267,10 +268,11 @@ def __init__(
267268
provided pool will be used.
268269
269270
Args:
270-
271-
single_connection_client:
272-
if `True`, connection pool is not used. In that case `Redis`
273-
instance use is not thread safe.
271+
check_server_ready: if `True`, an extra handshake is performed by sending a PING command, since
272+
connect and send operations work even when Redis server is not ready.
273+
single_connection_client:
274+
if `True`, connection pool is not used. In that case `Redis`
275+
instance use is not thread safe.
274276
"""
275277
if event_dispatcher is None:
276278
self._event_dispatcher = EventDispatcher()
@@ -287,6 +289,7 @@ def __init__(
287289
"encoding": encoding,
288290
"encoding_errors": encoding_errors,
289291
"decode_responses": decode_responses,
292+
"check_server_ready": check_server_ready,
290293
"retry_on_error": retry_on_error,
291294
"retry": copy.deepcopy(retry),
292295
"max_connections": max_connections,

redis/connection.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def __init__(
236236
encoding: str = "utf-8",
237237
encoding_errors: str = "strict",
238238
decode_responses: bool = False,
239+
check_server_ready: bool = False,
239240
parser_class=DefaultParser,
240241
socket_read_size: int = 65536,
241242
health_check_interval: int = 0,
@@ -302,6 +303,7 @@ def __init__(
302303
self.redis_connect_func = redis_connect_func
303304
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
304305
self.handshake_metadata = None
306+
self.check_server_ready = check_server_ready
305307
self._sock = None
306308
self._socket_read_size = socket_read_size
307309
self.set_parser(parser_class)
@@ -382,15 +384,15 @@ def connect_check_health(self, check_health: bool = True):
382384
if self._sock:
383385
return
384386
try:
385-
sock = self.retry.call_with_retry(
386-
lambda: self._connect(), lambda error: self.disconnect(error)
387+
self.retry.call_with_retry(
388+
lambda: self._connect_check_server_ready(),
389+
lambda error: self.disconnect(error),
387390
)
388391
except socket.timeout:
389392
raise TimeoutError("Timeout connecting to server")
390393
except OSError as e:
391394
raise ConnectionError(self._error_message(e))
392395

393-
self._sock = sock
394396
try:
395397
if self.redis_connect_func is None:
396398
# Use the default on_connect function
@@ -412,8 +414,27 @@ def connect_check_health(self, check_health: bool = True):
412414
if callback:
413415
callback(self)
414416

417+
def _connect_check_server_ready(self):
418+
self._connect()
419+
420+
# Doing handshake since connect and send operations work even when Redis is not ready
421+
if self.check_server_ready:
422+
try:
423+
self.send_command("PING", check_health=False)
424+
425+
response = str_if_bytes(self._sock.recv(1024))
426+
if not (response.startswith("+PONG") or response.startswith("-NOAUTH")):
427+
raise ResponseError(f"Invalid PING response: {response}")
428+
except (ConnectionResetError, ResponseError) as err:
429+
try:
430+
self._sock.shutdown(socket.SHUT_RDWR) # ensure a clean close
431+
except OSError:
432+
pass
433+
self._sock.close()
434+
raise ConnectionError(self._error_message(err))
435+
415436
@abstractmethod
416-
def _connect(self):
437+
def _connect(self) -> None:
417438
pass
418439

419440
@abstractmethod
@@ -752,7 +773,7 @@ def repr_pieces(self):
752773
pieces.append(("client_name", self.client_name))
753774
return pieces
754775

755-
def _connect(self):
776+
def _connect(self) -> None:
756777
"Create a TCP socket connection"
757778
# we want to mimic what socket.create_connection does to support
758779
# ipv4/ipv6, but we want to set options prior to calling
@@ -782,7 +803,8 @@ def _connect(self):
782803

783804
# set the socket_timeout now that we're connected
784805
sock.settimeout(self.socket_timeout)
785-
return sock
806+
self._sock = sock
807+
return
786808

787809
except OSError as _:
788810
err = _
@@ -1095,15 +1117,15 @@ def __init__(
10951117
self.ssl_ciphers = ssl_ciphers
10961118
super().__init__(**kwargs)
10971119

1098-
def _connect(self):
1120+
def _connect(self) -> None:
10991121
"""
11001122
Wrap the socket with SSL support, handling potential errors.
11011123
"""
1102-
sock = super()._connect()
1124+
super()._connect()
11031125
try:
1104-
return self._wrap_socket_with_ssl(sock)
1126+
self._sock = self._wrap_socket_with_ssl(self._sock)
11051127
except (OSError, RedisError):
1106-
sock.close()
1128+
self._sock.close()
11071129
raise
11081130

11091131
def _wrap_socket_with_ssl(self, sock):
@@ -1200,7 +1222,7 @@ def repr_pieces(self):
12001222
pieces.append(("client_name", self.client_name))
12011223
return pieces
12021224

1203-
def _connect(self):
1225+
def _connect(self) -> None:
12041226
"Create a Unix domain socket connection"
12051227
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
12061228
sock.settimeout(self.socket_connect_timeout)
@@ -1215,7 +1237,7 @@ def _connect(self):
12151237
sock.close()
12161238
raise
12171239
sock.settimeout(self.socket_timeout)
1218-
return sock
1240+
self._sock = sock
12191241

12201242
def _host_error(self):
12211243
return self.path

tests/test_asyncio/test_cluster.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,7 @@ async def test_reading_with_load_balancing_strategies(
729729
Connection,
730730
send_command=mock.DEFAULT,
731731
read_response=mock.DEFAULT,
732-
_connect=mock.DEFAULT,
732+
_connect_check_server_ready=mock.DEFAULT,
733733
can_read_destructive=mock.DEFAULT,
734734
on_connect=mock.DEFAULT,
735735
) as mocks:
@@ -761,7 +761,7 @@ def execute_command_mock_third(self, *args, **options):
761761
execute_command.side_effect = execute_command_mock_first
762762
mocks["send_command"].return_value = True
763763
mocks["read_response"].return_value = "OK"
764-
mocks["_connect"].return_value = True
764+
mocks["_connect_check_server_ready"].return_value = True
765765
mocks["can_read_destructive"].return_value = False
766766
mocks["on_connect"].return_value = True
767767

@@ -3117,13 +3117,19 @@ async def execute_command(self, *args, **kwargs):
31173117

31183118
return _create_client
31193119

3120+
@pytest.mark.parametrize("check_server_ready", [True, False])
31203121
async def test_ssl_connection_without_ssl(
3121-
self, create_client: Callable[..., Awaitable[RedisCluster]]
3122+
self, create_client: Callable[..., Awaitable[RedisCluster]], check_server_ready
31223123
) -> None:
31233124
with pytest.raises(RedisClusterException) as e:
3124-
await create_client(mocked=False, ssl=False)
3125+
await create_client(
3126+
mocked=False, ssl=False, check_server_ready=check_server_ready
3127+
)
31253128
e = e.value.__cause__
3126-
assert "Connection closed by server" in str(e)
3129+
if check_server_ready:
3130+
assert "Invalid PING response" in str(e)
3131+
else:
3132+
assert "Connection closed by server" in str(e)
31273133

31283134
async def test_ssl_with_invalid_cert(
31293135
self, create_client: Callable[..., Awaitable[RedisCluster]]

0 commit comments

Comments
 (0)