From 092e33bdd1caa7f6c1d70640cb84f54f89a0cbfa Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 27 Jun 2025 14:03:39 +0300 Subject: [PATCH 01/16] Handling of topology update push notifications for Standalone Redis client. --- redis/_parsers/base.py | 88 ++++++- redis/_parsers/hiredis.py | 27 +- redis/_parsers/resp3.py | 16 +- redis/client.py | 86 +++++- redis/connection.py | 510 +++++++++++++++++++++++++++++++----- redis/maintenance_events.py | 349 ++++++++++++++++++++++++ 6 files changed, 980 insertions(+), 96 deletions(-) create mode 100644 redis/maintenance_events.py diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 69d7b585dd..a0f6af4ac2 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -3,6 +3,12 @@ from asyncio import IncompleteReadError, StreamReader, TimeoutError from typing import Callable, List, Optional, Protocol, Union +from redis.maintenance_events import ( + NodeMigratedEvent, + NodeMigratingEvent, + NodeMovingEvent, +) + if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout else: @@ -158,7 +164,19 @@ async def read_response( raise NotImplementedError() -_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] +_INVALIDATION_MESSAGE = (b"invalidate", "invalidate") +_MOVING_MESSAGE = (b"MOVING", "MOVING") +_MIGRATING_MESSAGE = (b"MIGRATING", "MIGRATING") +_MIGRATED_MESSAGE = (b"MIGRATED", "MIGRATED") +_FAILING_OVER_MESSAGE = (b"FAILING_OVER", "FAILING_OVER") +_FAILED_OVER_MESSAGE = (b"FAILED_OVER", "FAILED_OVER") + +_MAINTENANCE_MESSAGES = ( + *_MIGRATING_MESSAGE, + *_MIGRATED_MESSAGE, + *_FAILING_OVER_MESSAGE, + *_FAILED_OVER_MESSAGE, +) class PushNotificationsParser(Protocol): @@ -166,16 +184,41 @@ class PushNotificationsParser(Protocol): pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None def handle_pubsub_push_response(self, response): """Handle pubsub push responses""" raise NotImplementedError() def handle_push_response(self, response, **kwargs): - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + if msg_type in _MOVING_MESSAGE: + host, port = response[2].split(":") + ttl = response[1] + notification = NodeMovingEvent(host, port, ttl) + return self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + ttl = response[1] + notification = NodeMigratingEvent(ttl) + elif msg_type in _MIGRATED_MESSAGE: + notification = NodeMigratedEvent() + else: + notification = None + if notification is not None: + return self.maintenance_push_handler_func(notification) + else: + return None def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func @@ -183,12 +226,20 @@ def set_pubsub_push_handler(self, pubsub_push_handler_func): def set_invalidation_push_handler(self, invalidation_push_handler_func): self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class AsyncPushNotificationsParser(Protocol): """Protocol defining async RESP3-specific parsing functionality""" pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None async def handle_pubsub_push_response(self, response): """Handle pubsub push responses asynchronously""" @@ -196,10 +247,31 @@ async def handle_pubsub_push_response(self, response): async def handle_push_response(self, response, **kwargs): """Handle push responses asynchronously""" - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return await self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return await self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # push notification from enterprise cluster for node moving + host, port = response[2].split(":") + ttl = response[1] + id = 1 # TODO: get unique id from push notification + notification = NodeMovingEvent(id, host, port, ttl) + return await self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + ttl = response[1] + id = 1 # TODO: get unique id from push notification + notification = NodeMigratingEvent(id, ttl) + elif msg_type in _MIGRATED_MESSAGE: + id = 1 # TODO: get unique id from push notification + notification = NodeMigratedEvent(id) + return await self.maintenance_push_handler_func(notification) def set_pubsub_push_handler(self, pubsub_push_handler_func): """Set the pubsub push handler function""" @@ -209,6 +281,12 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func): """Set the invalidation push handler function""" self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler_func(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class _AsyncRESPBase(AsyncBaseParser): """Base class for async resp parsing""" diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index 521a58b26c..e9df314a8c 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -47,6 +47,8 @@ def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self._buffer = bytearray(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None self._hiredis_PushNotificationType = None @@ -141,13 +143,15 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response - return response + + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) if disable_decoding: response = self._reader.gets(False) @@ -169,12 +173,13 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + if push_request: return response + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) + elif ( isinstance(response, list) and response diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 42c6652e31..72957b464c 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser): def __init__(self, socket_read_size): super().__init__(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None def handle_pubsub_push_response(self, response): @@ -117,17 +119,21 @@ def _read_response(self, disable_decoding=False, push_request=False): for _ in range(int(response)) ] response = self.handle_push_response(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response + + return self._read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) + return response diff --git a/redis/client.py b/redis/client.py index 0e05b6f542..0ec36c52d9 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,4 +1,5 @@ import copy +import logging import re import threading import time @@ -56,6 +57,10 @@ WatchError, ) from redis.lock import Lock +from redis.maintenance_events import ( + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, +) from redis.retry import Retry from redis.utils import ( _set_info_logger, @@ -244,6 +249,7 @@ def __init__( cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, ) -> None: """ Initialize a new Redis client. @@ -368,6 +374,23 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") + if maintenance_events_config and self.connection_pool.get_protocol() not in [ + 3, + "3", + ]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + if maintenance_events_config and maintenance_events_config.enabled: + self.maintenance_events_pool_handler = MaintenanceEventPoolHandler( + self.connection_pool, maintenance_events_config + ) + self.connection_pool.set_maintenance_events_pool_handler( + self.maintenance_events_pool_handler + ) + else: + self.maintenance_events_pool_handler = None + self.single_connection_lock = threading.RLock() self.connection = None self._single_connection_client = single_connection_client @@ -565,8 +588,15 @@ def monitor(self): return Monitor(self.connection_pool) def client(self): + maintenance_events_config = ( + None + if self.maintenance_events_pool_handler is None + else self.maintenance_events_pool_handler.config + ) return self.__class__( - connection_pool=self.connection_pool, single_connection_client=True + connection_pool=self.connection_pool, + single_connection_client=True, + maintenance_events_config=maintenance_events_config, ) def __enter__(self): @@ -635,7 +665,14 @@ def _execute_command(self, *args, **options): ), lambda _: self._close_connection(conn), ) + finally: + if conn and conn.should_reconnect(): + logging.debug( + f"***** Redis reconnect before exit _execute_command --> notification for {conn._sock.getpeername()}" + ) + self._close_connection(conn) + conn.connect() if self._single_connection_client: self.single_connection_lock.release() if not self.connection: @@ -686,11 +723,7 @@ def __init__(self, connection_pool): self.connection = self.connection_pool.get_connection() def __enter__(self): - self.connection.send_command("MONITOR") - # check that monitor returns 'OK', but don't return it to user - response = self.connection.read_response() - if not bool_ok(response): - raise RedisError(f"MONITOR failed: {response}") + self._start_monitor() return self def __exit__(self, *args): @@ -700,8 +733,13 @@ def __exit__(self, *args): def next_command(self): """Parse the response from a monitor command""" response = self.connection.read_response() + + if response is None: + return None + if isinstance(response, bytes): response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) m = self.monitor_re.match(command_data) db_id, client_info, command = m.groups() @@ -737,6 +775,14 @@ def listen(self): while True: yield self.next_command() + def _start_monitor(self): + self.connection.send_command("MONITOR") + # check that monitor returns 'OK', but don't return it to user + response = self.connection.read_response() + + if not bool_ok(response): + raise RedisError(f"MONITOR failed: {response}") + class PubSub: """ @@ -881,7 +927,7 @@ def clean_health_check_responses(self) -> None: """ ttl = 10 conn = self.connection - while self.health_check_response_counter > 0 and ttl > 0: + while conn and self.health_check_response_counter > 0 and ttl > 0: if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): response = self._execute(conn, conn.read_response) if self.is_health_check_response(response): @@ -911,10 +957,18 @@ def _execute(self, conn, command, *args, **kwargs): called by the # connection to resubscribe us to any channels and patterns we were previously listening to """ - return conn.retry.call_with_retry( + + response = conn.retry.call_with_retry( lambda: command(*args, **kwargs), lambda _: self._reconnect(conn), ) + if conn.should_reconnect(): + logging.debug( + f"***** PubSub --> Reconnect on notification for {conn._sock.getpeername()}" + ) + self._reconnect(conn) + + return response def parse_response(self, block=True, timeout=0): """Parse the response from a publish/subscribe command""" @@ -1148,6 +1202,7 @@ def handle_message(self, response, ignore_subscribe_messages=False): return None if isinstance(response, bytes): response = [b"pong", response] if response != b"PONG" else [b"pong", b""] + message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { @@ -1351,6 +1406,7 @@ def reset(self) -> None: # clean up the other instance attributes self.watching = False self.explicit_transaction = False + # we can safely return the connection to the pool here since we're # sure we're no longer WATCHing anything if self.connection: @@ -1510,6 +1566,7 @@ def _execute_transaction( if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) data.append(r) + return data def _execute_pipeline(self, connection, commands, raise_on_error): @@ -1517,16 +1574,17 @@ def _execute_pipeline(self, connection, commands, raise_on_error): all_cmds = connection.pack_commands([args for args, _ in commands]) connection.send_packed_command(all_cmds) - response = [] + responses = [] for args, options in commands: try: - response.append(self.parse_response(connection, args[0], **options)) + responses.append(self.parse_response(connection, args[0], **options)) except ResponseError as e: - response.append(e) + responses.append(e) if raise_on_error: - self.raise_first_error(commands, response) - return response + self.raise_first_error(commands, responses) + + return responses def raise_first_error(self, commands, response): for i, r in enumerate(response): @@ -1611,6 +1669,8 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: + # in reset() the connection is diconnected before returned to the pool if + # it is marked for reconnect. self.reset() def discard(self): diff --git a/redis/connection.py b/redis/connection.py index 47cb589569..f55e1b455c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,4 +1,5 @@ import copy +import logging import os import socket import sys @@ -19,10 +20,11 @@ CacheInterface, CacheKey, ) +from redis.typing import Number from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .auth.token import TokenInterface -from .backoff import NoBackoff +from .backoff import ExponentialWithJitterBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher from .exceptions import ( @@ -36,6 +38,11 @@ ResponseError, TimeoutError, ) +from .maintenance_events import ( + MaintenanceEventConnectionHandler, + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, +) from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, @@ -159,6 +166,10 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): pass + @abstractmethod + def set_maintenance_event_pool_handler(self, maintenance_event_pool_handler): + pass + @abstractmethod def get_protocol(self): pass @@ -222,6 +233,26 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass + @abstractmethod + def mark_for_reconnect(self): + pass + + @abstractmethod + def should_reconnect(self): + pass + + @abstractmethod + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + pass + + @abstractmethod + def update_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + pass + class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" @@ -250,6 +281,10 @@ def __init__( protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = -1, ): """ Initialize a new Connection. @@ -288,16 +323,15 @@ def __init__( # Add TimeoutError to the errors list to retry on retry_on_error.append(TimeoutError) self.retry_on_error = retry_on_error - if retry or retry_on_error: - if retry is None: - self.retry = Retry(NoBackoff(), 1) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) + if retry is None: + self.retry = Retry( + backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 + ) else: - self.retry = Retry(NoBackoff(), 0) + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + if retry_on_error: + self.retry.update_supported_errors(retry_on_error) self.health_check_interval = health_check_interval self.next_health_check = 0 self.redis_connect_func = redis_connect_func @@ -305,7 +339,6 @@ def __init__( self.handshake_metadata = None self._sock = None self._socket_read_size = socket_read_size - self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 self._re_auth_token: Optional[TokenInterface] = None @@ -320,7 +353,26 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") # p = DEFAULT_RESP_VERSION self.protocol = p + if self.protocol == 3 and parser_class == DefaultParser: + parser_class = _RESP3Parser + self.set_parser(parser_class) + + if maintenance_events_config and maintenance_events_config.enabled: + if maintenance_events_pool_handler: + self._parser.set_node_moving_push_handler( + maintenance_events_pool_handler.handle_event + ) + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler(self, maintenance_events_config) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) + self._command_packer = self._construct_command_packer(command_packer) + self._should_reconnect = False + self.tmp_host_address = tmp_host_address + self.tmp_relax_timeout = tmp_relax_timeout def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -375,6 +427,11 @@ def set_parser(self, parser_class): """ self._parser = parser_class(socket_read_size=self._socket_read_size) + def set_maintenance_event_pool_handler( + self, maintenance_event_pool_handler: MaintenanceEventPoolHandler + ): + self._parser.set_node_moving_push_handler(maintenance_event_pool_handler) + def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) @@ -549,6 +606,8 @@ def disconnect(self, *args): conn_sock = self._sock self._sock = None + # reset the reconnect flag + self._should_reconnect = False if conn_sock is None: return @@ -626,6 +685,7 @@ def can_read(self, timeout=0): try: return self._parser.can_read(timeout) + except OSError as e: self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") @@ -732,6 +792,35 @@ def re_auth(self): self.read_response() self._re_auth_token = None + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + if self._sock: + timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout + logging.debug( + f"***** Connection --> Updating timeout for {self._sock.getpeername()}" + f" to timeout {timeout}; relax_timeout: {relax_timeout}" + ) + self._sock.settimeout(timeout) + self._parser._buffer.socket_timeout = timeout + + def update_tmp_settings( + self, + tmp_host_address: Optional[str | object] = SENTINEL, + tmp_relax_timeout: Optional[float | object] = SENTINEL, + ): + """ + The value of SENTINEL is used to indicate that the property should not be updated. + """ + if tmp_host_address is not SENTINEL: + self.tmp_host_address = tmp_host_address + if tmp_relax_timeout is not SENTINEL: + self.tmp_relax_timeout = tmp_relax_timeout + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -764,8 +853,14 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None + if self.tmp_host_address is not None: + logging.debug( + f"***** Connection --> Using tmp_host_address: {self.tmp_host_address}" + ) + host = self.tmp_host_address or self.host + for res in socket.getaddrinfo( - self.host, self.port, self.socket_type, socket.SOCK_STREAM + host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -781,13 +876,32 @@ def _connect(self): sock.setsockopt(socket.IPPROTO_TCP, k, v) # set the socket_connect_timeout before we connect - sock.settimeout(self.socket_connect_timeout) + if self.tmp_relax_timeout != -1: + logging.debug( + f"***** Connection connect --> Using relax_timeout: {self.tmp_relax_timeout}" + ) + sock.settimeout(self.tmp_relax_timeout) + else: + logging.debug( + f"***** Connection connect --> Using default socket_connect_timeout: {self.socket_connect_timeout}" + ) + sock.settimeout(self.socket_connect_timeout) # connect sock.connect(socket_address) # set the socket_timeout now that we're connected - sock.settimeout(self.socket_timeout) + if self.tmp_relax_timeout != -1: + logging.debug( + f"***** Connection --> Using relax_timeout: {self.tmp_relax_timeout}" + ) + sock.settimeout(self.tmp_relax_timeout) + else: + logging.debug( + f"***** Connection --> Using default socket_timeout: {self.socket_timeout}" + ) + sock.settimeout(self.socket_timeout) + logging.debug(f"Connected to {sock.getpeername()}") return sock except OSError as _: @@ -1415,6 +1529,14 @@ def __init__( connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) + if connection_kwargs.get( + "maintenance_events_pool_handler" + ) or connection_kwargs.get("maintenance_events_config"): + if connection_kwargs.get("protocol") not in [3, "3"]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -1449,6 +1571,46 @@ def get_protocol(self): """ return self.connection_kwargs.get("protocol", None) + def maintenance_events_pool_handler_enabled(self): + """ + Returns: + True if the maintenance events pool handler is enabled, False otherwise. + """ + maintenance_events_config = self.connection_kwargs.get( + "maintenance_events_config", False + ) + + return maintenance_events_config and maintenance_events_config.enabled + + def set_maintenance_events_pool_handler( + self, maintenance_events_pool_handler: MaintenanceEventPoolHandler + ): + self.connection_kwargs.update( + { + "maintenance_events_pool_handler": maintenance_events_pool_handler, + "maintenance_events_config": maintenance_events_pool_handler.config, + } + ) + + self._update_maintenance_events_configs_for_connections( + maintenance_events_pool_handler + ) + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + with self._lock: + for conn in self._available_connections: + conn.set_maintenance_events_pool_handler( + maintenance_events_pool_handler + ) + conn.maintenance_events_config = maintenance_events_pool_handler.config + for conn in self._in_use_connections: + conn.set_maintenance_events_pool_handler( + maintenance_events_pool_handler + ) + conn.maintenance_events_config = maintenance_events_pool_handler.config + def reset(self) -> None: self._created_connections = 0 self._available_connections = [] @@ -1536,7 +1698,11 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if connection.can_read() and self.cache is None: + if ( + connection.can_read() + and self.cache is None + and not self.maintenance_events_pool_handler_enabled() + ): raise ConnectionError("Connection has data") except (ConnectionError, TimeoutError, OSError): connection.disconnect() @@ -1548,7 +1714,6 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # leak it self.release(connection) raise - return connection def get_encoder(self) -> Encoder: @@ -1570,7 +1735,6 @@ def make_connection(self) -> "ConnectionInterface": return CacheProxyConnection( self.connection_class(**self.connection_kwargs), self.cache, self._lock ) - return self.connection_class(**self.connection_kwargs) def release(self, connection: "Connection") -> None: @@ -1585,6 +1749,11 @@ def release(self, connection: "Connection") -> None: return if self.owns_connection(connection): + if connection.should_reconnect(): + logging.debug( + f"***** Pool--> disconnecting in release {connection._sock.getpeername()}" + ) + connection.disconnect() self._available_connections.append(connection) self._event_dispatcher.dispatch( AfterConnectionReleasedEvent(connection) @@ -1646,6 +1815,154 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) + def update_connection_kwargs_with_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Update the connection kwargs with the temporary host address and the + relax timeout(if enabled). + This is used when a cluster node is rebind to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + This new address will be used to create new connections until the old node is decomissioned. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + If -1 is provided - the relax timeout is disabled, so the tmp property is not set + """ + self.connection_kwargs.update({"tmp_host_address": tmp_host_address}) + self.connection_kwargs.update({"tmp_relax_timeout": tmp_relax_timeout}) + + def update_connections_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Update the tmp settings for all connections in the pool. + This is used when a cluster node is rebind to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + with self._lock: + for conn in self._available_connections: + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + for conn in self._in_use_connections: + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_active_connections_for_reconnect( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + """ + for conn in self._in_use_connections: + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + + for conn in self._available_connections: + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_connections_current_timeout( + self, + relax_timeout: Optional[float], + include_available_connections: bool = False, + ): + """ + Update the timeout either for all connections in the pool or just for the ones in use. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param relax_timeout: The relax timeout to use for the connection. + If -1 is provided - the relax timeout is disabled. + :param include_available_connections: Whether to include available connections in the update. + """ + logging.debug(f"***** Pool --> Updating timeouts. New value: {relax_timeout}") + start_time = time.time() + + for conn in self._in_use_connections: + self._update_connection_timeout(conn, relax_timeout) + + if include_available_connections: + for conn in self._available_connections: + self._update_connection_timeout(conn, relax_timeout) + + execution_time_us = (time.time() - start_time) * 1000000 + logging.error( + f"###### TIMEOUTS execution time: {execution_time_us:.0f} microseconds" + ) + + def _update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.mark_for_reconnect() + self._update_connection_tmp_settings( + connection, tmp_host_address, tmp_relax_timeout + ) + + def _disconnect_and_update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.disconnect() + self._update_connection_tmp_settings( + connection, tmp_host_address, tmp_relax_timeout + ) + + def _update_connection_tmp_settings( + self, + connection: "Connection", + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + connection.tmp_host_address = tmp_host_address + connection.tmp_relax_timeout = tmp_relax_timeout + + def _update_connection_timeout( + self, connection: "Connection", relax_timeout: Optional[Number] + ): + connection.update_current_socket_timeout(relax_timeout) + async def _mock(self, error: RedisError): """ Dummy functions, needs to be passed as error callback to retry object. @@ -1707,16 +2024,17 @@ def __init__( def reset(self): # Create and fill up a thread safe queue with ``None`` values. - self.pool = self.queue_class(self.max_connections) - while True: - try: - self.pool.put_nowait(None) - except Full: - break + with self._lock: + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except Full: + break - # Keep a list of actual connection instances so that we can - # disconnect them later. - self._connections = [] + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -1731,14 +2049,18 @@ def reset(self): def make_connection(self): "Make a fresh connection." - if self.cache is not None: - connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock - ) - else: - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection + with self._lock: + if self.cache is not None: + connection = CacheProxyConnection( + self.connection_class(**self.connection_kwargs), + self.cache, + self._lock, + ) + else: + connection = self.connection_class(**self.connection_kwargs) + + self._connections.append(connection) + return connection @deprecated_args( args_to_warn=["*"], @@ -1763,17 +2085,18 @@ def get_connection(self, command_name=None, *keys, **options): # Try and get a connection from the pool. If one isn't available within # self.timeout then raise a ``ConnectionError``. connection = None - try: - connection = self.pool.get(block=True, timeout=self.timeout) - except Empty: - # Note that this is not caught by the redis client and will be - # raised unless handled by application code. If you want never to - raise ConnectionError("No connection available.") - - # If the ``connection`` is actually ``None`` then that's a cue to make - # a new connection to add to the pool. - if connection is None: - connection = self.make_connection() + with self._lock: + try: + connection = self.pool.get(block=True, timeout=self.timeout) + except Empty: + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() try: # ensure this connection is connected to Redis @@ -1801,25 +2124,88 @@ def release(self, connection): "Releases the connection back to the pool." # Make sure we haven't changed process. self._checkpid() - if not self.owns_connection(connection): - # pool doesn't own this connection. do not add it back - # to the pool. instead add a None value which is a placeholder - # that will cause the pool to recreate the connection if - # its needed. - connection.disconnect() - self.pool.put_nowait(None) - return - # Put the connection back into the pool. - try: - self.pool.put_nowait(connection) - except Full: - # perhaps the pool has been reset() after a fork? regardless, - # we don't want this connection - pass + with self._lock: + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + connection.disconnect() + self.pool.put_nowait(None) + return + if connection.should_reconnect(): + logging.debug( + f"***** Blocking Pool--> disconnecting in release {connection._sock.getpeername()}" + ) + connection.disconnect() + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except Full: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass def disconnect(self): "Disconnects all connections in the pool." self._checkpid() - for connection in self._connections: - connection.disconnect() + with self._lock: + for connection in self._connections: + connection.disconnect() + + def update_active_connections_for_reconnect( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + with self._lock: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + if tmp_relax_timeout != -1: + conn.update_socket_timeout(tmp_relax_timeout) + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[Number] = None, + ): + with self._lock: + existing_connections = self.pool.queue + + for conn in existing_connections: + if conn: + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_connections_current_timeout(self, relax_timeout: Optional[float] = None): + logging.debug( + f"***** Blocking Pool --> Updating timeouts. relax_timeout: {relax_timeout}" + ) + + with self._lock: + for conn in tuple(self._connections): + self._update_connection_timeout(conn, relax_timeout) + + def update_connections_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + with self._lock: + for conn in tuple(self._connections): + self._update_connection_tmp_settings( + conn, tmp_host_address, tmp_relax_timeout + ) + + def _update_maintenance_events_config_for_connections( + self, maintenance_events_config + ): + with self._lock: + for conn in tuple(self._connections): + conn.maintenance_events_config = maintenance_events_config diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py new file mode 100644 index 0000000000..bbc519d0cc --- /dev/null +++ b/redis/maintenance_events.py @@ -0,0 +1,349 @@ +import logging +import threading +import time +from typing import TYPE_CHECKING + +from redis.typing import Number + +if TYPE_CHECKING: + from redis.connection import ConnectionInterface, ConnectionPool + + +class MaintenanceEvent: + """ + Base class for maintenance events sent through push messages by Redis server. + + This class provides common TTL (Time-To-Live) functionality for all + maintenance events. + + Attributes: + ttl (int): Time-to-live in seconds for this notification + creation_time (float): Timestamp when the notification was created/read + """ + + def __init__(self, ttl: int): + """ + Initialize a new MaintenanceEvent with TTL functionality. + + Args: + ttl (int): Time-to-live in seconds for this notification + """ + self.ttl = ttl + self.creation_time = int(time.time()) + self.expire_at = self.creation_time + self.ttl + + def is_expired(self) -> bool: + """ + Check if this event has expired based on its TTL + and creation time. + + Returns: + bool: True if the event has expired, False otherwise + """ + return int(time.time()) > (self.creation_time + self.ttl) + + +class NodeMovingEvent(MaintenanceEvent): + """ + This event is received when a node is replaced with a new node + during cluster rebalancing or maintenance operations. + """ + + def __init__(self, new_node_host: str, new_node_port: int, ttl: int): + """ + Initialize a new NodeMovingEvent. + + Args: + new_node_host (str): Hostname or IP address of the new replacement node + new_node_port (int): Port number of the new replacement node + ttl (int): Time-to-live in seconds for this notification + """ + super().__init__(ttl) + self.new_node_host = new_node_host + self.new_node_port = new_node_port + + def __repr__(self) -> str: + expiry_time = self.expire_at + remaining = max(0, expiry_time - time.time()) + + return ( + f"{self.__class__.__name__}(" + f"new_node_host='{self.new_node_host}', " + f"new_node_port={self.new_node_port}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMovingEvent events are considered equal if they have the same + new_node_host and new_node_port. + """ + if not isinstance(other, NodeMovingEvent): + return False + return ( + self.new_node_host == other.new_node_host + and self.new_node_port == other.new_node_port + ) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on new_node_host and new_node_port + """ + return hash((self.__class__, self.new_node_host, self.new_node_port)) + + +class NodeMigratingEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of migrating slots. + + This event is received when a node starts migrating its slots to another node + during cluster rebalancing or maintenance operations. + + Args: + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, ttl: int): + super().__init__(ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.time()) + return ( + f"{self.__class__.__name__}(" + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + +class NodeMigratedEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed migrating slots. + + This event is received when a node has finished migrating all its slots + to other nodes during cluster rebalancing or maintenance operations. + + Args: + ttl (int): Time-to-live in seconds for this notification + """ + + DEFAULT_TTL = 5 + + def __init__(self): + super().__init__(NodeMigratedEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.time()) + return ( + f"{self.__class__.__name__}(" + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + +class MaintenanceEventsConfig: + """ + Configuration class for maintenance events handling behaviour. Events are received through + push notifications. + + This class defines how the Redis client should react to different push notifications + such as node moving, migrations, etc. in a Redis cluster. + + """ + + def __init__( + self, + enabled: bool = False, + proactive_reconnect: bool = True, + relax_timeout: Number = 20, + ): + """ + Initialize a new MaintenanceEventsConfig. + + Args: + enabled (bool): Whether to enable maintenance events handling. + Defaults to False. + proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. + Defaults to True. + relax_timeout (Number): The relax timeout to use for the connection during maintenance. + If -1 is provided - the relax timeout is disabled. Defaults to 20. + + """ + self.enabled = enabled + self.relax_timeout = relax_timeout + self.proactive_reconnect = proactive_reconnect + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"enabled={self.enabled}, " + f"proactive_reconnect={self.proactive_reconnect}, " + f"relax_timeout={self.relax_timeout}, " + f")" + ) + + def is_relax_timeouts_enabled(self) -> bool: + """ + Check if the relax_timeout is enabled. The '-1' value is used to disable the relax_timeout. + If relax_timeout is set to None, it will make the operation blocking + and waiting until any response is received. + + Returns: + True if the relax_timeout is enabled, False otherwise. + """ + return self.relax_timeout != -1 + + +class MaintenanceEventPoolHandler: + def __init__(self, pool: "ConnectionPool", config: MaintenanceEventsConfig) -> None: + self.pool = pool + self.config = config + self._processed_events = set() + self._lock = threading.RLock() + + def remove_expired_notifications(self): + with self._lock: + for notification in tuple(self._processed_events): + if notification.is_expired(): + self._processed_events.remove(notification) + + def handle_event(self, notification: MaintenanceEvent): + self.remove_expired_notifications() + + if isinstance(notification, NodeMovingEvent): + return self.handle_node_moving_event(notification) + else: + logging.error(f"Unhandled notification type: {notification}") + + def handle_node_moved_event(self): + with self._lock: + self.pool.update_connection_kwargs_with_tmp_settings( + tmp_host_address=None, + tmp_relax_timeout=-1, + ) + with self.pool._lock: + if self.config.is_relax_timeouts_enabled(): + # reset the timeout for existing connections + self.pool.update_connections_current_timeout( + relax_timeout=-1, include_available_connections=True + ) + logging.debug("***** MOVING END--> TIMEOUTS RESET") + + self.pool.update_connections_tmp_settings( + tmp_host_address=None, tmp_relax_timeout=-1 + ) + logging.debug("***** MOVING END--> TMP SETTINGS ADDRESS RESET") + + def handle_node_moving_event(self, event: NodeMovingEvent): + if ( + not self.config.proactive_reconnect + and not self.config.is_relax_timeouts_enabled() + ): + return + with self._lock: + if event in self._processed_events: + # nothing to do in the connection pool handling + # the event has already been handled or is expired + # just return + logging.debug("***** MOVING --> SKIPPED DONE") + return + + logging.info(f"***** MOVING --> {event}") + logging.info(f"***** MOVING --> set: {self._processed_events}") + start_time = time.time() + + with self.pool._lock: + if ( + self.config.proactive_reconnect + or self.config.is_relax_timeouts_enabled() + ): + # edit the config for new connections until the notification expires + self.pool.update_connection_kwargs_with_tmp_settings( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + if self.config.is_relax_timeouts_enabled(): + # extend the timeout for all connections that are currently in use + self.pool.update_connections_current_timeout( + self.config.relax_timeout + ) + if self.config.proactive_reconnect: + # take care for the active connections in the pool + # mark them for reconnect after they complete the current command + self.pool.update_active_connections_for_reconnect( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + + # take care for the inactive connections in the pool + # delete them and create new ones + start_time_2 = time.time() + self.pool.disconnect_and_reconfigure_free_connections( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + execution_time_us = (time.time() - start_time_2) * 1000000 + logging.error( + f"###### MOVING disconnects execution time: {execution_time_us:.0f} microseconds" + ) + + threading.Timer(event.ttl, self.handle_node_moved_event).start() + + self._processed_events.add(event) + execution_time_us = (time.time() - start_time) * 1000000 + logging.error( + f"###### MOVING total execution time: {execution_time_us:.0f} microseconds" + ) + + +class MaintenanceEventConnectionHandler: + def __init__( + self, connection: "ConnectionInterface", config: MaintenanceEventsConfig + ) -> None: + self.connection = connection + self.config = config + + def handle_event(self, event: MaintenanceEvent): + if isinstance(event, NodeMigratingEvent): + return self.handle_migrating_event(event) + elif isinstance(event, NodeMigratedEvent): + return self.handle_migration_completed_event(event) + else: + logging.error(f"Unhandled event type: {event}") + + def handle_migrating_event(self, notification: NodeMigratingEvent): + if not self.config.is_relax_timeouts_enabled(): + return + + logging.info(f"***** MIGRATING --> {notification}") + # extend the timeout for all created connections + self.connection.update_current_socket_timeout(self.config.relax_timeout) + self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) + + def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + if not self.config.is_relax_timeouts_enabled(): + return + + logging.info(f"***** MIGRATED --> {notification}") + # Node migration completed - reset the connection + # timeouts by providing -1 as the relax timeout + self.connection.update_current_socket_timeout(-1) + self.connection.update_tmp_settings(tmp_relax_timeout=-1) From 41a199e8232f622fa6f6ebb898d4043cb8293bea Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 11 Jul 2025 11:07:50 +0300 Subject: [PATCH 02/16] Adding sequence id to the maintenance push notifications. Adding unit tests for maintenance_events.py file --- redis/_parsers/base.py | 15 +- redis/maintenance_events.py | 136 ++++++-- tests/test_maintenance_events.py | 543 +++++++++++++++++++++++++++++++ 3 files changed, 665 insertions(+), 29 deletions(-) create mode 100644 tests/test_maintenance_events.py diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index a0f6af4ac2..aa5a6b0f12 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -205,14 +205,17 @@ def handle_push_response(self, response, **kwargs): if msg_type in _MOVING_MESSAGE: host, port = response[2].split(":") ttl = response[1] - notification = NodeMovingEvent(host, port, ttl) + id = 1 # Hardcoded value for sync parser + notification = NodeMovingEvent(id, host, port, ttl) return self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: ttl = response[1] - notification = NodeMigratingEvent(ttl) + id = 2 # Hardcoded value for sync parser + notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: - notification = NodeMigratedEvent() + id = 3 # Hardcoded value for sync parser + notification = NodeMigratedEvent(id) else: notification = None if notification is not None: @@ -260,16 +263,16 @@ async def handle_push_response(self, response, **kwargs): # push notification from enterprise cluster for node moving host, port = response[2].split(":") ttl = response[1] - id = 1 # TODO: get unique id from push notification + id = 1 # Hardcoded value for async parser notification = NodeMovingEvent(id, host, port, ttl) return await self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: ttl = response[1] - id = 1 # TODO: get unique id from push notification + id = 2 # Hardcoded value for async parser notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: - id = 1 # TODO: get unique id from push notification + id = 3 # Hardcoded value for async parser notification = NodeMigratedEvent(id) return await self.maintenance_push_handler_func(notification) diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index bbc519d0cc..d818a846b8 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -1,7 +1,8 @@ import logging import threading import time -from typing import TYPE_CHECKING +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional from redis.typing import Number @@ -9,27 +10,30 @@ from redis.connection import ConnectionInterface, ConnectionPool -class MaintenanceEvent: +class MaintenanceEvent(ABC): """ Base class for maintenance events sent through push messages by Redis server. - This class provides common TTL (Time-To-Live) functionality for all - maintenance events. + This class provides common functionality for all maintenance events including + unique identification and TTL (Time-To-Live) functionality. Attributes: + id (int): Unique identifier for this event ttl (int): Time-to-live in seconds for this notification creation_time (float): Timestamp when the notification was created/read """ - def __init__(self, ttl: int): + def __init__(self, id: int, ttl: int): """ - Initialize a new MaintenanceEvent with TTL functionality. + Initialize a new MaintenanceEvent with unique ID and TTL functionality. Args: + id (int): Unique identifier for this event ttl (int): Time-to-live in seconds for this notification """ + self.id = id self.ttl = ttl - self.creation_time = int(time.time()) + self.creation_time = time.monotonic() self.expire_at = self.creation_time + self.ttl def is_expired(self) -> bool: @@ -40,7 +44,49 @@ def is_expired(self) -> bool: Returns: bool: True if the event has expired, False otherwise """ - return int(time.time()) > (self.creation_time + self.ttl) + return time.monotonic() > (self.creation_time + self.ttl) + + @abstractmethod + def __repr__(self) -> str: + """ + Return a string representation of the maintenance event. + + This method must be implemented by all concrete subclasses. + + Returns: + str: String representation of the event + """ + pass + + @abstractmethod + def __eq__(self, other) -> bool: + """ + Compare two maintenance events for equality. + + This method must be implemented by all concrete subclasses. + Events are typically considered equal if they have the same id + and are of the same type. + + Args: + other: The other object to compare with + + Returns: + bool: True if the events are equal, False otherwise + """ + pass + + @abstractmethod + def __hash__(self) -> int: + """ + Return a hash value for the maintenance event. + + This method must be implemented by all concrete subclasses to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value for the event + """ + pass class NodeMovingEvent(MaintenanceEvent): @@ -49,25 +95,27 @@ class NodeMovingEvent(MaintenanceEvent): during cluster rebalancing or maintenance operations. """ - def __init__(self, new_node_host: str, new_node_port: int, ttl: int): + def __init__(self, id: int, new_node_host: str, new_node_port: int, ttl: int): """ Initialize a new NodeMovingEvent. Args: + id (int): Unique identifier for this event new_node_host (str): Hostname or IP address of the new replacement node new_node_port (int): Port number of the new replacement node ttl (int): Time-to-live in seconds for this notification """ - super().__init__(ttl) + super().__init__(id, ttl) self.new_node_host = new_node_host self.new_node_port = new_node_port def __repr__(self) -> str: expiry_time = self.expire_at - remaining = max(0, expiry_time - time.time()) + remaining = max(0, expiry_time - time.monotonic()) return ( f"{self.__class__.__name__}(" + f"id={self.id}, " f"new_node_host='{self.new_node_host}', " f"new_node_port={self.new_node_port}, " f"ttl={self.ttl}, " @@ -81,12 +129,13 @@ def __repr__(self) -> str: def __eq__(self, other) -> bool: """ Two NodeMovingEvent events are considered equal if they have the same - new_node_host and new_node_port. + id, new_node_host, and new_node_port. """ if not isinstance(other, NodeMovingEvent): return False return ( - self.new_node_host == other.new_node_host + self.id == other.id + and self.new_node_host == other.new_node_host and self.new_node_port == other.new_node_port ) @@ -96,9 +145,9 @@ def __hash__(self) -> int: instances to be used in sets and as dictionary keys. Returns: - int: Hash value based on new_node_host and new_node_port + int: Hash value based on event type, id, new_node_host, and new_node_port """ - return hash((self.__class__, self.new_node_host, self.new_node_port)) + return hash((self.__class__, self.id, self.new_node_host, self.new_node_port)) class NodeMigratingEvent(MaintenanceEvent): @@ -109,17 +158,19 @@ class NodeMigratingEvent(MaintenanceEvent): during cluster rebalancing or maintenance operations. Args: + id (int): Unique identifier for this event ttl (int): Time-to-live in seconds for this notification """ - def __init__(self, ttl: int): - super().__init__(ttl) + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) def __repr__(self) -> str: expiry_time = self.creation_time + self.ttl - remaining = max(0, expiry_time - time.time()) + remaining = max(0, expiry_time - time.monotonic()) return ( f"{self.__class__.__name__}(" + f"id={self.id}, " f"ttl={self.ttl}, " f"creation_time={self.creation_time}, " f"expires_at={expiry_time}, " @@ -128,6 +179,25 @@ def __repr__(self) -> str: f")" ) + def __eq__(self, other) -> bool: + """ + Two NodeMigratingEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratingEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + class NodeMigratedEvent(MaintenanceEvent): """ @@ -137,19 +207,20 @@ class NodeMigratedEvent(MaintenanceEvent): to other nodes during cluster rebalancing or maintenance operations. Args: - ttl (int): Time-to-live in seconds for this notification + id (int): Unique identifier for this event """ DEFAULT_TTL = 5 - def __init__(self): - super().__init__(NodeMigratedEvent.DEFAULT_TTL) + def __init__(self, id: int): + super().__init__(id, NodeMigratedEvent.DEFAULT_TTL) def __repr__(self) -> str: expiry_time = self.creation_time + self.ttl - remaining = max(0, expiry_time - time.time()) + remaining = max(0, expiry_time - time.monotonic()) return ( f"{self.__class__.__name__}(" + f"id={self.id}, " f"ttl={self.ttl}, " f"creation_time={self.creation_time}, " f"expires_at={expiry_time}, " @@ -158,6 +229,25 @@ def __repr__(self) -> str: f")" ) + def __eq__(self, other) -> bool: + """ + Two NodeMigratedEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratedEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + class MaintenanceEventsConfig: """ @@ -173,7 +263,7 @@ def __init__( self, enabled: bool = False, proactive_reconnect: bool = True, - relax_timeout: Number = 20, + relax_timeout: Optional[Number] = 20, ): """ Initialize a new MaintenanceEventsConfig. diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py new file mode 100644 index 0000000000..69a6014fe1 --- /dev/null +++ b/tests/test_maintenance_events.py @@ -0,0 +1,543 @@ +import threading +from unittest.mock import Mock, patch + +from redis.maintenance_events import ( + MaintenanceEvent, + NodeMovingEvent, + NodeMigratingEvent, + NodeMigratedEvent, + MaintenanceEventsConfig, + MaintenanceEventPoolHandler, + MaintenanceEventConnectionHandler, +) + + +class TestMaintenanceEvent: + """Test the base MaintenanceEvent class functionality through concrete subclasses.""" + + def test_abstract_class_cannot_be_instantiated(self): + """Test that MaintenanceEvent cannot be instantiated directly.""" + import pytest + + with patch("time.monotonic", return_value=1000): + with pytest.raises(TypeError): + MaintenanceEvent(id=1, ttl=10) # type: ignore + + def test_init_through_subclass(self): + """Test MaintenanceEvent initialization through concrete subclass.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.ttl == 10 + assert event.creation_time == 1000 + assert event.expire_at == 1010 + + def test_is_expired_false(self): + """Test is_expired returns False for non-expired event.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1005): # 5 seconds later + assert not event.is_expired() + + def test_is_expired_true(self): + """Test is_expired returns True for expired event.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1015): # 15 seconds later + assert event.is_expired() + + def test_is_expired_exact_boundary(self): + """Test is_expired at exact expiration boundary.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1010): # Exactly at expiration + assert not event.is_expired() + + with patch("time.monotonic", return_value=1011): # 1 second past expiration + assert event.is_expired() + + +class TestNodeMovingEvent: + """Test the NodeMovingEvent class.""" + + def test_init(self): + """Test NodeMovingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.new_node_host == "localhost" + assert event.new_node_port == 6379 + assert event.ttl == 10 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMovingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1005): # 5 seconds later + repr_str = repr(event) + assert "NodeMovingEvent" in repr_str + assert "id=1" in repr_str + assert "new_node_host='localhost'" in repr_str + assert "new_node_port=6379" in repr_str + assert "ttl=10" in repr_str + assert "remaining=5.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_same_id_host_port(self): + """Test equality for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert event1 == event2 + + def test_equality_same_id_different_host(self): + """Test inequality for events with same id but different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_same_id_different_port(self): + """Test inequality for events with same id but different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_id(self): + """Test inequality for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_type(self): + """Test inequality for events of different types.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMigratingEvent(id=1, ttl=10) + assert event1 != event2 + + def test_hash_same_id_host_port(self): + """Test hash consistency for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert hash(event1) == hash(event2) + + def test_hash_different_host(self): + """Test hash difference for events with different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_port(self): + """Test hash difference for events with different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_id(self): + """Test hash difference for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_set_functionality(self): + """Test that events can be used in sets correctly.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Same id, host, port - should be considered the same + event3 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6380, ttl=10 + ) # Same id but different host/port - should be different + event4 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) # Different id - should be different + + event_set = {event1, event2, event3, event4} + assert len(event_set) == 3 # event1 and event2 should be considered the same + + +class TestNodeMigratingEvent: + """Test the NodeMigratingEvent class.""" + + def test_init(self): + """Test NodeMigratingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMigratingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeMigratingEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratingEvent.""" + event1 = NodeMigratingEvent(id=1, ttl=5) + event2 = NodeMigratingEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeMigratingEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeMigratedEvent: + """Test the NodeMigratedEvent class.""" + + def test_init(self): + """Test NodeMigratedEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeMigratedEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeMigratedEvent.DEFAULT_TTL == 5 + event = NodeMigratedEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeMigratedEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeMigratedEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratedEvent.""" + event1 = NodeMigratedEvent(id=1) + event2 = NodeMigratedEvent(id=1) # Same id + event3 = NodeMigratedEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestMaintenanceEventsConfig: + """Test the MaintenanceEventsConfig class.""" + + def test_init_defaults(self): + """Test MaintenanceEventsConfig initialization with defaults.""" + config = MaintenanceEventsConfig() + assert config.enabled is False + assert config.proactive_reconnect is True + assert config.relax_timeout == 20 + + def test_init_custom_values(self): + """Test MaintenanceEventsConfig initialization with custom values.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + assert config.enabled is True + assert config.proactive_reconnect is False + assert config.relax_timeout == 30 + + def test_repr(self): + """Test MaintenanceEventsConfig string representation.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + repr_str = repr(config) + assert "MaintenanceEventsConfig" in repr_str + assert "enabled=True" in repr_str + assert "proactive_reconnect=False" in repr_str + assert "relax_timeout=30" in repr_str + + def test_is_relax_timeouts_enabled_true(self): + """Test is_relax_timeouts_enabled returns True for positive timeout.""" + config = MaintenanceEventsConfig(relax_timeout=20) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_false(self): + """Test is_relax_timeouts_enabled returns False for -1 timeout.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + assert config.is_relax_timeouts_enabled() is False + + def test_is_relax_timeouts_enabled_zero(self): + """Test is_relax_timeouts_enabled returns True for zero timeout.""" + config = MaintenanceEventsConfig(relax_timeout=0) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_none(self): + """Test is_relax_timeouts_enabled returns True for None timeout.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.is_relax_timeouts_enabled() is True + + def test_relax_timeout_none_is_saved_as_none(self): + """Test that None value for relax_timeout is saved as None.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.relax_timeout is None + + +class TestMaintenanceEventPoolHandler: + """Test the MaintenanceEventPoolHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_pool = Mock() + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=20 + ) + self.handler = MaintenanceEventPoolHandler(self.mock_pool, self.config) + + def test_init(self): + """Test MaintenanceEventPoolHandler initialization.""" + assert self.handler.pool == self.mock_pool + assert self.handler.config == self.config + assert isinstance(self.handler._processed_events, set) + assert isinstance(self.handler._lock, type(threading.RLock())) + + def test_remove_expired_notifications(self): + """Test removal of expired notifications.""" + with patch("time.monotonic", return_value=1000): + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="host2", new_node_port=6380, ttl=5 + ) + self.handler._processed_events.add(event1) + self.handler._processed_events.add(event2) + + # Move time forward but not enough to expire event2 (expires at 1005) + with patch("time.monotonic", return_value=1003): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 in self.handler._processed_events # Not expired yet + + # Move time forward to expire event2 but not event1 + with patch("time.monotonic", return_value=1006): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 not in self.handler._processed_events # Now expired + + def test_handle_event_node_moving(self): + """Test handling of NodeMovingEvent.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch.object(self.handler, "handle_node_moving_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMigratingEvent(id=1, ttl=5) # Not handled by pool handler + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_node_moving_event_disabled_config(self): + """Test node moving event handling when both features are disabled.""" + config = MaintenanceEventsConfig(proactive_reconnect=False, relax_timeout=-1) + handler = MaintenanceEventPoolHandler(self.mock_pool, config) + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = handler.handle_node_moving_event(event) + assert result is None + assert event not in handler._processed_events + + def test_handle_node_moving_event_already_processed(self): + """Test node moving event handling when event already processed.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.handler._processed_events.add(event) + + result = self.handler.handle_node_moving_event(event) + assert result is None + + def test_handle_node_moving_event_success(self): + """Test successful node moving event handling.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with ( + patch("threading.Timer") as mock_timer, + patch("time.monotonic", return_value=1000), + ): + self.handler.handle_node_moving_event(event) + + # Verify timer was started + mock_timer.assert_called_once_with( + event.ttl, self.handler.handle_node_moved_event + ) + mock_timer.return_value.start.assert_called_once() + + # Verify event was added to processed set + assert event in self.handler._processed_events + + # Verify pool methods were called + self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once() + + def test_handle_node_moved_event(self): + """Test handling of node moved event (cleanup).""" + self.handler.handle_node_moved_event() + + # Verify cleanup methods were called + self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once_with( + tmp_host_address=None, + tmp_relax_timeout=-1, + ) + + +class TestMaintenanceEventConnectionHandler: + """Test the MaintenanceEventConnectionHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_connection = Mock() + self.config = MaintenanceEventsConfig(enabled=True, relax_timeout=20) + self.handler = MaintenanceEventConnectionHandler( + self.mock_connection, self.config + ) + + def test_init(self): + """Test MaintenanceEventConnectionHandler initialization.""" + assert self.handler.connection == self.mock_connection + assert self.handler.config == self.config + + def test_handle_event_migrating(self): + """Test handling of NodeMigratingEvent.""" + event = NodeMigratingEvent(id=1, ttl=5) + + with patch.object(self.handler, "handle_migrating_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_migrated(self): + """Test handling of NodeMigratedEvent.""" + event = NodeMigratedEvent(id=1) + + with patch.object( + self.handler, "handle_migration_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_migrating_event_disabled(self): + """Test migrating event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratingEvent(id=1, ttl=5) + + result = handler.handle_migrating_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migrating_event_success(self): + """Test successful migrating event handling.""" + event = NodeMigratingEvent(id=1, ttl=5) + + self.handler.handle_migrating_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) + self.mock_connection.update_tmp_settings.assert_called_once_with( + tmp_relax_timeout=20 + ) + + def test_handle_migration_completed_event_disabled(self): + """Test migration completed event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratedEvent(id=1) + + result = handler.handle_migration_completed_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migration_completed_event_success(self): + """Test successful migration completed event handling.""" + event = NodeMigratedEvent(id=1) + + self.handler.handle_migration_completed_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) + self.mock_connection.update_tmp_settings.assert_called_once_with( + tmp_relax_timeout=-1 + ) From 63d0c45772a0ef28714f9e919b70ab9dfa8bedfe Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 11 Jul 2025 18:47:39 +0300 Subject: [PATCH 03/16] Adding integration-like tests for migrating/migrated events handling --- redis/connection.py | 14 + tests/test_maintenance_events_handling.py | 696 ++++++++++++++++++++++ 2 files changed, 710 insertions(+) create mode 100644 tests/test_maintenance_events_handling.py diff --git a/redis/connection.py b/redis/connection.py index f55e1b455c..7755472085 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -2209,3 +2209,17 @@ def _update_maintenance_events_config_for_connections( with self._lock: for conn in tuple(self._connections): conn.maintenance_events_config = maintenance_events_config + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + """Override base class method to work with BlockingConnectionPool's structure.""" + with self._lock: + for conn in tuple(self._connections): + if conn: # conn can be None in BlockingConnectionPool + conn.set_maintenance_event_pool_handler( + maintenance_events_pool_handler + ) + conn.maintenance_events_config = ( + maintenance_events_pool_handler.config + ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py new file mode 100644 index 0000000000..6413e24b6e --- /dev/null +++ b/tests/test_maintenance_events_handling.py @@ -0,0 +1,696 @@ +import socket +import threading +import select +from unittest.mock import Mock, patch +import pytest + +from redis import Redis +from redis.connection import ConnectionPool, BlockingConnectionPool +from redis.maintenance_events import ( + MaintenanceEventsConfig, + NodeMigratingEvent, + NodeMigratedEvent, + MaintenanceEventConnectionHandler, + MaintenanceEventPoolHandler, +) + + +class MockSocket: + """Mock socket that simulates Redis protocol responses.""" + + def __init__(self): + self.connected = False + self.sent_data = [] + self.response_queue = [] + self.closed = False + self.command_count = 0 + self.pending_responses = [] + self.current_response_index = 0 + # Track socket timeout changes for maintenance events validation + self.timeout = None + self.thread_timeouts = {} # Track last applied timeout per thread + + def connect(self, address): + """Simulate socket connection.""" + self.connected = True + + def send(self, data): + """Simulate sending data to Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + self.sent_data.append(data) + + # Analyze the command and prepare appropriate response + if b"HELLO" in data: + response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" + self.pending_responses.append(response) + elif b"SET" in data: + response = b"+OK\r\n" + + # Check if this is a key that should trigger a push message + if b"key_receive_migrating_" in data: + # MIGRATING push message before SET key_receive_migrating_X response + # Format: >2\r\n$9\r\nMIGRATING\r\n:10\r\n (2 elements: MIGRATING, ttl) + migrating_push = ">2\r\n$9\r\nMIGRATING\r\n:10\r\n" + response = migrating_push.encode() + response + elif b"key_receive_migrated_" in data: + # MIGRATED push message before SET key_receive_migrated_X response + # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) + migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" + response = migrated_push.encode() + response + + self.pending_responses.append(response) + elif b"GET" in data: + # Extract key and provide appropriate response + if b"hello" in data: + response = b"$5\r\nworld\r\n" + self.pending_responses.append(response) + # Handle thread-specific keys for integration test first (more specific) + elif b"key1_0" in data: + self.pending_responses.append(b"$8\r\nvalue1_0\r\n") + elif b"key_receive_migrating_0" in data: + self.pending_responses.append(b"$8\r\nvalue2_0\r\n") + elif b"key1_1" in data: + self.pending_responses.append(b"$8\r\nvalue1_1\r\n") + elif b"key_receive_migrating_1" in data: + self.pending_responses.append(b"$8\r\nvalue2_1\r\n") + elif b"key1_2" in data: + self.pending_responses.append(b"$8\r\nvalue1_2\r\n") + elif b"key_receive_migrating_2" in data: + self.pending_responses.append(b"$8\r\nvalue2_2\r\n") + # Generic keys (less specific, should come after thread-specific) + elif b"key0" in data: + self.pending_responses.append(b"$6\r\nvalue0\r\n") + elif b"key1" in data: + self.pending_responses.append(b"$6\r\nvalue1\r\n") + elif b"key2" in data: + self.pending_responses.append(b"$6\r\nvalue2\r\n") + else: + self.pending_responses.append(b"$-1\r\n") # NULL response + else: + self.pending_responses.append(b"+OK\r\n") # Default response + + self.command_count += 1 + return len(data) + + def sendall(self, data): + """Simulate sending all data to Redis.""" + return self.send(data) + + def recv(self, bufsize): + """Simulate receiving data from Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + if self.response_queue: + response = self.response_queue.pop(0) + return response[:bufsize] # Respect buffer size + + # Use pending responses that were prepared when commands were sent + if self.pending_responses: + response = self.pending_responses.pop(0) + return response[:bufsize] # Respect buffer size + else: + # No data available - this should block or raise an exception + # For can_read checks, we should indicate no data is available + import errno + + raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + + def fileno(self): + """Return a fake file descriptor for select/poll operations.""" + return 1 # Fake file descriptor + + def close(self): + """Simulate closing the socket.""" + self.closed = True + self.connected = False + + def settimeout(self, timeout): + """Simulate setting socket timeout and track changes per thread.""" + self.timeout = timeout + + # Track last applied timeout per thread + thread_id = threading.current_thread().ident + self.thread_timeouts[thread_id] = timeout + + def setsockopt(self, level, optname, value): + """Simulate setting socket options.""" + pass + + def getpeername(self): + """Simulate getting peer name.""" + return ("127.0.0.1", 6379) + + def getsockname(self): + """Simulate getting socket name.""" + return ("127.0.0.1", 12345) + + def shutdown(self, how): + """Simulate socket shutdown.""" + pass + + +class TestMaintenanceEventsHandling: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if ( + hasattr(sock, "pending_responses") and sock.pending_responses + ) or (hasattr(sock, "response_queue") and sock.response_queue): + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + # Create connection pool with maintenance events (requires RESP3) + self.pool = ConnectionPool( + host="localhost", + port=6379, + max_connections=10, # Increased for multi-threaded tests + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + # Create Redis client + self.redis_client = Redis(connection_pool=self.pool) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + if hasattr(self.pool, "disconnect"): + self.pool.disconnect() + + def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): + """Helper method to validate the current timeout for the calling thread.""" + current_thread_id = threading.current_thread().ident + actual_timeout = None + for sock in self.mock_sockets: + if current_thread_id in sock.thread_timeouts: + actual_timeout = sock.thread_timeouts[current_thread_id] + break + + assert actual_timeout == expected_timeout, ( + f"Thread {thread_id}: Expected timeout ({expected_timeout}), " + f"but found timeout: {actual_timeout} for thread {current_thread_id}. " + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}" + ) + + def test_connection_pool_creation_with_maintenance_events(self): + """Test that connection pool is created with maintenance events configuration.""" + assert ( + self.pool.connection_kwargs.get("maintenance_events_config") == self.config + ) + # Pool should have maintenance events enabled + assert self.pool.maintenance_events_pool_handler_enabled() is True + + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) + self.pool.set_maintenance_events_pool_handler(pool_handler) + + # Validate that the handler is properly set on the pool + assert ( + self.pool.connection_kwargs.get("maintenance_events_pool_handler") + == pool_handler + ) + assert ( + self.pool.connection_kwargs.get("maintenance_events_config") + == pool_handler.config + ) + + # Verify that the pool handler has the correct configuration + assert pool_handler.pool == self.pool + assert pool_handler.config == self.config + + def test_blocking_connection_pool_creation_with_maintenance_events(self): + """Test that BlockingConnectionPool is created with maintenance events configuration.""" + # Create blocking connection pool with maintenance events (requires RESP3) + blocking_pool = BlockingConnectionPool( + host="localhost", + port=6379, + max_connections=3, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + try: + assert ( + blocking_pool.connection_kwargs.get("maintenance_events_config") + == self.config + ) + # Pool should have maintenance events enabled + assert blocking_pool.maintenance_events_pool_handler_enabled() is True + + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(blocking_pool, self.config) + blocking_pool.set_maintenance_events_pool_handler(pool_handler) + + # Validate that the handler is properly set on the blocking pool + assert ( + blocking_pool.connection_kwargs.get("maintenance_events_pool_handler") + == pool_handler + ) + assert ( + blocking_pool.connection_kwargs.get("maintenance_events_config") + == pool_handler.config + ) + + # Verify that the pool handler has the correct configuration + assert pool_handler.pool == blocking_pool + assert pool_handler.config == self.config + + finally: + if hasattr(blocking_pool, "disconnect"): + blocking_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_redis_operations_with_mock_sockets(self, pool_class): + """ + Test basic Redis operations work with mocked sockets and proper response parsing. + Basically with test - the mocked socket is validated. + """ + # Create a pool of the specified type with maintenance events + test_pool = pool_class( + host="localhost", + port=6379, + max_connections=5, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + try: + # Create Redis client with the test pool + test_redis_client = Redis(connection_pool=test_pool) + + # Perform Redis operations that should work with our improved mock responses + result_set = test_redis_client.set("hello", "world") + result_get = test_redis_client.get("hello") + + # Verify operations completed successfully + assert result_set is True + assert result_get == b"world" + + # Verify socket interactions + assert len(self.mock_sockets) >= 1 + assert self.mock_sockets[0].connected + assert len(self.mock_sockets[0].sent_data) >= 2 # HELLO, SET, GET commands + + # Verify that the connection has maintenance event handler + connection = test_pool.get_connection() + assert hasattr(connection, "_maintenance_event_connection_handler") + test_pool.release(connection) + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_multiple_connections_in_pool(self, pool_class): + """Test that multiple connections can be created and used for Redis operations in multiple threads.""" + # Create a pool of the specified type with maintenance events + test_pool = pool_class( + host="localhost", + port=6379, + max_connections=5, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + try: + # Create Redis client with the test pool + test_redis_client = Redis(connection_pool=test_pool) + + # Results storage for thread operations + results = [] + errors = [] + + def redis_operation(key_suffix): + """Perform Redis operations in a thread.""" + try: + # SET operation + set_result = test_redis_client.set( + f"key{key_suffix}", f"value{key_suffix}" + ) + # GET operation + get_result = test_redis_client.get(f"key{key_suffix}") + results.append((set_result, get_result)) + except Exception as e: + errors.append(e) + + # Run operations in multiple threads to force multiple connections + threads = [] + for i in range(3): + thread = threading.Thread(target=redis_operation, args=(i,)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" + + # Verify all operations completed successfully + assert len(results) == 3 + for set_result, get_result in results: + assert set_result is True + assert get_result in [b"value0", b"value1", b"value2"] + + # Verify that multiple connections were created with mock sockets + # With threading, both pool types should create multiple sockets for concurrent access + assert len(self.mock_sockets) >= 2, ( + f"Expected multiple sockets due to threading, got {len(self.mock_sockets)}" + ) + + # Verify each connection has maintenance event handler + connection = test_pool.get_connection() + assert hasattr(connection, "_maintenance_event_connection_handler") + test_pool.release(connection) + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migration_related_events_handling_integration(self, pool_class): + """ + Test full integration of migration-related events (MIGRATING/MIGRATED) handling with multiple threads and commands. + + This test validates the complete migration lifecycle: + 1. Creates 3 concurrent threads, each executing 5 Redis commands + 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating_X) + 3. Validates socket timeout is updated to relaxed value (30s) after MIGRATING + 4. Executes commands 3-4 while timeout remains relaxed + 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated_X) + 6. Validates socket timeout is restored after MIGRATED + 7. Tests both ConnectionPool and BlockingConnectionPool implementations + 8. Uses proper RESP3 push message format for realistic protocol simulation + """ + # Create a pool of the specified type with maintenance events + test_pool = pool_class( + host="localhost", + port=6379, + max_connections=10, # Increased for multi-threaded tests + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + try: + # Create Redis client with the test pool + test_redis_client = Redis(connection_pool=test_pool) + + # Results storage for thread operations + results = [] + errors = [] + + def redis_operations_with_maintenance_events(thread_id): + """Perform Redis operations with maintenance events in a thread.""" + try: + # Command 1: Initial command + result1 = test_redis_client.set( + f"key1_{thread_id}", f"value1_{thread_id}" + ) + + # Validate Command 1 result + assert result1 is True, ( + f"Thread {thread_id}: Command 1 (SET key1) failed" + ) + + # Command 2: This SET command will receive MIGRATING push message before response + result2 = test_redis_client.set( + f"key_receive_migrating_{thread_id}", f"value2_{thread_id}" + ) + + # Validate Command 2 result + assert result2 is True, ( + f"Thread {thread_id}: Command 2 (SET key2) failed" + ) + + # Step 4: Validate timeout was updated to relaxed value after MIGRATING + self._validate_current_timeout_for_thread(thread_id, 30) + + # Command 3: Another command while timeout is still relaxed + result3 = test_redis_client.get(f"key1_{thread_id}") + + # Validate Command 3 result + expected_value3 = f"value1_{thread_id}".encode() + assert result3 == expected_value3, ( + f"Thread {thread_id}: Command 3 (GET key1) failed. " + f"Expected {expected_value3}, got {result3}" + ) + + # Command 4: Execute command (step 5) + result4 = test_redis_client.get( + f"key_receive_migrating_{thread_id}" + ) + + # Validate Command 4 result + expected_value4 = f"value2_{thread_id}".encode() + assert result4 == expected_value4, ( + f"Thread {thread_id}: Command 4 (GET key_receive_migrating) failed. " + f"Expected {expected_value4}, got {result4}" + ) + + # Step 6: Validate socket timeout is still relaxed during commands 3-4 + self._validate_current_timeout_for_thread(thread_id, 30) + + # Command 5: This SET command will receive + # MIGRATED push message before actual response + result5 = test_redis_client.set( + f"key_receive_migrated_{thread_id}", f"value3_{thread_id}" + ) + + # Validate Command 5 result + assert result5 is True, ( + f"Thread {thread_id}: Command 5 (SET key_receive_migrated) failed" + ) + + # Step 8: Validate socket timeout is reversed back to original after MIGRATED + self._validate_current_timeout_for_thread(thread_id, None) + + results.append( + { + "thread_id": thread_id, + "success": True, + } + ) + + except Exception as e: + errors.append(f"Thread {thread_id}: {e}") + + # Run operations in multiple threads (step 1) + threads = [] + for i in range(3): + thread = threading.Thread( + target=redis_operations_with_maintenance_events, + args=(i,), + name=str(i), + ) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify all threads completed successfully + successful_threads = len(results) + assert successful_threads == 3, ( + f"Expected 3 successful threads, got {successful_threads}. " + f"Errors: {errors}" + ) + + # Verify maintenance events were processed correctly across all threads + # Note: Different pool types may create different numbers of sockets + # The key is that we have at least 1 socket and all threads succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + def test_migrating_event_with_disabled_relax_timeout(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test migrating event handling when relax timeout is disabled.""" + # Create config with disabled relax timeout + disabled_config = MaintenanceEventsConfig( + enabled=True, + relax_timeout=-1, # Disabled + ) + + # Create new pool with disabled config + disabled_pool = ConnectionPool( + host="localhost", + port=6379, + protocol=3, # Required for maintenance events + maintenance_events_config=disabled_config, + ) + + try: + # Get a connection + connection = disabled_pool.get_connection() + + # Mock the connection's timeout update methods + connection.update_current_socket_timeout = Mock() + connection.update_tmp_settings = Mock() + + # Create and handle migrating event + migrating_event = NodeMigratingEvent(id=1, ttl=10) + result = connection._maintenance_event_connection_handler.handle_event( + migrating_event + ) + + # Verify that no timeout updates were made (relax is disabled) + assert result is None + connection.update_current_socket_timeout.assert_not_called() + connection.update_tmp_settings.assert_not_called() + + finally: + if hasattr(disabled_pool, "disconnect"): + disabled_pool.disconnect() + + def test_pool_handler_with_migrating_event(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test that pool handler correctly handles migrating events.""" + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) + + # Create a migrating event (not handled by pool handler) + migrating_event = NodeMigratingEvent(id=1, ttl=5) + + # Pool handler should return None for migrating events (not its responsibility) + result = pool_handler.handle_event(migrating_event) + assert result is None + + def test_connection_timeout_restoration_after_event(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test that connection timeout is properly restored after maintenance event.""" + # Establish connection + self.redis_client.set("test", "value") + + connection = self.pool.get_connection() + + # Mock timeout methods + connection.update_current_socket_timeout = Mock() + connection.update_tmp_settings = Mock() + + # Simulate migrating event + migrating_event = NodeMigratingEvent(id=1, ttl=5) + connection._maintenance_event_connection_handler.handle_migrating_event( + migrating_event + ) + + # Verify relax timeout was applied + connection.update_current_socket_timeout.assert_called_with(30) + connection.update_tmp_settings.assert_called_with(tmp_relax_timeout=30) + + # Reset mocks + connection.update_current_socket_timeout.reset_mock() + connection.update_tmp_settings.reset_mock() + + # Simulate migration completed event + from redis.maintenance_events import NodeMigratedEvent + + migrated_event = NodeMigratedEvent(id=1) + connection._maintenance_event_connection_handler.handle_migration_completed_event( + migrated_event + ) + + # Verify timeout was restored + connection.update_current_socket_timeout.assert_called_with( + -1 + ) # Restore original + connection.update_tmp_settings.assert_called_with(tmp_relax_timeout=-1) + + self.pool.release(connection) + + def test_socket_error_handling_during_operations(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test that socket errors are properly handled during Redis operations.""" + # Create a connection first to ensure we have a mock socket + connection = self.pool.get_connection() + + # Set up a socket that will fail + if self.mock_sockets: + self.mock_sockets[0].closed = True + + # Attempt Redis operation that should fail due to closed socket + with pytest.raises( + (ConnectionError, OSError, Exception) + ): # Should raise connection-related exception + # Try to use the connection with a closed socket + connection.send_command("PING") + + # Release the connection + self.pool.release(connection) + + def test_maintenance_events_with_concurrent_operations(self): + # TODO Not yet reviewed and validated - just vipecoded + """Test maintenance events handling with concurrent Redis operations.""" + + # Perform concurrent operations + def redis_operation(key_suffix): + try: + return self.redis_client.set( + f"concurrent_key_{key_suffix}", f"value_{key_suffix}" + ) + except Exception: + return False + + # Simulate concurrent operations + threads = [] + results = [] + + for i in range(3): + thread = threading.Thread( + target=lambda i=i: results.append(redis_operation(i)) + ) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # During concurrent operations, simulate a maintenance event + if self.pool.connection_kwargs.get("maintenance_events_config"): + migrating_event = NodeMigratingEvent(id=1, ttl=5) + # Create a pool handler to test event handling + pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) + result = pool_handler.handle_event(migrating_event) + assert result is None # Pool handler doesn't handle migrating events + + # Verify that some operations completed successfully + # (Some might fail due to mock socket limitations, but that's expected) + assert len(results) == 3 From 5c7173373e68d29430ed92b208fa97ebe5d4fcd7 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 11 Jul 2025 18:49:22 +0300 Subject: [PATCH 04/16] Removed unused imports --- tests/test_maintenance_events_handling.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 6413e24b6e..80590840a0 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1,6 +1,5 @@ import socket import threading -import select from unittest.mock import Mock, patch import pytest @@ -9,8 +8,6 @@ from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, - NodeMigratedEvent, - MaintenanceEventConnectionHandler, MaintenanceEventPoolHandler, ) From 96c6e5d442b69eb6739c0501094e27efab68a179 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 11 Jul 2025 19:06:37 +0300 Subject: [PATCH 05/16] Revert changing of the default retry object initialization for connection pool - this should be a separate PR --- redis/connection.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 7755472085..9a434848ca 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -24,7 +24,7 @@ from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .auth.token import TokenInterface -from .backoff import ExponentialWithJitterBackoff +from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .event import AfterConnectionReleasedEvent, EventDispatcher from .exceptions import ( @@ -323,15 +323,16 @@ def __init__( # Add TimeoutError to the errors list to retry on retry_on_error.append(TimeoutError) self.retry_on_error = retry_on_error - if retry is None: - self.retry = Retry( - backoff=ExponentialWithJitterBackoff(base=1, cap=10), retries=3 - ) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - if retry_on_error: + if retry or retry_on_error: + if retry is None: + self.retry = Retry(NoBackoff(), 1) + else: + # deep-copy the Retry object as it is mutable + self.retry = copy.deepcopy(retry) + # Update the retry's supported errors with the specified errors self.retry.update_supported_errors(retry_on_error) + else: + self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check = 0 self.redis_connect_func = redis_connect_func From 8691475d5f7b31e2c5836099f1d29013329d03da Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Mon, 14 Jul 2025 15:41:47 +0300 Subject: [PATCH 06/16] Complete migrating/migrated integration-like tests --- tests/test_maintenance_events_handling.py | 444 ++++++++++------------ 1 file changed, 195 insertions(+), 249 deletions(-) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 80590840a0..6a687da1b0 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -17,6 +17,7 @@ class MockSocket: def __init__(self): self.connected = False + self.address = None self.sent_data = [] self.response_queue = [] self.closed = False @@ -30,6 +31,7 @@ def __init__(self): def connect(self, address): """Simulate socket connection.""" self.connected = True + self.address = address def send(self, data): """Simulate sending data to Redis.""" @@ -187,24 +189,40 @@ def mock_select(rlist, wlist, xlist, timeout=0): enabled=True, proactive_reconnect=True, relax_timeout=30 ) - # Create connection pool with maintenance events (requires RESP3) - self.pool = ConnectionPool( - host="localhost", - port=6379, - max_connections=10, # Increased for multi-threaded tests - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) - - # Create Redis client - self.redis_client = Redis(connection_pool=self.pool) - def teardown_method(self): """Clean up test fixtures.""" self.socket_patcher.stop() self.select_patcher.stop() - if hasattr(self.pool, "disconnect"): - self.pool.disconnect() + + def _get_client( + self, pool_class, max_connections=10, maintenance_events_config=None + ): + """Helper method to create a pool and Redis client with maintenance events configuration. + + Args: + pool_class: The connection pool class (ConnectionPool or BlockingConnectionPool) + max_connections: Maximum number of connections in the pool (default: 10) + maintenance_events_config: Optional MaintenanceEventsConfig to use. If not provided, + uses self.config from setup_method (default: None) + + Returns: + tuple: (test_pool, test_redis_client) + """ + config = ( + maintenance_events_config + if maintenance_events_config is not None + else self.config + ) + + test_pool = pool_class( + host="localhost", + port=6379, + max_connections=max_connections, + protocol=3, # Required for maintenance events + maintenance_events_config=config, + ) + test_redis_client = Redis(connection_pool=test_pool) + return test_pool, test_redis_client def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): """Helper method to validate the current timeout for the calling thread.""" @@ -221,72 +239,42 @@ def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}" ) - def test_connection_pool_creation_with_maintenance_events(self): - """Test that connection pool is created with maintenance events configuration.""" - assert ( - self.pool.connection_kwargs.get("maintenance_events_config") == self.config - ) - # Pool should have maintenance events enabled - assert self.pool.maintenance_events_pool_handler_enabled() is True - - # Create and set a pool handler - pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) - self.pool.set_maintenance_events_pool_handler(pool_handler) - - # Validate that the handler is properly set on the pool - assert ( - self.pool.connection_kwargs.get("maintenance_events_pool_handler") - == pool_handler - ) - assert ( - self.pool.connection_kwargs.get("maintenance_events_config") - == pool_handler.config - ) - - # Verify that the pool handler has the correct configuration - assert pool_handler.pool == self.pool - assert pool_handler.config == self.config - - def test_blocking_connection_pool_creation_with_maintenance_events(self): - """Test that BlockingConnectionPool is created with maintenance events configuration.""" - # Create blocking connection pool with maintenance events (requires RESP3) - blocking_pool = BlockingConnectionPool( - host="localhost", - port=6379, - max_connections=3, - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_connection_pool_creation_with_maintenance_events(self, pool_class): + """Test that connection pools are created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + max_connections = 3 if pool_class == BlockingConnectionPool else 10 + test_pool, _ = self._get_client(pool_class, max_connections=max_connections) try: assert ( - blocking_pool.connection_kwargs.get("maintenance_events_config") + test_pool.connection_kwargs.get("maintenance_events_config") == self.config ) # Pool should have maintenance events enabled - assert blocking_pool.maintenance_events_pool_handler_enabled() is True + assert test_pool.maintenance_events_pool_handler_enabled() is True # Create and set a pool handler - pool_handler = MaintenanceEventPoolHandler(blocking_pool, self.config) - blocking_pool.set_maintenance_events_pool_handler(pool_handler) + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + test_pool.set_maintenance_events_pool_handler(pool_handler) - # Validate that the handler is properly set on the blocking pool + # Validate that the handler is properly set on the pool assert ( - blocking_pool.connection_kwargs.get("maintenance_events_pool_handler") + test_pool.connection_kwargs.get("maintenance_events_pool_handler") == pool_handler ) assert ( - blocking_pool.connection_kwargs.get("maintenance_events_config") + test_pool.connection_kwargs.get("maintenance_events_config") == pool_handler.config ) # Verify that the pool handler has the correct configuration - assert pool_handler.pool == blocking_pool + assert pool_handler.pool == test_pool assert pool_handler.config == self.config finally: - if hasattr(blocking_pool, "disconnect"): - blocking_pool.disconnect() + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_redis_operations_with_mock_sockets(self, pool_class): @@ -294,19 +282,10 @@ def test_redis_operations_with_mock_sockets(self, pool_class): Test basic Redis operations work with mocked sockets and proper response parsing. Basically with test - the mocked socket is validated. """ - # Create a pool of the specified type with maintenance events - test_pool = pool_class( - host="localhost", - port=6379, - max_connections=5, - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) + # Create a pool and Redis client with maintenance events + test_pool, test_redis_client = self._get_client(pool_class, max_connections=5) try: - # Create Redis client with the test pool - test_redis_client = Redis(connection_pool=test_pool) - # Perform Redis operations that should work with our improved mock responses result_set = test_redis_client.set("hello", "world") result_get = test_redis_client.get("hello") @@ -332,19 +311,10 @@ def test_redis_operations_with_mock_sockets(self, pool_class): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_multiple_connections_in_pool(self, pool_class): """Test that multiple connections can be created and used for Redis operations in multiple threads.""" - # Create a pool of the specified type with maintenance events - test_pool = pool_class( - host="localhost", - port=6379, - max_connections=5, - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) + # Create a pool and Redis client with maintenance events + test_pool, test_redis_client = self._get_client(pool_class, max_connections=5) try: - # Create Redis client with the test pool - test_redis_client = Redis(connection_pool=test_pool) - # Results storage for thread operations results = [] errors = [] @@ -397,6 +367,44 @@ def redis_operation(key_suffix): if hasattr(test_pool, "disconnect"): test_pool.disconnect() + def test_pool_handler_with_migrating_event(self): + """Test that pool handler correctly handles migrating events.""" + # Create a pool and Redis client with maintenance events + test_pool, _ = self._get_client(ConnectionPool) + + try: + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + + # Create a migrating event (not handled by pool handler) + migrating_event = NodeMigratingEvent(id=1, ttl=5) + + # Mock the required functions + with ( + patch.object( + pool_handler, "remove_expired_notifications" + ) as mock_remove_expired, + patch.object( + pool_handler, "handle_node_moving_event" + ) as mock_handle_moving, + patch("redis.maintenance_events.logging.error") as mock_logging_error, + ): + # Pool handler should return None for migrating events (not its responsibility) + pool_handler.handle_event(migrating_event) + + # Validate that remove_expired_notifications has been called once + mock_remove_expired.assert_called_once() + + # Validate that handle_node_moving_event hasn't been called + mock_handle_moving.assert_not_called() + + # Validate that logging.error has been called once + mock_logging_error.assert_called_once() + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migration_related_events_handling_integration(self, pool_class): """ @@ -412,19 +420,10 @@ def test_migration_related_events_handling_integration(self, pool_class): 7. Tests both ConnectionPool and BlockingConnectionPool implementations 8. Uses proper RESP3 push message format for realistic protocol simulation """ - # Create a pool of the specified type with maintenance events - test_pool = pool_class( - host="localhost", - port=6379, - max_connections=10, # Increased for multi-threaded tests - protocol=3, # Required for maintenance events - maintenance_events_config=self.config, - ) + # Create a pool and Redis client with maintenance events + test_pool, test_redis_client = self._get_client(pool_class, max_connections=10) try: - # Create Redis client with the test pool - test_redis_client = Redis(connection_pool=test_pool) - # Results storage for thread operations results = [] errors = [] @@ -433,63 +432,60 @@ def redis_operations_with_maintenance_events(thread_id): """Perform Redis operations with maintenance events in a thread.""" try: # Command 1: Initial command - result1 = test_redis_client.set( - f"key1_{thread_id}", f"value1_{thread_id}" - ) + key1 = f"key1_{thread_id}" + value1 = f"value1_{thread_id}" + result1 = test_redis_client.set(key1, value1) # Validate Command 1 result - assert result1 is True, ( - f"Thread {thread_id}: Command 1 (SET key1) failed" - ) + erros_msg = f"Thread {thread_id}: Command 1 (SET key1) failed" + assert result1 is True, erros_msg # Command 2: This SET command will receive MIGRATING push message before response - result2 = test_redis_client.set( - f"key_receive_migrating_{thread_id}", f"value2_{thread_id}" - ) + key_migrating = f"key_receive_migrating_{thread_id}" + value_migrating = f"value2_{thread_id}" + result2 = test_redis_client.set(key_migrating, value_migrating) # Validate Command 2 result - assert result2 is True, ( - f"Thread {thread_id}: Command 2 (SET key2) failed" - ) + erros_msg = f"Thread {thread_id}: Command 2 (SET key_receive_migrating) failed" + assert result2 is True, erros_msg # Step 4: Validate timeout was updated to relaxed value after MIGRATING self._validate_current_timeout_for_thread(thread_id, 30) # Command 3: Another command while timeout is still relaxed - result3 = test_redis_client.get(f"key1_{thread_id}") + result3 = test_redis_client.get(key1) # Validate Command 3 result - expected_value3 = f"value1_{thread_id}".encode() - assert result3 == expected_value3, ( + expected_value3 = value1.encode() + errors_msg = ( f"Thread {thread_id}: Command 3 (GET key1) failed. " f"Expected {expected_value3}, got {result3}" ) + assert result3 == expected_value3, errors_msg # Command 4: Execute command (step 5) - result4 = test_redis_client.get( - f"key_receive_migrating_{thread_id}" - ) + result4 = test_redis_client.get(key_migrating) # Validate Command 4 result - expected_value4 = f"value2_{thread_id}".encode() - assert result4 == expected_value4, ( + expected_value4 = value_migrating.encode() + errors_msg = ( f"Thread {thread_id}: Command 4 (GET key_receive_migrating) failed. " f"Expected {expected_value4}, got {result4}" ) + assert result4 == expected_value4, errors_msg # Step 6: Validate socket timeout is still relaxed during commands 3-4 self._validate_current_timeout_for_thread(thread_id, 30) # Command 5: This SET command will receive # MIGRATED push message before actual response - result5 = test_redis_client.set( - f"key_receive_migrated_{thread_id}", f"value3_{thread_id}" - ) + key_migrated = f"key_receive_migrated_{thread_id}" + value_migrated = f"value3_{thread_id}" + result5 = test_redis_client.set(key_migrated, value_migrated) # Validate Command 5 result - assert result5 is True, ( - f"Thread {thread_id}: Command 5 (SET key_receive_migrated) failed" - ) + errors_msg = f"Thread {thread_id}: Command 5 (SET key_receive_migrated) failed" + assert result5 is True, errors_msg # Step 8: Validate socket timeout is reversed back to original after MIGRATED self._validate_current_timeout_for_thread(thread_id, None) @@ -537,157 +533,107 @@ def redis_operations_with_maintenance_events(thread_id): if hasattr(test_pool, "disconnect"): test_pool.disconnect() - def test_migrating_event_with_disabled_relax_timeout(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test migrating event handling when relax timeout is disabled.""" + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_event_with_disabled_relax_timeout(self, pool_class): + """ + Test migrating event handling when relax timeout is disabled. + + This test validates that when relax_timeout is disabled (-1): + 1. MIGRATING events are received and processed + 2. No timeout updates are applied to connections + 3. Socket timeouts remain unchanged during migration events + 4. Tests both ConnectionPool and BlockingConnectionPool implementations + """ # Create config with disabled relax timeout disabled_config = MaintenanceEventsConfig( enabled=True, - relax_timeout=-1, # Disabled + relax_timeout=-1, # This means the relax timeout is Disabled ) - # Create new pool with disabled config - disabled_pool = ConnectionPool( - host="localhost", - port=6379, - protocol=3, # Required for maintenance events - maintenance_events_config=disabled_config, + # Create a pool and Redis client with disabled relax timeout config + test_pool, test_redis_client = self._get_client( + pool_class, max_connections=5, maintenance_events_config=disabled_config ) try: - # Get a connection - connection = disabled_pool.get_connection() - - # Mock the connection's timeout update methods - connection.update_current_socket_timeout = Mock() - connection.update_tmp_settings = Mock() - - # Create and handle migrating event - migrating_event = NodeMigratingEvent(id=1, ttl=10) - result = connection._maintenance_event_connection_handler.handle_event( - migrating_event - ) - - # Verify that no timeout updates were made (relax is disabled) - assert result is None - connection.update_current_socket_timeout.assert_not_called() - connection.update_tmp_settings.assert_not_called() - - finally: - if hasattr(disabled_pool, "disconnect"): - disabled_pool.disconnect() - - def test_pool_handler_with_migrating_event(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test that pool handler correctly handles migrating events.""" - # Create and set a pool handler - pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) - - # Create a migrating event (not handled by pool handler) - migrating_event = NodeMigratingEvent(id=1, ttl=5) + # Results storage for thread operations + results = [] + errors = [] - # Pool handler should return None for migrating events (not its responsibility) - result = pool_handler.handle_event(migrating_event) - assert result is None + def redis_operations_with_disabled_relax(thread_id): + """Perform Redis operations with disabled relax timeout in a thread.""" + try: + # Command 1: Initial command + key1 = f"key1_{thread_id}" + value1 = f"value1_{thread_id}" + result1 = test_redis_client.set(key1, value1) - def test_connection_timeout_restoration_after_event(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test that connection timeout is properly restored after maintenance event.""" - # Establish connection - self.redis_client.set("test", "value") + # Validate Command 1 result + errors_msg = f"Thread {thread_id}: Command 1 (SET key1) failed" + assert result1 is True, errors_msg - connection = self.pool.get_connection() + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = f"key_receive_migrating_{thread_id}" + value_migrating = f"value2_{thread_id}" + result2 = test_redis_client.set(key_migrating, value_migrating) - # Mock timeout methods - connection.update_current_socket_timeout = Mock() - connection.update_tmp_settings = Mock() + # Validate Command 2 result + errors_msg = f"Thread {thread_id}: Command 2 (SET key_receive_migrating) failed" + assert result2 is True, errors_msg - # Simulate migrating event - migrating_event = NodeMigratingEvent(id=1, ttl=5) - connection._maintenance_event_connection_handler.handle_migrating_event( - migrating_event - ) + # Validate timeout was NOT updated (relax is disabled) + # Should remain at default timeout (None), not relaxed to 30s + self._validate_current_timeout_for_thread(thread_id, None) - # Verify relax timeout was applied - connection.update_current_socket_timeout.assert_called_with(30) - connection.update_tmp_settings.assert_called_with(tmp_relax_timeout=30) + # Command 3: Another command to verify timeout remains unchanged + result3 = test_redis_client.get(key1) - # Reset mocks - connection.update_current_socket_timeout.reset_mock() - connection.update_tmp_settings.reset_mock() + # Validate Command 3 result + expected_value3 = value1.encode() + errors_msg = ( + f"Thread {thread_id}: Command 3 (GET key1) failed. " + f"Expected: {expected_value3}, Got: {result3}" + ) + assert result3 == expected_value3, errors_msg - # Simulate migration completed event - from redis.maintenance_events import NodeMigratedEvent + results.append( + { + "thread_id": thread_id, + "success": True, + } + ) - migrated_event = NodeMigratedEvent(id=1) - connection._maintenance_event_connection_handler.handle_migration_completed_event( - migrated_event - ) + except Exception as e: + errors.append(f"Thread {thread_id}: {str(e)}") - # Verify timeout was restored - connection.update_current_socket_timeout.assert_called_with( - -1 - ) # Restore original - connection.update_tmp_settings.assert_called_with(tmp_relax_timeout=-1) - - self.pool.release(connection) - - def test_socket_error_handling_during_operations(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test that socket errors are properly handled during Redis operations.""" - # Create a connection first to ensure we have a mock socket - connection = self.pool.get_connection() - - # Set up a socket that will fail - if self.mock_sockets: - self.mock_sockets[0].closed = True - - # Attempt Redis operation that should fail due to closed socket - with pytest.raises( - (ConnectionError, OSError, Exception) - ): # Should raise connection-related exception - # Try to use the connection with a closed socket - connection.send_command("PING") - - # Release the connection - self.pool.release(connection) - - def test_maintenance_events_with_concurrent_operations(self): - # TODO Not yet reviewed and validated - just vipecoded - """Test maintenance events handling with concurrent Redis operations.""" - - # Perform concurrent operations - def redis_operation(key_suffix): - try: - return self.redis_client.set( - f"concurrent_key_{key_suffix}", f"value_{key_suffix}" + # Run operations in multiple threads to test concurrent behavior + threads = [] + for i in range(3): + thread = threading.Thread( + target=redis_operations_with_disabled_relax, args=(i,) ) - except Exception: - return False + threads.append(thread) + thread.start() - # Simulate concurrent operations - threads = [] - results = [] + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify no errors occurred + assert len(errors) == 0, f"Errors occurred: {errors}" - for i in range(3): - thread = threading.Thread( - target=lambda i=i: results.append(redis_operation(i)) + # Verify all operations completed successfully + assert len(results) == 3, ( + f"Expected 3 successful threads, got {len(results)}" ) - threads.append(thread) - thread.start() - # Wait for all threads to complete - for thread in threads: - thread.join() + # Verify maintenance events were processed correctly across all threads + # Note: Different pool types may create different numbers of sockets + # The key is that we have at least 1 socket and all threads succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) - # During concurrent operations, simulate a maintenance event - if self.pool.connection_kwargs.get("maintenance_events_config"): - migrating_event = NodeMigratingEvent(id=1, ttl=5) - # Create a pool handler to test event handling - pool_handler = MaintenanceEventPoolHandler(self.pool, self.config) - result = pool_handler.handle_event(migrating_event) - assert result is None # Pool handler doesn't handle migrating events - - # Verify that some operations completed successfully - # (Some might fail due to mock socket limitations, but that's expected) - assert len(results) == 3 + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() From 7b57a2211242867cd5a9161bda36e34b4be342aa Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 15 Jul 2025 17:15:02 +0300 Subject: [PATCH 07/16] Adding moving integration-like tests --- redis/_parsers/base.py | 9 +- redis/connection.py | 26 +- redis/maintenance_events.py | 39 +- tests/test_maintenance_events_handling.py | 604 ++++++++++++++++++---- 4 files changed, 557 insertions(+), 121 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index aa5a6b0f12..f2670e43b0 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -129,9 +129,10 @@ def __del__(self): def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock - self._buffer = SocketBuffer( - self._sock, self.socket_read_size, connection.socket_timeout - ) + timeout = connection.socket_timeout + if connection.tmp_relax_timeout != -1: + timeout = connection.tmp_relax_timeout + self._buffer = SocketBuffer(self._sock, self.socket_read_size, timeout) self.encoder = connection.encoder def on_disconnect(self): @@ -203,7 +204,7 @@ def handle_push_response(self, response, **kwargs): return self.invalidation_push_handler_func(response) if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: if msg_type in _MOVING_MESSAGE: - host, port = response[2].split(":") + host, port = response[2].decode().split(":") ttl = response[1] id = 1 # Hardcoded value for sync parser notification = NodeMovingEvent(id, host, port, ttl) diff --git a/redis/connection.py b/redis/connection.py index 9a434848ca..81a80d0903 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -807,6 +807,10 @@ def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): f" to timeout {timeout}; relax_timeout: {relax_timeout}" ) self._sock.settimeout(timeout) + self.update_parser_buffer_timeout(timeout) + + def update_parser_buffer_timeout(self, timeout: Optional[float] = None): + if self._parser and self._parser._buffer: self._parser._buffer.socket_timeout = timeout def update_tmp_settings( @@ -1901,7 +1905,7 @@ def disconnect_and_reconfigure_free_connections( def update_connections_current_timeout( self, relax_timeout: Optional[float], - include_available_connections: bool = False, + include_free_connections: bool = False, ): """ Update the timeout either for all connections in the pool or just for the ones in use. @@ -1919,7 +1923,7 @@ def update_connections_current_timeout( for conn in self._in_use_connections: self._update_connection_timeout(conn, relax_timeout) - if include_available_connections: + if include_free_connections: for conn in self._available_connections: self._update_connection_timeout(conn, relax_timeout) @@ -2164,8 +2168,6 @@ def update_active_connections_for_reconnect( connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: - if tmp_relax_timeout != -1: - conn.update_socket_timeout(tmp_relax_timeout) self._update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2184,14 +2186,24 @@ def disconnect_and_reconfigure_free_connections( conn, tmp_host_address, tmp_relax_timeout ) - def update_connections_current_timeout(self, relax_timeout: Optional[float] = None): + def update_connections_current_timeout( + self, + relax_timeout: Optional[float] = None, + include_free_connections: bool = False, + ): logging.debug( f"***** Blocking Pool --> Updating timeouts. relax_timeout: {relax_timeout}" ) with self._lock: - for conn in tuple(self._connections): - self._update_connection_timeout(conn, relax_timeout) + if include_free_connections: + for conn in tuple(self._connections): + self._update_connection_timeout(conn, relax_timeout) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + self._update_connection_timeout(conn, relax_timeout) def update_connections_tmp_settings( self, diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index d818a846b8..5b2091472d 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -323,25 +323,6 @@ def handle_event(self, notification: MaintenanceEvent): else: logging.error(f"Unhandled notification type: {notification}") - def handle_node_moved_event(self): - with self._lock: - self.pool.update_connection_kwargs_with_tmp_settings( - tmp_host_address=None, - tmp_relax_timeout=-1, - ) - with self.pool._lock: - if self.config.is_relax_timeouts_enabled(): - # reset the timeout for existing connections - self.pool.update_connections_current_timeout( - relax_timeout=-1, include_available_connections=True - ) - logging.debug("***** MOVING END--> TIMEOUTS RESET") - - self.pool.update_connections_tmp_settings( - tmp_host_address=None, tmp_relax_timeout=-1 - ) - logging.debug("***** MOVING END--> TMP SETTINGS ADDRESS RESET") - def handle_node_moving_event(self, event: NodeMovingEvent): if ( not self.config.proactive_reconnect @@ -403,6 +384,26 @@ def handle_node_moving_event(self, event: NodeMovingEvent): f"###### MOVING total execution time: {execution_time_us:.0f} microseconds" ) + def handle_node_moved_event(self): + logging.debug("***** MOVING END--> Starting to revert the changes.") + with self._lock: + self.pool.update_connection_kwargs_with_tmp_settings( + tmp_host_address=None, + tmp_relax_timeout=-1, + ) + with self.pool._lock: + if self.config.is_relax_timeouts_enabled(): + # reset the timeout for existing connections + self.pool.update_connections_current_timeout( + relax_timeout=-1, include_free_connections=True + ) + logging.debug("***** MOVING END--> TIMEOUTS RESET") + + self.pool.update_connections_tmp_settings( + tmp_host_address=None, tmp_relax_timeout=-1 + ) + logging.debug("***** MOVING END--> TMP SETTINGS ADDRESS RESET") + class MaintenanceEventConnectionHandler: def __init__( diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 6a687da1b0..c04c6c066e 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1,10 +1,12 @@ import socket import threading -from unittest.mock import Mock, patch +from typing import List +from unittest.mock import patch import pytest +from time import sleep from redis import Redis -from redis.connection import ConnectionPool, BlockingConnectionPool +from redis.connection import AbstractConnection, ConnectionPool, BlockingConnectionPool from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, @@ -15,18 +17,21 @@ class MockSocket: """Mock socket that simulates Redis protocol responses.""" + AFTER_MOVING_ADDRESS = "1.2.3.4:6379" + DEFAULT_ADDRESS = "12.45.34.56:6379" + MOVING_TIMEOUT = 1 + def __init__(self): self.connected = False self.address = None self.sent_data = [] - self.response_queue = [] self.closed = False self.command_count = 0 self.pending_responses = [] - self.current_response_index = 0 # Track socket timeout changes for maintenance events validation self.timeout = None self.thread_timeouts = {} # Track last applied timeout per thread + self.moving_sent = False def connect(self, address): """Simulate socket connection.""" @@ -57,6 +62,12 @@ def send(self, data): # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" response = migrated_push.encode() + response + elif b"key_receive_moving_" in data: + # MOVING push message before SET key_receive_moving_X response + # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) + # Note: Using + instead of $ to send as simple string instead of bulk string + moving_push = f">3\r\n$6\r\nMOVING\r\n:{MockSocket.MOVING_TIMEOUT}\r\n+{MockSocket.AFTER_MOVING_ADDRESS}\r\n" + response = moving_push.encode() + response self.pending_responses.append(response) elif b"GET" in data: @@ -69,14 +80,20 @@ def send(self, data): self.pending_responses.append(b"$8\r\nvalue1_0\r\n") elif b"key_receive_migrating_0" in data: self.pending_responses.append(b"$8\r\nvalue2_0\r\n") + elif b"key_receive_moving_0" in data: + self.pending_responses.append(b"$8\r\nvalue3_0\r\n") elif b"key1_1" in data: self.pending_responses.append(b"$8\r\nvalue1_1\r\n") elif b"key_receive_migrating_1" in data: self.pending_responses.append(b"$8\r\nvalue2_1\r\n") + elif b"key_receive_moving_1" in data: + self.pending_responses.append(b"$8\r\nvalue3_1\r\n") elif b"key1_2" in data: self.pending_responses.append(b"$8\r\nvalue1_2\r\n") elif b"key_receive_migrating_2" in data: self.pending_responses.append(b"$8\r\nvalue2_2\r\n") + elif b"key_receive_moving_2" in data: + self.pending_responses.append(b"$8\r\nvalue3_2\r\n") # Generic keys (less specific, should come after thread-specific) elif b"key0" in data: self.pending_responses.append(b"$6\r\nvalue0\r\n") @@ -100,13 +117,12 @@ def recv(self, bufsize): """Simulate receiving data from Redis.""" if self.closed: raise ConnectionError("Socket is closed") - if self.response_queue: - response = self.response_queue.pop(0) - return response[:bufsize] # Respect buffer size # Use pending responses that were prepared when commands were sent if self.pending_responses: response = self.pending_responses.pop(0) + if b"MOVING" in response: + self.moving_sent = True return response[:bufsize] # Respect buffer size else: # No data available - this should block or raise an exception @@ -123,26 +139,33 @@ def close(self): """Simulate closing the socket.""" self.closed = True self.connected = False + self.address = None + self.timeout = None + self.thread_timeouts = {} def settimeout(self, timeout): """Simulate setting socket timeout and track changes per thread.""" self.timeout = timeout - # Track last applied timeout per thread + # Track last applied timeout with thread_id information added thread_id = threading.current_thread().ident self.thread_timeouts[thread_id] = timeout + def gettimeout(self): + """Simulate getting socket timeout.""" + return self.timeout + def setsockopt(self, level, optname, value): """Simulate setting socket options.""" pass def getpeername(self): """Simulate getting peer name.""" - return ("127.0.0.1", 6379) + return self.address def getsockname(self): """Simulate getting socket name.""" - return ("127.0.0.1", 12345) + return (self.address.split(":")[0], 12345) def shutdown(self, how): """Simulate socket shutdown.""" @@ -173,9 +196,7 @@ def mock_select(rlist, wlist, xlist, timeout=0): for sock in rlist: if hasattr(sock, "connected") and sock.connected and not sock.closed: # Only return socket as ready if it actually has data to read - if ( - hasattr(sock, "pending_responses") and sock.pending_responses - ) or (hasattr(sock, "response_queue") and sock.response_queue): + if hasattr(sock, "pending_responses") and sock.pending_responses: ready_sockets.append(sock) # Don't return socket as ready just because it received commands # Only when there are actual responses available @@ -195,7 +216,11 @@ def teardown_method(self): self.select_patcher.stop() def _get_client( - self, pool_class, max_connections=10, maintenance_events_config=None + self, + pool_class, + max_connections=10, + maintenance_events_config=None, + setup_pool_handler=False, ): """Helper method to create a pool and Redis client with maintenance events configuration. @@ -204,6 +229,7 @@ def _get_client( max_connections: Maximum number of connections in the pool (default: 10) maintenance_events_config: Optional MaintenanceEventsConfig to use. If not provided, uses self.config from setup_method (default: None) + setup_pool_handler: Whether to set up pool handler for moving events (default: False) Returns: tuple: (test_pool, test_redis_client) @@ -215,19 +241,30 @@ def _get_client( ) test_pool = pool_class( - host="localhost", - port=6379, + host=MockSocket.DEFAULT_ADDRESS.split(":")[0], + port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), max_connections=max_connections, protocol=3, # Required for maintenance events maintenance_events_config=config, ) test_redis_client = Redis(connection_pool=test_pool) - return test_pool, test_redis_client + + # Set up pool handler for moving events if requested + if setup_pool_handler: + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + return test_redis_client def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): """Helper method to validate the current timeout for the calling thread.""" - current_thread_id = threading.current_thread().ident actual_timeout = None + # Get the actual thread ID from the current thread + current_thread_id = threading.current_thread().ident for sock in self.mock_sockets: if current_thread_id in sock.thread_timeouts: actual_timeout = sock.thread_timeouts[current_thread_id] @@ -235,16 +272,121 @@ def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): assert actual_timeout == expected_timeout, ( f"Thread {thread_id}: Expected timeout ({expected_timeout}), " - f"but found timeout: {actual_timeout} for thread {current_thread_id}. " + f"but found timeout: {actual_timeout} for thread {thread_id}. " f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}" ) + def _validate_disconnected(self, expected_count): + """Helper method to validate all socket timeouts""" + disconnected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.closed: + disconnected_sockets_count += 1 + assert disconnected_sockets_count == expected_count + + def _validate_connected(self, expected_count): + """Helper method to validate all socket timeouts""" + connected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.connected: + connected_sockets_count += 1 + assert connected_sockets_count == expected_count + + def _validate_in_use_connections_state( + self, in_use_connections: List[AbstractConnection] + ): + """Helper method to validate state of in-use connections.""" + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for connection in in_use_connections: + assert connection._should_reconnect is True + assert ( + connection.tmp_host_address + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert connection.tmp_relax_timeout == self.config.relax_timeout + assert connection._sock.gettimeout() == self.config.relax_timeout + assert connection._sock.connected is True + assert ( + connection._sock.getpeername()[0] + == MockSocket.DEFAULT_ADDRESS.split(":")[0] + ) + + def _validate_free_connections_state( + self, + pool, + tmp_host_address, + relax_timeout, + should_be_connected_count, + connected_to_tmp_addres=False, + ): + """Helper method to validate state of free/available connections.""" + if isinstance(pool, BlockingConnectionPool): + # BlockingConnectionPool uses _connections list where created connections are stored + # but we need to get the ones in the queue - these are the free ones + # the uninitialized connections are filtered out + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + # Regular ConnectionPool uses _available_connections for free connections + free_connections = pool._available_connections + else: + raise ValueError(f"Unsupported pool type: {type(pool)}") + + connected_count = 0 + # Validate fields that are validated in the validation of the active connections + for connection in free_connections: + # Validate the same fields as in _validate_in_use_connections_state + assert connection._should_reconnect is False + assert connection.tmp_host_address == tmp_host_address + assert connection.tmp_relax_timeout == relax_timeout + if connection._sock is not None: + connected_count += 1 + + if connected_to_tmp_addres: + assert ( + connection._sock.getpeername()[0] + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + else: + assert ( + connection._sock.getpeername()[0] + == MockSocket.DEFAULT_ADDRESS.split(":")[0] + ) + assert connected_count == should_be_connected_count + + def _validate_all_timeouts(self, expected_timeout): + """Helper method to validate state of in-use connections.""" + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for mock_socket in self.mock_sockets: + if expected_timeout is None: + assert mock_socket.gettimeout() is None + else: + assert mock_socket.gettimeout() == expected_timeout + + def _validate_conn_kwargs( + self, + pool, + expected_host_address, + expected_port, + expected_tmp_host_address, + expected_tmp_relax_timeout, + ): + """Helper method to validate connection kwargs.""" + assert pool.connection_kwargs["host"] == expected_host_address + assert pool.connection_kwargs["port"] == expected_port + assert pool.connection_kwargs["tmp_host_address"] == expected_tmp_host_address + assert pool.connection_kwargs["tmp_relax_timeout"] == expected_tmp_relax_timeout + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_connection_pool_creation_with_maintenance_events(self, pool_class): """Test that connection pools are created with maintenance events configuration.""" # Create a pool and Redis client with maintenance events max_connections = 3 if pool_class == BlockingConnectionPool else 10 - test_pool, _ = self._get_client(pool_class, max_connections=max_connections) + test_redis_client = self._get_client( + pool_class, max_connections=max_connections + ) + test_pool = test_redis_client.connection_pool try: assert ( @@ -283,7 +425,7 @@ def test_redis_operations_with_mock_sockets(self, pool_class): Basically with test - the mocked socket is validated. """ # Create a pool and Redis client with maintenance events - test_pool, test_redis_client = self._get_client(pool_class, max_connections=5) + test_redis_client = self._get_client(pool_class, max_connections=5) try: # Perform Redis operations that should work with our improved mock responses @@ -300,77 +442,19 @@ def test_redis_operations_with_mock_sockets(self, pool_class): assert len(self.mock_sockets[0].sent_data) >= 2 # HELLO, SET, GET commands # Verify that the connection has maintenance event handler - connection = test_pool.get_connection() - assert hasattr(connection, "_maintenance_event_connection_handler") - test_pool.release(connection) - - finally: - if hasattr(test_pool, "disconnect"): - test_pool.disconnect() - - @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) - def test_multiple_connections_in_pool(self, pool_class): - """Test that multiple connections can be created and used for Redis operations in multiple threads.""" - # Create a pool and Redis client with maintenance events - test_pool, test_redis_client = self._get_client(pool_class, max_connections=5) - - try: - # Results storage for thread operations - results = [] - errors = [] - - def redis_operation(key_suffix): - """Perform Redis operations in a thread.""" - try: - # SET operation - set_result = test_redis_client.set( - f"key{key_suffix}", f"value{key_suffix}" - ) - # GET operation - get_result = test_redis_client.get(f"key{key_suffix}") - results.append((set_result, get_result)) - except Exception as e: - errors.append(e) - - # Run operations in multiple threads to force multiple connections - threads = [] - for i in range(3): - thread = threading.Thread(target=redis_operation, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify no errors occurred - assert len(errors) == 0, f"Errors occurred: {errors}" - - # Verify all operations completed successfully - assert len(results) == 3 - for set_result, get_result in results: - assert set_result is True - assert get_result in [b"value0", b"value1", b"value2"] - - # Verify that multiple connections were created with mock sockets - # With threading, both pool types should create multiple sockets for concurrent access - assert len(self.mock_sockets) >= 2, ( - f"Expected multiple sockets due to threading, got {len(self.mock_sockets)}" - ) - - # Verify each connection has maintenance event handler - connection = test_pool.get_connection() + connection = test_redis_client.connection_pool.get_connection() assert hasattr(connection, "_maintenance_event_connection_handler") - test_pool.release(connection) + test_redis_client.connection_pool.release(connection) finally: - if hasattr(test_pool, "disconnect"): - test_pool.disconnect() + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() def test_pool_handler_with_migrating_event(self): """Test that pool handler correctly handles migrating events.""" # Create a pool and Redis client with maintenance events - test_pool, _ = self._get_client(ConnectionPool) + test_redis_client = self._get_client(ConnectionPool) + test_pool = test_redis_client.connection_pool try: # Create and set a pool handler @@ -421,7 +505,7 @@ def test_migration_related_events_handling_integration(self, pool_class): 8. Uses proper RESP3 push message format for realistic protocol simulation """ # Create a pool and Redis client with maintenance events - test_pool, test_redis_client = self._get_client(pool_class, max_connections=10) + test_redis_client = self._get_client(pool_class, max_connections=10) try: # Results storage for thread operations @@ -530,8 +614,8 @@ def redis_operations_with_maintenance_events(thread_id): ) finally: - if hasattr(test_pool, "disconnect"): - test_pool.disconnect() + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migrating_event_with_disabled_relax_timeout(self, pool_class): @@ -551,7 +635,7 @@ def test_migrating_event_with_disabled_relax_timeout(self, pool_class): ) # Create a pool and Redis client with disabled relax timeout config - test_pool, test_redis_client = self._get_client( + test_redis_client = self._get_client( pool_class, max_connections=5, maintenance_events_config=disabled_config ) @@ -635,5 +719,343 @@ def redis_operations_with_disabled_relax(thread_id): ) finally: - if hasattr(test_pool, "disconnect"): - test_pool.disconnect() + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_related_events_handling_integration(self, pool_class): + """ + Test full integration of moving-related events (MOVING) handling with Redis commands. + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(10): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 5 connections to be "in use" + in_use_connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + # the connection used for the command is expected to be reconnected to the new address + # before it is returned to the pool + result2 = test_redis_client.set(key_moving, value_moving) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_moving) failed" + + # Validate pool and connections settings were updated according to MOVING event + # handling expectations + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + # 5 disconnects has happened, 1 of them is with reconnect + self._validate_disconnected(5) + # 5 in use connected + 1 after reconnect + self._validate_connected(6) + self._validate_in_use_connections_state(in_use_connections) + # Validate there is 1 free connection that is connected + # the one that has handled the MOVING should reconnect after parsing the response + self._validate_free_connections_state( + test_redis_client.connection_pool, + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + + # Wait for MOVING timeout to expire and the moving completed handler to run + print("Waiting for MOVING timeout to expire...") + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + + self._validate_all_timeouts(None) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + None, + -1, + ) + self._validate_free_connections_state( + test_redis_client.connection_pool, + None, + -1, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_while_moving_not_expired(self, pool_class): + """ + Test creating new connections while MOVING event is active (not expired). + + This test validates that: + 1. After MOVING event is processed, new connections are created with temporary address + 2. New connections inherit the relaxed timeout settings + 3. Pool configuration is properly applied to newly created connections + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + + # Now get several more connections to force creation of new ones + # This should create new connections with the temporary address + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with temporary address and relax timeout + # and when connecting those configs are used + # get_connection() returns a connection that is already connected + assert ( + new_connection.tmp_host_address + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert new_connection.tmp_relax_timeout == self.config.relax_timeout + # New connections should be connected to the temporary address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + assert ( + new_connection._sock.getpeername()[0] + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert new_connection._sock.gettimeout() == self.config.relax_timeout + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_after_moving_expires(self, pool_class): + """ + Test creating new connections after MOVING event expires. + + This test validates that: + 1. After MOVING timeout expires, new connections use original address + 2. Pool configuration is reset to original values + 3. New connections don't inherit temporary settings + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Wait for MOVING timeout to expire + print("Waiting for MOVING timeout to expire...") + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + + # Now get several new connections after expiration + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with original address (no temporary settings) + assert new_connection.tmp_host_address is None + assert new_connection.tmp_relax_timeout == -1 + # New connections should be connected to the original address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + # Socket timeout should be None (original timeout) + assert new_connection._sock.gettimeout() is None + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_receive_migrated_after_moving(self, pool_class): + # TODO Refactor: when migrated comes after moving and + # moving hasn't yet expired - it should not decrease timeouts + """ + Test receiving MIGRATED event after MOVING event. + + This test validates the complete MOVING -> MIGRATED lifecycle: + 1. MOVING event is processed and temporary settings are applied + 2. MIGRATED event is received during command execution + 3. Temporary settings are cleared after MIGRATED + 4. Pool configuration is restored to original values + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Step 1: Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result_moving = test_redis_client.set(key_moving, value_moving) + + # Validate MOVING command result + assert result_moving is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + + # Step 2: Run command that will receive and handle MIGRATED event + # This should clear the temporary settings + key_migrated = "key_receive_migrated_0" + value_migrated = "migrated_value" + result_migrated = test_redis_client.set(key_migrated, value_migrated) + + # Validate MIGRATED command result + assert result_migrated is True, "SET key_receive_migrated command failed" + + # Step 3: Validate that MIGRATED event was processed but MOVING settings remain + # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[ + 0 + ], # MOVING settings still active + self.config.relax_timeout, # MOVING timeout still active + ) + + # Step 4: Create new connections after MIGRATED to verify they still use MOVING settings + # (since MOVING settings are still active) + new_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + new_connections.append(connection) + + # Validate that new connections are created with MOVING settings (still active) + for connection in new_connections: + assert ( + connection.tmp_host_address + == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + ) + # Note: New connections may not inherit the exact relax timeout value + # but they should have the temporary host address + # New connections should be connected + if connection._sock is not None: + assert connection._sock.connected is True + + # Release the new connections + for connection in new_connections: + test_redis_client.connection_pool.release(connection) + + # Validate free connections state with MOVING settings still active + # Note: We'll validate with the pool's current settings rather than individual connection settings + # since new connections may have different timeout values but still use the temporary address + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() From bed2e40c21b3f8c98af0d09d52522e94ea798b69 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 15:28:20 +0300 Subject: [PATCH 08/16] Fixed BlockingConnectionPool locking strategy. Removed debug logging. Refactored the maintenance events tests not to be multithreaded - we don't need it for those tests. --- redis/asyncio/connection.py | 2 + redis/client.py | 6 - redis/connection.py | 125 ++++--- redis/maintenance_events.py | 36 +- tests/test_connection_pool.py | 3 + tests/test_maintenance_events_handling.py | 419 ++++++++++++---------- 6 files changed, 329 insertions(+), 262 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4efd868f6f..fe86e4c36e 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1308,6 +1308,8 @@ def __init__( ) self._condition = asyncio.Condition() self.timeout = timeout + self._in_maintenance = False + self._locked = False @deprecated_args( args_to_warn=["*"], diff --git a/redis/client.py b/redis/client.py index 0ec36c52d9..473b1e00f2 100755 --- a/redis/client.py +++ b/redis/client.py @@ -668,9 +668,6 @@ def _execute_command(self, *args, **options): finally: if conn and conn.should_reconnect(): - logging.debug( - f"***** Redis reconnect before exit _execute_command --> notification for {conn._sock.getpeername()}" - ) self._close_connection(conn) conn.connect() if self._single_connection_client: @@ -963,9 +960,6 @@ def _execute(self, conn, command, *args, **kwargs): lambda _: self._reconnect(conn), ) if conn.should_reconnect(): - logging.debug( - f"***** PubSub --> Reconnect on notification for {conn._sock.getpeername()}" - ) self._reconnect(conn) return response diff --git a/redis/connection.py b/redis/connection.py index 81a80d0903..a096b045b2 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -431,7 +431,20 @@ def set_parser(self, parser_class): def set_maintenance_event_pool_handler( self, maintenance_event_pool_handler: MaintenanceEventPoolHandler ): - self._parser.set_node_moving_push_handler(maintenance_event_pool_handler) + self._parser.set_node_moving_push_handler( + maintenance_event_pool_handler.handle_event + ) + + # Initialize maintenance event connection handler if it doesn't exist + if not hasattr(self, "_maintenance_event_connection_handler"): + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler( + self, maintenance_event_pool_handler.config + ) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) def connect(self): "Connects to the Redis server if not already connected" @@ -802,10 +815,6 @@ def should_reconnect(self): def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): if self._sock: timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout - logging.debug( - f"***** Connection --> Updating timeout for {self._sock.getpeername()}" - f" to timeout {timeout}; relax_timeout: {relax_timeout}" - ) self._sock.settimeout(timeout) self.update_parser_buffer_timeout(timeout) @@ -858,10 +867,6 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None - if self.tmp_host_address is not None: - logging.debug( - f"***** Connection --> Using tmp_host_address: {self.tmp_host_address}" - ) host = self.tmp_host_address or self.host for res in socket.getaddrinfo( @@ -882,14 +887,8 @@ def _connect(self): # set the socket_connect_timeout before we connect if self.tmp_relax_timeout != -1: - logging.debug( - f"***** Connection connect --> Using relax_timeout: {self.tmp_relax_timeout}" - ) sock.settimeout(self.tmp_relax_timeout) else: - logging.debug( - f"***** Connection connect --> Using default socket_connect_timeout: {self.socket_connect_timeout}" - ) sock.settimeout(self.socket_connect_timeout) # connect @@ -897,16 +896,9 @@ def _connect(self): # set the socket_timeout now that we're connected if self.tmp_relax_timeout != -1: - logging.debug( - f"***** Connection --> Using relax_timeout: {self.tmp_relax_timeout}" - ) sock.settimeout(self.tmp_relax_timeout) else: - logging.debug( - f"***** Connection --> Using default socket_timeout: {self.socket_timeout}" - ) sock.settimeout(self.socket_timeout) - logging.debug(f"Connected to {sock.getpeername()}") return sock except OSError as _: @@ -1606,14 +1598,10 @@ def _update_maintenance_events_configs_for_connections( ): with self._lock: for conn in self._available_connections: - conn.set_maintenance_events_pool_handler( - maintenance_events_pool_handler - ) + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) conn.maintenance_events_config = maintenance_events_pool_handler.config for conn in self._in_use_connections: - conn.set_maintenance_events_pool_handler( - maintenance_events_pool_handler - ) + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) conn.maintenance_events_config = maintenance_events_pool_handler.config def reset(self) -> None: @@ -1755,9 +1743,6 @@ def release(self, connection: "Connection") -> None: if self.owns_connection(connection): if connection.should_reconnect(): - logging.debug( - f"***** Pool--> disconnecting in release {connection._sock.getpeername()}" - ) connection.disconnect() self._available_connections.append(connection) self._event_dispatcher.dispatch( @@ -1917,9 +1902,6 @@ def update_connections_current_timeout( If -1 is provided - the relax timeout is disabled. :param include_available_connections: Whether to include available connections in the update. """ - logging.debug(f"***** Pool --> Updating timeouts. New value: {relax_timeout}") - start_time = time.time() - for conn in self._in_use_connections: self._update_connection_timeout(conn, relax_timeout) @@ -1927,11 +1909,6 @@ def update_connections_current_timeout( for conn in self._available_connections: self._update_connection_timeout(conn, relax_timeout) - execution_time_us = (time.time() - start_time) * 1000000 - logging.error( - f"###### TIMEOUTS execution time: {execution_time_us:.0f} microseconds" - ) - def _update_connection_for_reconnect( self, connection: "Connection", @@ -2021,6 +1998,8 @@ def __init__( ): self.queue_class = queue_class self.timeout = timeout + self._in_maintenance = False + self._locked = False super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -2029,7 +2008,10 @@ def __init__( def reset(self): # Create and fill up a thread safe queue with ``None`` values. - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True self.pool = self.queue_class(self.max_connections) while True: try: @@ -2040,6 +2022,13 @@ def reset(self): # Keep a list of actual connection instances so that we can # disconnect them later. self._connections = [] + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -2054,7 +2043,10 @@ def reset(self): def make_connection(self): "Make a fresh connection." - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True if self.cache is not None: connection = CacheProxyConnection( self.connection_class(**self.connection_kwargs), @@ -2066,6 +2058,13 @@ def make_connection(self): self._connections.append(connection) return connection + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False @deprecated_args( args_to_warn=["*"], @@ -2090,7 +2089,10 @@ def get_connection(self, command_name=None, *keys, **options): # Try and get a connection from the pool. If one isn't available within # self.timeout then raise a ``ConnectionError``. connection = None - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True try: connection = self.pool.get(block=True, timeout=self.timeout) except Empty: @@ -2102,6 +2104,13 @@ def get_connection(self, command_name=None, *keys, **options): # a new connection to add to the pool. if connection is None: connection = self.make_connection() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False try: # ensure this connection is connected to Redis @@ -2130,7 +2139,10 @@ def release(self, connection): # Make sure we haven't changed process. self._checkpid() - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True if not self.owns_connection(connection): # pool doesn't own this connection. do not add it back # to the pool. instead add a None value which is a placeholder @@ -2140,24 +2152,39 @@ def release(self, connection): self.pool.put_nowait(None) return if connection.should_reconnect(): - logging.debug( - f"***** Blocking Pool--> disconnecting in release {connection._sock.getpeername()}" - ) connection.disconnect() # Put the connection back into the pool. try: + print("Releasing connection - in the pool") self.pool.put_nowait(connection) except Full: # perhaps the pool has been reset() after a fork? regardless, # we don't want this connection pass + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False def disconnect(self): "Disconnects all connections in the pool." self._checkpid() - with self._lock: + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True for connection in self._connections: connection.disconnect() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False def update_active_connections_for_reconnect( self, @@ -2236,3 +2263,7 @@ def _update_maintenance_events_configs_for_connections( conn.maintenance_events_config = ( maintenance_events_pool_handler.config ) + + def set_in_maintenance(self, in_maintenance: bool): + """Set the maintenance mode for the connection pool.""" + self._in_maintenance = in_maintenance diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 5b2091472d..bf0cd6bda8 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -2,12 +2,16 @@ import threading import time from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union from redis.typing import Number if TYPE_CHECKING: - from redis.connection import ConnectionInterface, ConnectionPool + from redis.connection import ( + BlockingConnectionPool, + ConnectionInterface, + ConnectionPool, + ) class MaintenanceEvent(ABC): @@ -303,7 +307,11 @@ def is_relax_timeouts_enabled(self) -> bool: class MaintenanceEventPoolHandler: - def __init__(self, pool: "ConnectionPool", config: MaintenanceEventsConfig) -> None: + def __init__( + self, + pool: Union["ConnectionPool", "BlockingConnectionPool"], + config: MaintenanceEventsConfig, + ) -> None: self.pool = pool self.config = config self._processed_events = set() @@ -334,18 +342,15 @@ def handle_node_moving_event(self, event: NodeMovingEvent): # nothing to do in the connection pool handling # the event has already been handled or is expired # just return - logging.debug("***** MOVING --> SKIPPED DONE") return - logging.info(f"***** MOVING --> {event}") - logging.info(f"***** MOVING --> set: {self._processed_events}") - start_time = time.time() - with self.pool._lock: if ( self.config.proactive_reconnect or self.config.is_relax_timeouts_enabled() ): + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(True) # edit the config for new connections until the notification expires self.pool.update_connection_kwargs_with_tmp_settings( tmp_host_address=event.new_node_host, @@ -371,21 +376,14 @@ def handle_node_moving_event(self, event: NodeMovingEvent): tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, ) - execution_time_us = (time.time() - start_time_2) * 1000000 - logging.error( - f"###### MOVING disconnects execution time: {execution_time_us:.0f} microseconds" - ) + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(False) threading.Timer(event.ttl, self.handle_node_moved_event).start() self._processed_events.add(event) - execution_time_us = (time.time() - start_time) * 1000000 - logging.error( - f"###### MOVING total execution time: {execution_time_us:.0f} microseconds" - ) def handle_node_moved_event(self): - logging.debug("***** MOVING END--> Starting to revert the changes.") with self._lock: self.pool.update_connection_kwargs_with_tmp_settings( tmp_host_address=None, @@ -397,12 +395,10 @@ def handle_node_moved_event(self): self.pool.update_connections_current_timeout( relax_timeout=-1, include_free_connections=True ) - logging.debug("***** MOVING END--> TIMEOUTS RESET") self.pool.update_connections_tmp_settings( tmp_host_address=None, tmp_relax_timeout=-1 ) - logging.debug("***** MOVING END--> TMP SETTINGS ADDRESS RESET") class MaintenanceEventConnectionHandler: @@ -424,7 +420,6 @@ def handle_migrating_event(self, notification: NodeMigratingEvent): if not self.config.is_relax_timeouts_enabled(): return - logging.info(f"***** MIGRATING --> {notification}") # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) @@ -433,7 +428,6 @@ def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): if not self.config.is_relax_timeouts_enabled(): return - logging.info(f"***** MIGRATED --> {notification}") # Node migration completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 3a4896f2a3..4518cd7290 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -33,6 +33,9 @@ def connect(self): def can_read(self): return False + def should_reconnect(self): + return False + class TestConnectionPool: def get_pool( diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index c04c6c066e..1620471ea7 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -52,12 +52,12 @@ def send(self, data): response = b"+OK\r\n" # Check if this is a key that should trigger a push message - if b"key_receive_migrating_" in data: + if b"key_receive_migrating_" in data or b"key_receive_migrating" in data: # MIGRATING push message before SET key_receive_migrating_X response # Format: >2\r\n$9\r\nMIGRATING\r\n:10\r\n (2 elements: MIGRATING, ttl) migrating_push = ">2\r\n$9\r\nMIGRATING\r\n:10\r\n" response = migrating_push.encode() + response - elif b"key_receive_migrated_" in data: + elif b"key_receive_migrated_" in data or b"key_receive_migrated" in data: # MIGRATED push message before SET key_receive_migrated_X response # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" @@ -75,32 +75,17 @@ def send(self, data): if b"hello" in data: response = b"$5\r\nworld\r\n" self.pending_responses.append(response) - # Handle thread-specific keys for integration test first (more specific) - elif b"key1_0" in data: - self.pending_responses.append(b"$8\r\nvalue1_0\r\n") - elif b"key_receive_migrating_0" in data: - self.pending_responses.append(b"$8\r\nvalue2_0\r\n") + # Handle specific keys used in tests elif b"key_receive_moving_0" in data: self.pending_responses.append(b"$8\r\nvalue3_0\r\n") - elif b"key1_1" in data: - self.pending_responses.append(b"$8\r\nvalue1_1\r\n") - elif b"key_receive_migrating_1" in data: - self.pending_responses.append(b"$8\r\nvalue2_1\r\n") - elif b"key_receive_moving_1" in data: - self.pending_responses.append(b"$8\r\nvalue3_1\r\n") - elif b"key1_2" in data: - self.pending_responses.append(b"$8\r\nvalue1_2\r\n") - elif b"key_receive_migrating_2" in data: - self.pending_responses.append(b"$8\r\nvalue2_2\r\n") - elif b"key_receive_moving_2" in data: - self.pending_responses.append(b"$8\r\nvalue3_2\r\n") - # Generic keys (less specific, should come after thread-specific) - elif b"key0" in data: - self.pending_responses.append(b"$6\r\nvalue0\r\n") + elif b"key_receive_migrated_0" in data: + self.pending_responses.append(b"$13\r\nmigrated_value\r\n") + elif b"key_receive_migrating" in data: + self.pending_responses.append(b"$6\r\nvalue2\r\n") + elif b"key_receive_migrated" in data: + self.pending_responses.append(b"$6\r\nvalue3\r\n") elif b"key1" in data: self.pending_responses.append(b"$6\r\nvalue1\r\n") - elif b"key2" in data: - self.pending_responses.append(b"$6\r\nvalue2\r\n") else: self.pending_responses.append(b"$-1\r\n") # NULL response else: @@ -260,7 +245,37 @@ def _get_client( return test_redis_client - def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): + def _validate_connection_handlers(self, conn, pool_handler, config): + """Helper method to validate connection handlers are properly set.""" + # Test that the node moving handler function is correctly set + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is config + + def _validate_current_timeout_for_thread( + self, thread_id, expected_timeout, error_msg=None + ): """Helper method to validate the current timeout for the calling thread.""" actual_timeout = None # Get the actual thread ID from the current thread @@ -271,9 +286,27 @@ def _validate_current_timeout_for_thread(self, thread_id, expected_timeout): break assert actual_timeout == expected_timeout, ( + error_msg, f"Thread {thread_id}: Expected timeout ({expected_timeout}), " f"but found timeout: {actual_timeout} for thread {thread_id}. " - f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}" + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", + ) + + def _validate_current_timeout(self, expected_timeout, error_msg=None): + """Helper method to validate the current timeout for the calling thread.""" + actual_timeout = None + # Get the actual thread ID from the current thread + current_thread_id = threading.current_thread().ident + for sock in self.mock_sockets: + if current_thread_id in sock.thread_timeouts: + actual_timeout = sock.thread_timeouts[current_thread_id] + break + + assert actual_timeout == expected_timeout, ( + f"{error_msg or ''}" + f"Expected timeout ({expected_timeout}), " + f"but found timeout: {actual_timeout}. " + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", ) def _validate_disconnected(self, expected_count): @@ -378,6 +411,95 @@ def _validate_conn_kwargs( assert pool.connection_kwargs["tmp_host_address"] == expected_tmp_host_address assert pool.connection_kwargs["tmp_relax_timeout"] == expected_tmp_relax_timeout + def test_client_initialization(self): + """Test that Redis client is created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + pool_handler = test_redis_client.connection_pool.connection_kwargs.get( + "maintenance_events_pool_handler" + ) + assert pool_handler is not None + assert pool_handler.config == self.config + + conn = test_redis_client.connection_pool.get_connection() + assert conn._should_reconnect is False + assert conn.tmp_host_address is None + assert conn.tmp_relax_timeout == -1 + + # Test that the node moving handler function is correctly set by + # comparing the underlying function and instance + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is self.config + + def test_maint_handler_init_for_existing_connections(self): + """Test that maintenance event handlers are properly set on existing and new connections + when configuration is enabled after client creation.""" + + # Create a Redis client with disabled maintenance events configuration + disabled_config = MaintenanceEventsConfig(enabled=False) + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=disabled_config, + ) + + # Extract an existing connection before enabling maintenance events + existing_conn = test_redis_client.connection_pool.get_connection() + + # Verify that maintenance events are initially disabled + assert existing_conn._parser.node_moving_push_handler_func is None + assert not hasattr(existing_conn, "_maintenance_event_connection_handler") + assert existing_conn._parser.maintenance_push_handler_func is None + + # Create a new enabled configuration and set up pool handler + enabled_config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, enabled_config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + # Validate the existing connection after enabling maintenance events + # Both existing and new connections should now have full handler setup + self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) + + # Create a new connection and validate it has full handlers + new_conn = test_redis_client.connection_pool.get_connection() + self._validate_connection_handlers(new_conn, pool_handler, enabled_config) + + # Clean up connections + test_redis_client.connection_pool.release(existing_conn) + test_redis_client.connection_pool.release(new_conn) + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_connection_pool_creation_with_maintenance_events(self, pool_class): """Test that connection pools are created with maintenance events configuration.""" @@ -492,14 +614,14 @@ def test_pool_handler_with_migrating_event(self): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_migration_related_events_handling_integration(self, pool_class): """ - Test full integration of migration-related events (MIGRATING/MIGRATED) handling with multiple threads and commands. + Test full integration of migration-related events (MIGRATING/MIGRATED) handling. This test validates the complete migration lifecycle: - 1. Creates 3 concurrent threads, each executing 5 Redis commands - 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating_X) + 1. Executes 5 Redis commands sequentially + 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating) 3. Validates socket timeout is updated to relaxed value (30s) after MIGRATING 4. Executes commands 3-4 while timeout remains relaxed - 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated_X) + 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated) 6. Validates socket timeout is restored after MIGRATED 7. Tests both ConnectionPool and BlockingConnectionPool implementations 8. Uses proper RESP3 push message format for realistic protocol simulation @@ -508,107 +630,63 @@ def test_migration_related_events_handling_integration(self, pool_class): test_redis_client = self._get_client(pool_class, max_connections=10) try: - # Results storage for thread operations - results = [] - errors = [] - - def redis_operations_with_maintenance_events(thread_id): - """Perform Redis operations with maintenance events in a thread.""" - try: - # Command 1: Initial command - key1 = f"key1_{thread_id}" - value1 = f"value1_{thread_id}" - result1 = test_redis_client.set(key1, value1) - - # Validate Command 1 result - erros_msg = f"Thread {thread_id}: Command 1 (SET key1) failed" - assert result1 is True, erros_msg - - # Command 2: This SET command will receive MIGRATING push message before response - key_migrating = f"key_receive_migrating_{thread_id}" - value_migrating = f"value2_{thread_id}" - result2 = test_redis_client.set(key_migrating, value_migrating) - - # Validate Command 2 result - erros_msg = f"Thread {thread_id}: Command 2 (SET key_receive_migrating) failed" - assert result2 is True, erros_msg - - # Step 4: Validate timeout was updated to relaxed value after MIGRATING - self._validate_current_timeout_for_thread(thread_id, 30) - - # Command 3: Another command while timeout is still relaxed - result3 = test_redis_client.get(key1) - - # Validate Command 3 result - expected_value3 = value1.encode() - errors_msg = ( - f"Thread {thread_id}: Command 3 (GET key1) failed. " - f"Expected {expected_value3}, got {result3}" - ) - assert result3 == expected_value3, errors_msg + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) - # Command 4: Execute command (step 5) - result4 = test_redis_client.get(key_migrating) + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" - # Validate Command 4 result - expected_value4 = value_migrating.encode() - errors_msg = ( - f"Thread {thread_id}: Command 4 (GET key_receive_migrating) failed. " - f"Expected {expected_value4}, got {result4}" - ) - assert result4 == expected_value4, errors_msg + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) - # Step 6: Validate socket timeout is still relaxed during commands 3-4 - self._validate_current_timeout_for_thread(thread_id, 30) + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" - # Command 5: This SET command will receive - # MIGRATED push message before actual response - key_migrated = f"key_receive_migrated_{thread_id}" - value_migrated = f"value3_{thread_id}" - result5 = test_redis_client.set(key_migrated, value_migrated) + # Step 4: Validate timeout was updated to relaxed value after MIGRATING + self._validate_current_timeout(30, "Right after MIGRATING is received. ") - # Validate Command 5 result - errors_msg = f"Thread {thread_id}: Command 5 (SET key_receive_migrated) failed" - assert result5 is True, errors_msg + # Command 3: Another command while timeout is still relaxed + result3 = test_redis_client.get(key1) - # Step 8: Validate socket timeout is reversed back to original after MIGRATED - self._validate_current_timeout_for_thread(thread_id, None) + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected {expected_value3}, got {result3}" + ) - results.append( - { - "thread_id": thread_id, - "success": True, - } - ) + # Command 4: Execute command (step 5) + result4 = test_redis_client.get(key_migrating) - except Exception as e: - errors.append(f"Thread {thread_id}: {e}") + # Validate Command 4 result + expected_value4 = value_migrating.encode() + assert result4 == expected_value4, ( + f"Command 4 (GET key_receive_migrating) failed. Expected {expected_value4}, got {result4}" + ) - # Run operations in multiple threads (step 1) - threads = [] - for i in range(3): - thread = threading.Thread( - target=redis_operations_with_maintenance_events, - args=(i,), - name=str(i), - ) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - # Verify all threads completed successfully - successful_threads = len(results) - assert successful_threads == 3, ( - f"Expected 3 successful threads, got {successful_threads}. " - f"Errors: {errors}" + # Step 6: Validate socket timeout is still relaxed during commands 3-4 + self._validate_current_timeout( + 30, + "Execute a command with a connection extracted from the pool (after it has received MIGRATING)", ) - # Verify maintenance events were processed correctly across all threads - # Note: Different pool types may create different numbers of sockets - # The key is that we have at least 1 socket and all threads succeeded + # Command 5: This SET command will receive + # MIGRATED push message before actual response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result5 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_migrated) failed" + + # Step 8: Validate socket timeout is reversed back to original after MIGRATED + self._validate_current_timeout(None) + + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded assert len(self.mock_sockets) >= 1, ( f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" ) @@ -640,80 +718,37 @@ def test_migrating_event_with_disabled_relax_timeout(self, pool_class): ) try: - # Results storage for thread operations - results = [] - errors = [] - - def redis_operations_with_disabled_relax(thread_id): - """Perform Redis operations with disabled relax timeout in a thread.""" - try: - # Command 1: Initial command - key1 = f"key1_{thread_id}" - value1 = f"value1_{thread_id}" - result1 = test_redis_client.set(key1, value1) - - # Validate Command 1 result - errors_msg = f"Thread {thread_id}: Command 1 (SET key1) failed" - assert result1 is True, errors_msg - - # Command 2: This SET command will receive MIGRATING push message before response - key_migrating = f"key_receive_migrating_{thread_id}" - value_migrating = f"value2_{thread_id}" - result2 = test_redis_client.set(key_migrating, value_migrating) - - # Validate Command 2 result - errors_msg = f"Thread {thread_id}: Command 2 (SET key_receive_migrating) failed" - assert result2 is True, errors_msg - - # Validate timeout was NOT updated (relax is disabled) - # Should remain at default timeout (None), not relaxed to 30s - self._validate_current_timeout_for_thread(thread_id, None) - - # Command 3: Another command to verify timeout remains unchanged - result3 = test_redis_client.get(key1) - - # Validate Command 3 result - expected_value3 = value1.encode() - errors_msg = ( - f"Thread {thread_id}: Command 3 (GET key1) failed. " - f"Expected: {expected_value3}, Got: {result3}" - ) - assert result3 == expected_value3, errors_msg + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) - results.append( - { - "thread_id": thread_id, - "success": True, - } - ) + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" - except Exception as e: - errors.append(f"Thread {thread_id}: {str(e)}") + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) - # Run operations in multiple threads to test concurrent behavior - threads = [] - for i in range(3): - thread = threading.Thread( - target=redis_operations_with_disabled_relax, args=(i,) - ) - threads.append(thread) - thread.start() + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" - # Wait for all threads to complete - for thread in threads: - thread.join() + # Validate timeout was NOT updated (relax is disabled) + # Should remain at default timeout (None), not relaxed to 30s + self._validate_current_timeout(None) - # Verify no errors occurred - assert len(errors) == 0, f"Errors occurred: {errors}" + # Command 3: Another command to verify timeout remains unchanged + result3 = test_redis_client.get(key1) - # Verify all operations completed successfully - assert len(results) == 3, ( - f"Expected 3 successful threads, got {len(results)}" + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" ) - # Verify maintenance events were processed correctly across all threads - # Note: Different pool types may create different numbers of sockets - # The key is that we have at least 1 socket and all threads succeeded + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded assert len(self.mock_sockets) >= 1, ( f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" ) @@ -726,6 +761,13 @@ def redis_operations_with_disabled_relax(thread_id): def test_moving_related_events_handling_integration(self, pool_class): """ Test full integration of moving-related events (MOVING) handling with Redis commands. + + This test validates the complete MOVING event lifecycle: + 1. Creates multiple connections in the pool + 2. Executes a Redis command that triggers a MOVING push message + 3. Validates that pool configuration is updated with temporary address and timeout + 4. Validates that existing connections are marked for disconnection + 5. Tests both ConnectionPool and BlockingConnectionPool implementations """ # Create a pool and Redis client with maintenance events and pool handler test_redis_client = self._get_client( @@ -956,8 +998,6 @@ def test_create_new_conn_after_moving_expires(self, pool_class): @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) def test_receive_migrated_after_moving(self, pool_class): - # TODO Refactor: when migrated comes after moving and - # moving hasn't yet expired - it should not decrease timeouts """ Test receiving MIGRATED event after MOVING event. @@ -966,6 +1006,9 @@ def test_receive_migrated_after_moving(self, pool_class): 2. MIGRATED event is received during command execution 3. Temporary settings are cleared after MIGRATED 4. Pool configuration is restored to original values + + Note: When MIGRATED comes after MOVING and MOVING hasn't yet expired, + it should not decrease timeouts (future refactoring consideration). """ # Create a pool and Redis client with maintenance events and pool handler test_redis_client = self._get_client( From 0744ee5927f07313ce3cf9635fc2d4e73fea8045 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 15:38:30 +0300 Subject: [PATCH 09/16] Fixing linters --- redis/client.py | 1 - redis/maintenance_events.py | 1 - 2 files changed, 2 deletions(-) diff --git a/redis/client.py b/redis/client.py index 473b1e00f2..a1a053ddc6 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,5 +1,4 @@ import copy -import logging import re import threading import time diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index bf0cd6bda8..5fe2e81de8 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -371,7 +371,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent): # take care for the inactive connections in the pool # delete them and create new ones - start_time_2 = time.time() self.pool.disconnect_and_reconfigure_free_connections( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, From 4c536f391296cf2207ea2f613e4b3e26d69b0ddd Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 15:42:41 +0300 Subject: [PATCH 10/16] Applying Copilot's comments --- redis/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index a1a053ddc6..a6c96c3882 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1662,7 +1662,7 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: - # in reset() the connection is diconnected before returned to the pool if + # in reset() the connection is disconnected before returned to the pool if # it is marked for reconnect. self.reset() From 6768d5d4c07798f9ebccc44aee44af6e13848cc2 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 15:49:20 +0300 Subject: [PATCH 11/16] Fixed type annotations not compatible with older python versions --- redis/connection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index a096b045b2..57e8869e40 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -824,8 +824,8 @@ def update_parser_buffer_timeout(self, timeout: Optional[float] = None): def update_tmp_settings( self, - tmp_host_address: Optional[str | object] = SENTINEL, - tmp_relax_timeout: Optional[float | object] = SENTINEL, + tmp_host_address: Optional[Union[str, object]] = SENTINEL, + tmp_relax_timeout: Optional[Union[float, object]] = SENTINEL, ): """ The value of SENTINEL is used to indicate that the property should not be updated. From ce31ec76a6c069a5596022ac9d85acc37d9e8116 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 17 Jul 2025 19:12:07 +0300 Subject: [PATCH 12/16] Add a few more tests and fix pool mock for python 3.9 --- redis/connection.py | 1 - tests/test_maintenance_events.py | 8 +- tests/test_maintenance_events_handling.py | 126 ++++++++++++++++++++++ 3 files changed, 131 insertions(+), 4 deletions(-) diff --git a/redis/connection.py b/redis/connection.py index 57e8869e40..7e9ad95b21 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -2155,7 +2155,6 @@ def release(self, connection): connection.disconnect() # Put the connection back into the pool. try: - print("Releasing connection - in the pool") self.pool.put_nowait(connection) except Full: # perhaps the pool has been reset() after a fork? regardless, diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index 69a6014fe1..ac7d10b51e 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -1,5 +1,6 @@ import threading -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock +import pytest from redis.maintenance_events import ( MaintenanceEvent, @@ -17,8 +18,6 @@ class TestMaintenanceEvent: def test_abstract_class_cannot_be_instantiated(self): """Test that MaintenanceEvent cannot be instantiated directly.""" - import pytest - with patch("time.monotonic", return_value=1000): with pytest.raises(TypeError): MaintenanceEvent(id=1, ttl=10) # type: ignore @@ -347,6 +346,9 @@ class TestMaintenanceEventPoolHandler: def setup_method(self): """Set up test fixtures.""" self.mock_pool = Mock() + self.mock_pool._lock = MagicMock() + self.mock_pool._lock.__enter__.return_value = None + self.mock_pool._lock.__exit__.return_value = None self.config = MaintenanceEventsConfig( enabled=True, proactive_reconnect=True, relax_timeout=20 ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 1620471ea7..6ce74ebc3b 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1102,3 +1102,129 @@ def test_receive_migrated_after_moving(self, pool_class): finally: if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_overlapping_moving_events(self, pool_class): + """ + Test handling of overlapping/duplicate MOVING events (e.g., two MOVING events before the first expires). + Ensures that the second MOVING event updates the pool and connections as expected, and that expiry/cleanup works. + """ + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + try: + # Create and release some connections + for _ in range(3): + conn = test_redis_client.connection_pool.get_connection() + test_redis_client.connection_pool.release(conn) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = test_redis_client.connection_pool.get_connection() + in_use_connections.append(conn) + + # Trigger first MOVING event + key_moving1 = "key_receive_moving_0" + value_moving1 = "value3_0" + result1 = test_redis_client.set(key_moving1, value_moving1) + assert result1 is True + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + # Validate all connections reflect the first MOVING event + self._validate_in_use_connections_state(in_use_connections) + self._validate_free_connections_state( + test_redis_client.connection_pool, + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + + # Before the first MOVING expires, trigger a second MOVING event (simulate new address) + # Patch MockSocket to use a new address for the second event + new_address = "5.6.7.8:6380" + orig_after_moving = MockSocket.AFTER_MOVING_ADDRESS + MockSocket.AFTER_MOVING_ADDRESS = new_address + try: + key_moving2 = "key_receive_moving_1" + value_moving2 = "value3_1" + result2 = test_redis_client.set(key_moving2, value_moving2) + assert result2 is True + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + new_address.split(":")[0], + self.config.relax_timeout, + ) + # Validate all connections reflect the second MOVING event + self._validate_in_use_connections_state(in_use_connections) + self._validate_free_connections_state( + test_redis_client.connection_pool, + new_address.split(":")[0], + self.config.relax_timeout, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + finally: + MockSocket.AFTER_MOVING_ADDRESS = orig_after_moving + + # Wait for both MOVING timeouts to expire + sleep(MockSocket.MOVING_TIMEOUT + 0.5) + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + None, + -1, + ) + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_thread_safety_concurrent_event_handling(self, pool_class): + """ + Test thread-safety under concurrent maintenance event handling. + Simulates multiple threads triggering MOVING events and performing operations concurrently. + """ + import threading + + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + results = [] + errors = [] + + def worker(idx): + try: + key = f"key_receive_moving_{idx}" + value = f"value3_{idx}" + result = test_redis_client.set(key, value) + results.append(result) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + assert all(results), f"Not all threads succeeded: {results}" + assert not errors, f"Errors occurred in threads: {errors}" + # After all threads, MOVING event should have been handled safely + self._validate_conn_kwargs( + test_redis_client.connection_pool, + MockSocket.DEFAULT_ADDRESS.split(":")[0], + int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + self.config.relax_timeout, + ) + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() From d73cd35fe96fa5e5c1bf6344cfa3ba97fddb75b0 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 18 Jul 2025 16:14:01 +0300 Subject: [PATCH 13/16] Adding maintenance state to connections. Migrating and Migrated are not processed in in Moving state. Tests are updated --- redis/_parsers/hiredis.py | 1 + redis/connection.py | 47 +++++- redis/maintenance_events.py | 32 +++- tests/test_connection_pool.py | 10 +- tests/test_maintenance_events_handling.py | 197 +++++++++++++++++----- 5 files changed, 237 insertions(+), 50 deletions(-) diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index e9df314a8c..d82fe99cd9 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -152,6 +152,7 @@ def read_response(self, disable_decoding=False, push_request=False): disable_decoding=disable_decoding, push_request=push_request, ) + return response if disable_decoding: response = self._reader.gets(False) diff --git a/redis/connection.py b/redis/connection.py index 7e9ad95b21..5646a745af 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -42,6 +42,7 @@ MaintenanceEventConnectionHandler, MaintenanceEventPoolHandler, MaintenanceEventsConfig, + MaintenanceState, ) from .retry import Retry from .utils import ( @@ -285,6 +286,7 @@ def __init__( maintenance_events_config: Optional[MaintenanceEventsConfig] = None, tmp_host_address: Optional[str] = None, tmp_relax_timeout: Optional[float] = -1, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, ): """ Initialize a new Connection. @@ -374,6 +376,7 @@ def __init__( self._should_reconnect = False self.tmp_host_address = tmp_host_address self.tmp_relax_timeout = tmp_relax_timeout + self.maintenance_state = maintenance_state def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -835,6 +838,9 @@ def update_tmp_settings( if tmp_relax_timeout is not SENTINEL: self.tmp_relax_timeout = tmp_relax_timeout + def set_maintenance_state(self, state: "MaintenanceState"): + self.maintenance_state = state + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -1724,11 +1730,18 @@ def make_connection(self) -> "ConnectionInterface": raise MaxConnectionsError("Too many connections") self._created_connections += 1 + # Pass current maintenance_state to new connections + maintenance_state = self.connection_kwargs.get( + "maintenance_state", MaintenanceState.NONE + ) + kwargs = dict(self.connection_kwargs) + kwargs["maintenance_state"] = maintenance_state + if self.cache is not None: return CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock + self.connection_class(**kwargs), self.cache, self._lock ) - return self.connection_class(**self.connection_kwargs) + return self.connection_class(**kwargs) def release(self, connection: "Connection") -> None: "Releases the connection back to the pool" @@ -1953,6 +1966,16 @@ async def _mock(self, error: RedisError): """ pass + def set_maintenance_state_for_all(self, state: "MaintenanceState"): + with self._lock: + for conn in self._available_connections: + conn.set_maintenance_state(state) + for conn in self._in_use_connections: + conn.set_maintenance_state(state) + + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state + class BlockingConnectionPool(ConnectionPool): """ @@ -2047,15 +2070,20 @@ def make_connection(self): if self._in_maintenance: self._lock.acquire() self._locked = True + # Pass current maintenance_state to new connections + maintenance_state = self.connection_kwargs.get( + "maintenance_state", MaintenanceState.NONE + ) + kwargs = dict(self.connection_kwargs) + kwargs["maintenance_state"] = maintenance_state if self.cache is not None: connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), + self.connection_class(**kwargs), self.cache, self._lock, ) else: - connection = self.connection_class(**self.connection_kwargs) - + connection = self.connection_class(**kwargs) self._connections.append(connection) return connection finally: @@ -2266,3 +2294,12 @@ def _update_maintenance_events_configs_for_connections( def set_in_maintenance(self, in_maintenance: bool): """Set the maintenance mode for the connection pool.""" self._in_maintenance = in_maintenance + + def set_maintenance_state_for_all(self, state: "MaintenanceState"): + with self._lock: + for conn in getattr(self, "_connections", []): + if conn: + conn.set_maintenance_state(state) + + def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 5fe2e81de8..dd62602105 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -1,3 +1,4 @@ +import enum import logging import threading import time @@ -6,6 +7,13 @@ from redis.typing import Number + +class MaintenanceState(enum.Enum): + NONE = "none" + MOVING = "moving" + MIGRATING = "migrating" + + if TYPE_CHECKING: from redis.connection import ( BlockingConnectionPool, @@ -351,6 +359,9 @@ def handle_node_moving_event(self, event: NodeMovingEvent): ): if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(True) + # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) + self.pool.set_maintenance_state_for_all(MaintenanceState.MOVING) + self.pool.set_maintenance_state_in_kwargs(MaintenanceState.MOVING) # edit the config for new connections until the notification expires self.pool.update_connection_kwargs_with_tmp_settings( tmp_host_address=event.new_node_host, @@ -368,7 +379,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent): tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, ) - # take care for the inactive connections in the pool # delete them and create new ones self.pool.disconnect_and_reconfigure_free_connections( @@ -388,16 +398,19 @@ def handle_node_moved_event(self): tmp_host_address=None, tmp_relax_timeout=-1, ) + # Clear state to NONE in kwargs immediately after updating tmp kwargs + self.pool.set_maintenance_state_in_kwargs(MaintenanceState.NONE) with self.pool._lock: if self.config.is_relax_timeouts_enabled(): # reset the timeout for existing connections self.pool.update_connections_current_timeout( relax_timeout=-1, include_free_connections=True ) - self.pool.update_connections_tmp_settings( tmp_host_address=None, tmp_relax_timeout=-1 ) + # Clear state to NONE for all connections + self.pool.set_maintenance_state_for_all(MaintenanceState.NONE) class MaintenanceEventConnectionHandler: @@ -416,17 +429,24 @@ def handle_event(self, event: MaintenanceEvent): logging.error(f"Unhandled event type: {event}") def handle_migrating_event(self, notification: NodeMigratingEvent): - if not self.config.is_relax_timeouts_enabled(): + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): return - + self.connection.set_maintenance_state(MaintenanceState.MIGRATING) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): - if not self.config.is_relax_timeouts_enabled(): + # Only reset timeouts if state is not MOVING and relax timeouts are enabled + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): return - + self.connection.set_maintenance_state(MaintenanceState.NONE) # Node migration completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 4518cd7290..880b6db27e 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -9,6 +9,7 @@ import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool +from redis.maintenance_events import MaintenanceState from redis.utils import SSL_AVAILABLE from .conftest import ( @@ -53,10 +54,15 @@ def get_pool( return pool def test_connection_creation(self): - connection_kwargs = {"foo": "bar", "biz": "baz"} + connection_kwargs = { + "foo": "bar", + "biz": "baz", + "maintenance_state": MaintenanceState.NONE, + } pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -152,7 +158,9 @@ def test_connection_creation(self, master_host): "host": master_host[0], "port": master_host[1], } + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection_kwargs["maintenance_state"] = MaintenanceState.NONE connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index 6ce74ebc3b..b573a55e5f 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -6,11 +6,18 @@ from time import sleep from redis import Redis -from redis.connection import AbstractConnection, ConnectionPool, BlockingConnectionPool +from redis.connection import ( + AbstractConnection, + ConnectionPool, + BlockingConnectionPool, + MaintenanceState, +) from redis.maintenance_events import ( MaintenanceEventsConfig, NodeMigratingEvent, MaintenanceEventPoolHandler, + NodeMovingEvent, + NodeMigratedEvent, ) @@ -326,24 +333,25 @@ def _validate_connected(self, expected_count): assert connected_sockets_count == expected_count def _validate_in_use_connections_state( - self, in_use_connections: List[AbstractConnection] + self, + in_use_connections: List[AbstractConnection], + expected_state=MaintenanceState.NONE, + expected_tmp_host_address=None, + expected_tmp_relax_timeout=-1, + expected_current_socket_timeout=None, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ): """Helper method to validate state of in-use connections.""" # validate in use connections are still working with set flag for reconnect # and timeout is updated for connection in in_use_connections: assert connection._should_reconnect is True - assert ( - connection.tmp_host_address - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] - ) - assert connection.tmp_relax_timeout == self.config.relax_timeout - assert connection._sock.gettimeout() == self.config.relax_timeout + assert connection.tmp_host_address == expected_tmp_host_address + assert connection.tmp_relax_timeout == expected_tmp_relax_timeout + assert connection._sock.gettimeout() == expected_current_socket_timeout assert connection._sock.connected is True - assert ( - connection._sock.getpeername()[0] - == MockSocket.DEFAULT_ADDRESS.split(":")[0] - ) + assert connection.maintenance_state == expected_state + assert connection._sock.getpeername()[0] == expected_current_peername def _validate_free_connections_state( self, @@ -352,39 +360,30 @@ def _validate_free_connections_state( relax_timeout, should_be_connected_count, connected_to_tmp_addres=False, + expected_state=MaintenanceState.MOVING, ): """Helper method to validate state of free/available connections.""" if isinstance(pool, BlockingConnectionPool): - # BlockingConnectionPool uses _connections list where created connections are stored - # but we need to get the ones in the queue - these are the free ones - # the uninitialized connections are filtered out free_connections = [conn for conn in pool.pool.queue if conn is not None] elif isinstance(pool, ConnectionPool): - # Regular ConnectionPool uses _available_connections for free connections free_connections = pool._available_connections else: raise ValueError(f"Unsupported pool type: {type(pool)}") connected_count = 0 - # Validate fields that are validated in the validation of the active connections for connection in free_connections: - # Validate the same fields as in _validate_in_use_connections_state assert connection._should_reconnect is False assert connection.tmp_host_address == tmp_host_address assert connection.tmp_relax_timeout == relax_timeout + assert connection.maintenance_state == expected_state if connection._sock is not None: - connected_count += 1 - + assert connection._sock.connected is True if connected_to_tmp_addres: assert ( connection._sock.getpeername()[0] == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] ) - else: - assert ( - connection._sock.getpeername()[0] - == MockSocket.DEFAULT_ADDRESS.split(":")[0] - ) + connected_count += 1 assert connected_count == should_be_connected_count def _validate_all_timeouts(self, expected_timeout): @@ -804,7 +803,6 @@ def test_moving_related_events_handling_integration(self, pool_class): assert result2 is True, "Command 2 (SET key_receive_moving) failed" # Validate pool and connections settings were updated according to MOVING event - # handling expectations self._validate_conn_kwargs( test_redis_client.connection_pool, MockSocket.DEFAULT_ADDRESS.split(":")[0], @@ -812,25 +810,28 @@ def test_moving_related_events_handling_integration(self, pool_class): MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], self.config.relax_timeout, ) - # 5 disconnects has happened, 1 of them is with reconnect self._validate_disconnected(5) - # 5 in use connected + 1 after reconnect self._validate_connected(6) - self._validate_in_use_connections_state(in_use_connections) - # Validate there is 1 free connection that is connected - # the one that has handled the MOVING should reconnect after parsing the response + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[ + 0 + ], # the in use connections reconnect when they complete their current task + ) self._validate_free_connections_state( test_redis_client.connection_pool, MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], self.config.relax_timeout, should_be_connected_count=1, connected_to_tmp_addres=True, + expected_state=MaintenanceState.MOVING, ) - # Wait for MOVING timeout to expire and the moving completed handler to run - print("Waiting for MOVING timeout to expire...") sleep(MockSocket.MOVING_TIMEOUT + 0.5) - self._validate_all_timeouts(None) self._validate_conn_kwargs( test_redis_client.connection_pool, @@ -845,8 +846,8 @@ def test_moving_related_events_handling_integration(self, pool_class): -1, should_be_connected_count=1, connected_to_tmp_addres=True, + expected_state=MaintenanceState.NONE, ) - finally: if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() @@ -972,7 +973,6 @@ def test_create_new_conn_after_moving_expires(self, pool_class): assert result is True, "SET key_receive_moving command failed" # Wait for MOVING timeout to expire - print("Waiting for MOVING timeout to expire...") sleep(MockSocket.MOVING_TIMEOUT + 0.5) # Now get several new connections after expiration @@ -1137,7 +1137,14 @@ def test_overlapping_moving_events(self, pool_class): self.config.relax_timeout, ) # Validate all connections reflect the first MOVING event - self._validate_in_use_connections_state(in_use_connections) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) self._validate_free_connections_state( test_redis_client.connection_pool, MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], @@ -1164,7 +1171,14 @@ def test_overlapping_moving_events(self, pool_class): self.config.relax_timeout, ) # Validate all connections reflect the second MOVING event - self._validate_in_use_connections_state(in_use_connections) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=new_address.split(":")[0], + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) self._validate_free_connections_state( test_redis_client.connection_pool, new_address.split(":")[0], @@ -1228,3 +1242,110 @@ def worker(idx): ) if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): + """ + Test moving configs are not lost if the per connection events get picked up after moving is handled. + MOVING → MIGRATING → MIGRATED → MOVED + Checks the state after each event for all connections and for new connections created during each state. + """ + # Setup + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + pool = test_redis_client.connection_pool + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + in_use_connections = [] + for _ in range(3): + in_use_connections.append(pool.get_connection()) + while len(in_use_connections) > 0: + pool.release(in_use_connections.pop()) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = pool.get_connection() + in_use_connections.append(conn) + + # 1. MOVING event + tmp_address = "22.23.24.25" + moving_event = NodeMovingEvent( + id=1, new_node_host=tmp_address, new_node_port=6379, ttl=1 + ) + pool_handler.handle_event(moving_event) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + self._validate_free_connections_state( + pool, + tmp_address, + self.config.relax_timeout, + should_be_connected_count=0, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.MOVING, + ) + + # 2. MIGRATING event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratingEvent(id=2, ttl=1) + ) + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 3. MIGRATED event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratedEvent(id=2) + ) + # State should not change for connections that are in MOVING state + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_tmp_host_address=tmp_address, + expected_tmp_relax_timeout=self.config.relax_timeout, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + + # 4. MOVED event (simulate timer expiry) + pool_handler.handle_node_moved_event() + self._validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_tmp_host_address=None, + expected_tmp_relax_timeout=-1, + expected_current_socket_timeout=None, + expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + ) + self._validate_free_connections_state( + pool, + None, + -1, + should_be_connected_count=0, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.NONE, + ) + # New connection after MOVED + new_conn_none = pool.get_connection() + assert new_conn_none.maintenance_state == MaintenanceState.NONE + pool.release(new_conn_none) + # Cleanup + for conn in in_use_connections: + pool.release(conn) + if hasattr(pool, "disconnect"): + pool.disconnect() From 788cf524398589558e59ac269a4fe345924773ad Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Tue, 22 Jul 2025 19:27:57 +0300 Subject: [PATCH 14/16] Refactored the tmp host address and timeout storing and the way to apply them during connect --- redis/_parsers/base.py | 2 - redis/connection.py | 383 ++++++++++++++-------- redis/maintenance_events.py | 71 ++-- tests/test_maintenance_events.py | 21 +- tests/test_maintenance_events_handling.py | 360 ++++++++++++-------- 5 files changed, 536 insertions(+), 301 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index f2670e43b0..77d0188092 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -130,8 +130,6 @@ def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock timeout = connection.socket_timeout - if connection.tmp_relax_timeout != -1: - timeout = connection.tmp_relax_timeout self._buffer = SocketBuffer(self._sock, self.socket_read_size, timeout) self.encoder = connection.encoder diff --git a/redis/connection.py b/redis/connection.py index 5646a745af..c20c89dd9d 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1,5 +1,4 @@ import copy -import logging import os import socket import sys @@ -236,22 +235,62 @@ def re_auth(self): @abstractmethod def mark_for_reconnect(self): + """ + Mark the connection to be reconnected on the next command. + This is useful when a connection is moved to a different node. + """ pass @abstractmethod def should_reconnect(self): + """ + Returns True if the connection should be reconnected. + """ + pass + + @property + @abstractmethod + def maintenance_state(self) -> MaintenanceState: + """ + Returns the current maintenance state of the connection. + """ + pass + + @maintenance_state.setter + @abstractmethod + def maintenance_state(self, state: "MaintenanceState"): + """ + Sets the current maintenance state of the connection. + """ pass @abstractmethod def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + """ + Update the timeout for the current socket. + """ pass @abstractmethod - def update_tmp_settings( + def set_tmp_settings( self, tmp_host_address: Optional[str] = None, tmp_relax_timeout: Optional[float] = None, ): + """ + Updates temporary host address and timeout settings for the connection. + """ + pass + + @abstractmethod + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + """ + Resets temporary host address and timeout settings for the connection. + """ pass @@ -284,8 +323,9 @@ def __init__( event_dispatcher: Optional[EventDispatcher] = None, maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, maintenance_events_config: Optional[MaintenanceEventsConfig] = None, - tmp_host_address: Optional[str] = None, - tmp_relax_timeout: Optional[float] = -1, + orig_host_address: Optional[str] = None, + orig_socket_timeout: Optional[float] = None, + orig_socket_connect_timeout: Optional[float] = None, maintenance_state: "MaintenanceState" = MaintenanceState.NONE, ): """ @@ -374,8 +414,9 @@ def __init__( self._command_packer = self._construct_command_packer(command_packer) self._should_reconnect = False - self.tmp_host_address = tmp_host_address - self.tmp_relax_timeout = tmp_relax_timeout + self.orig_host_address = orig_host_address + self.orig_socket_timeout = orig_socket_timeout + self.orig_socket_connect_timeout = orig_socket_connect_timeout self.maintenance_state = maintenance_state def __repr__(self): @@ -809,6 +850,14 @@ def re_auth(self): self.read_response() self._re_auth_token = None + @property + def maintenance_state(self) -> MaintenanceState: + return self._maintenance_state + + @maintenance_state.setter + def maintenance_state(self, state: "MaintenanceState"): + self._maintenance_state = state + def mark_for_reconnect(self): self._should_reconnect = True @@ -825,21 +874,40 @@ def update_parser_buffer_timeout(self, timeout: Optional[float] = None): if self._parser and self._parser._buffer: self._parser._buffer.socket_timeout = timeout - def update_tmp_settings( + def set_tmp_settings( self, tmp_host_address: Optional[Union[str, object]] = SENTINEL, - tmp_relax_timeout: Optional[Union[float, object]] = SENTINEL, + tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): """ The value of SENTINEL is used to indicate that the property should not be updated. """ if tmp_host_address is not SENTINEL: - self.tmp_host_address = tmp_host_address - if tmp_relax_timeout is not SENTINEL: - self.tmp_relax_timeout = tmp_relax_timeout - - def set_maintenance_state(self, state: "MaintenanceState"): - self.maintenance_state = state + if not skip_original_data_update: + self.orig_host_address = self.host + self.host = tmp_host_address + if tmp_relax_timeout != -1: + if not skip_original_data_update: + self.orig_socket_timeout = self.socket_timeout + self.orig_socket_connect_timeout = self.socket_connect_timeout + + self.socket_timeout = tmp_relax_timeout + self.socket_connect_timeout = tmp_relax_timeout + + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + if reset_host_address: + self.host = self.orig_host_address + self.orig_host_address = None + if reset_relax_timeout: + self.socket_timeout = self.orig_socket_timeout + self.socket_connect_timeout = self.orig_socket_connect_timeout + self.orig_socket_timeout = None + self.orig_socket_connect_timeout = None class Connection(AbstractConnection): @@ -873,10 +941,9 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None - host = self.tmp_host_address or self.host for res in socket.getaddrinfo( - host, self.port, self.socket_type, socket.SOCK_STREAM + self.host, self.port, self.socket_type, socket.SOCK_STREAM ): family, socktype, proto, canonname, socket_address = res sock = None @@ -892,19 +959,13 @@ def _connect(self): sock.setsockopt(socket.IPPROTO_TCP, k, v) # set the socket_connect_timeout before we connect - if self.tmp_relax_timeout != -1: - sock.settimeout(self.tmp_relax_timeout) - else: - sock.settimeout(self.socket_connect_timeout) + sock.settimeout(self.socket_connect_timeout) # connect sock.connect(socket_address) # set the socket_timeout now that we're connected - if self.tmp_relax_timeout != -1: - sock.settimeout(self.tmp_relax_timeout) - else: - sock.settimeout(self.socket_timeout) + sock.settimeout(self.socket_timeout) return sock except OSError as _: @@ -1818,54 +1879,128 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) - def update_connection_kwargs_with_tmp_settings( + def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): + for conn in self._available_connections: + conn.maintenance_state = state + for conn in self._in_use_connections: + conn.maintenance_state = state + + def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state + + def add_tmp_config_to_connection_kwargs( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): """ - Update the connection kwargs with the temporary host address and the - relax timeout(if enabled). - This is used when a cluster node is rebind to a different address. + Store original connection configuration and apply temporary settings. + + This method saves the current host, socket_timeout, and socket_connect_timeout values + in temporary storage fields (orig_*), then applies the provided temporary values + as the active connection configuration. + + This is used when a cluster node is rebound to a different address during + maintenance operations. New connections created after this call will use the + temporary configuration until remove_tmp_config_from_connection_kwargs() is called. + + When this method is called the pool will already be locked, so getting the pool + lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for new connections. + This parameter is required and will replace the current host. + :param tmp_relax_timeout: The temporary timeout value to use for both socket_timeout + and socket_connect_timeout. If -1 is provided, the timeout + settings are not modified (relax timeout is disabled). + :param skip_original_data_update: Whether to skip updating the original data. + This is used when we are already in MOVING state + and the original data is already stored in the connection kwargs. + """ + if not skip_original_data_update: + # Store original values in temporary storage + original_host = self.connection_kwargs.get("host") + original_socket_timeout = self.connection_kwargs.get("socket_timeout") + original_connect_timeout = self.connection_kwargs.get( + "socket_connect_timeout" + ) - When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - This new address will be used to create new connections until the old node is decomissioned. + self.connection_kwargs.update( + { + "orig_host_address": original_host, + "orig_socket_timeout": original_socket_timeout, + "orig_socket_connect_timeout": original_connect_timeout, + } + ) - :param tmp_host_address: The temporary host address to use for the connection. - :param tmp_relax_timeout: The relax timeout to use for the connection. - If -1 is provided - the relax timeout is disabled, so the tmp property is not set + # Apply temporary values as active configuration + self.connection_kwargs.update({"host": tmp_host_address}) + + if tmp_relax_timeout != -1: + self.connection_kwargs.update( + { + "socket_timeout": tmp_relax_timeout, + "socket_connect_timeout": tmp_relax_timeout, + } + ) + + def remove_tmp_config_from_connection_kwargs(self): """ - self.connection_kwargs.update({"tmp_host_address": tmp_host_address}) - self.connection_kwargs.update({"tmp_relax_timeout": tmp_relax_timeout}) + Remove temporary configuration from connection kwargs and restore original values. - def update_connections_tmp_settings( - self, - tmp_host_address: Optional[str] = None, - tmp_relax_timeout: Optional[float] = None, - ): + This method restores the original host address, socket timeout, and connect timeout + from their temporary storage back to the main connection kwargs, then clears the + temporary storage fields. + + This is typically called when a cluster node maintenance operation is complete + and the connection should revert to its original configuration. + + When this method is called the pool will already be locked, so getting the pool + lock inside is not needed. """ - Update the tmp settings for all connections in the pool. - This is used when a cluster node is rebind to a different address. + orig_host = self.connection_kwargs.get("orig_host_address") + orig_socket_timeout = self.connection_kwargs.get("orig_socket_timeout") + orig_connect_timeout = self.connection_kwargs.get("orig_socket_connect_timeout") - When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + self.connection_kwargs.update( + { + "orig_host_address": None, + "orig_socket_timeout": None, + "orig_socket_connect_timeout": None, + "host": orig_host, + "socket_timeout": orig_socket_timeout, + "socket_connect_timeout": orig_connect_timeout, + } + ) + + def reset_connections_tmp_settings(self): + """ + Restore original settings from temporary configuration for all connections in the pool. - :param tmp_host_address: The temporary host address to use for the connection. - :param tmp_relax_timeout: The relax timeout to use for the connection. + This method restores each connection's original host, socket_timeout, and socket_connect_timeout + values from their orig_* attributes back to the active connection configuration, then clears + the temporary storage attributes. + + This is used to restore connections to their original configuration after maintenance operations + that required temporary address/timeout changes are complete. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. """ with self._lock: for conn in self._available_connections: - self._update_connection_tmp_settings( - conn, tmp_host_address, tmp_relax_timeout + conn.reset_tmp_settings( + reset_host_address=True, reset_relax_timeout=True ) for conn in self._in_use_connections: - self._update_connection_tmp_settings( - conn, tmp_host_address, tmp_relax_timeout + conn.reset_tmp_settings( + reset_host_address=True, reset_relax_timeout=True ) def update_active_connections_for_reconnect( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): """ Mark all active connections for reconnect. @@ -1873,17 +2008,18 @@ def update_active_connections_for_reconnect( When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - :param tmp_host_address: The temporary host address to use for the connection. + :param orig_host_address: The temporary host address to use for the connection. """ for conn in self._in_use_connections: self._update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout + conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update ) def disconnect_and_reconfigure_free_connections( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): """ Disconnect all free/available connections. @@ -1891,13 +2027,13 @@ def disconnect_and_reconfigure_free_connections( When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - :param tmp_host_address: The temporary host address to use for the connection. - :param tmp_relax_timeout: The relax timeout to use for the connection. + :param orig_host_address: The temporary host address to use for the connection. + :param orig_relax_timeout: The relax timeout to use for the connection. """ for conn in self._available_connections: self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout + conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update ) def update_connections_current_timeout( @@ -1916,48 +2052,40 @@ def update_connections_current_timeout( :param include_available_connections: Whether to include available connections in the update. """ for conn in self._in_use_connections: - self._update_connection_timeout(conn, relax_timeout) + conn.update_current_socket_timeout(relax_timeout) if include_free_connections: for conn in self._available_connections: - self._update_connection_timeout(conn, relax_timeout) + conn.update_current_socket_timeout(relax_timeout) def _update_connection_for_reconnect( self, connection: "Connection", - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): connection.mark_for_reconnect() - self._update_connection_tmp_settings( - connection, tmp_host_address, tmp_relax_timeout + connection.set_tmp_settings( + tmp_host_address=tmp_host_address, + tmp_relax_timeout=tmp_relax_timeout, + skip_original_data_update=skip_original_data_update, ) def _disconnect_and_update_connection_for_reconnect( self, connection: "Connection", - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): connection.disconnect() - self._update_connection_tmp_settings( - connection, tmp_host_address, tmp_relax_timeout + connection.set_tmp_settings( + tmp_host_address=tmp_host_address, + tmp_relax_timeout=tmp_relax_timeout, + skip_original_data_update=skip_original_data_update, ) - def _update_connection_tmp_settings( - self, - connection: "Connection", - tmp_host_address: Optional[str] = None, - tmp_relax_timeout: Optional[float] = None, - ): - connection.tmp_host_address = tmp_host_address - connection.tmp_relax_timeout = tmp_relax_timeout - - def _update_connection_timeout( - self, connection: "Connection", relax_timeout: Optional[Number] - ): - connection.update_current_socket_timeout(relax_timeout) - async def _mock(self, error: RedisError): """ Dummy functions, needs to be passed as error callback to retry object. @@ -1966,16 +2094,6 @@ async def _mock(self, error: RedisError): """ pass - def set_maintenance_state_for_all(self, state: "MaintenanceState"): - with self._lock: - for conn in self._available_connections: - conn.set_maintenance_state(state) - for conn in self._in_use_connections: - conn.set_maintenance_state(state) - - def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): - self.connection_kwargs["maintenance_state"] = state - class BlockingConnectionPool(ConnectionPool): """ @@ -2215,67 +2333,54 @@ def disconnect(self): def update_active_connections_for_reconnect( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + skip_original_data_update: bool = False, ): with self._lock: connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: self._update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout + conn, + tmp_host_address, + tmp_relax_timeout, + skip_original_data_update, ) def disconnect_and_reconfigure_free_connections( self, - tmp_host_address: Optional[str] = None, + tmp_host_address: str, tmp_relax_timeout: Optional[Number] = None, + skip_original_data_update: bool = False, ): - with self._lock: - existing_connections = self.pool.queue + existing_connections = self.pool.queue - for conn in existing_connections: - if conn: - self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout - ) + for conn in existing_connections: + if conn: + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update + ) def update_connections_current_timeout( self, relax_timeout: Optional[float] = None, include_free_connections: bool = False, ): - logging.debug( - f"***** Blocking Pool --> Updating timeouts. relax_timeout: {relax_timeout}" - ) - - with self._lock: - if include_free_connections: - for conn in tuple(self._connections): - self._update_connection_timeout(conn, relax_timeout) - else: - connections_in_queue = {conn for conn in self.pool.queue if conn} - for conn in self._connections: - if conn not in connections_in_queue: - self._update_connection_timeout(conn, relax_timeout) - - def update_connections_tmp_settings( - self, - tmp_host_address: Optional[str] = None, - tmp_relax_timeout: Optional[float] = None, - ): - with self._lock: + if include_free_connections: for conn in tuple(self._connections): - self._update_connection_tmp_settings( - conn, tmp_host_address, tmp_relax_timeout - ) + conn.update_current_socket_timeout(relax_timeout) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + conn.update_current_socket_timeout(relax_timeout) def _update_maintenance_events_config_for_connections( self, maintenance_events_config ): - with self._lock: - for conn in tuple(self._connections): - conn.maintenance_events_config = maintenance_events_config + for conn in tuple(self._connections): + conn.maintenance_events_config = maintenance_events_config def _update_maintenance_events_configs_for_connections( self, maintenance_events_pool_handler @@ -2283,23 +2388,25 @@ def _update_maintenance_events_configs_for_connections( """Override base class method to work with BlockingConnectionPool's structure.""" with self._lock: for conn in tuple(self._connections): - if conn: # conn can be None in BlockingConnectionPool - conn.set_maintenance_event_pool_handler( - maintenance_events_pool_handler - ) - conn.maintenance_events_config = ( - maintenance_events_pool_handler.config - ) + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + + def reset_connections_tmp_settings(self): + """ + Override base class method to work with BlockingConnectionPool's structure. + + Restore original settings from temporary configuration for all connections in the pool. + """ + for conn in tuple(self._connections): + conn.reset_tmp_settings(reset_host_address=True, reset_relax_timeout=True) def set_in_maintenance(self, in_maintenance: bool): """Set the maintenance mode for the connection pool.""" self._in_maintenance = in_maintenance - def set_maintenance_state_for_all(self, state: "MaintenanceState"): - with self._lock: - for conn in getattr(self, "_connections", []): - if conn: - conn.set_maintenance_state(state) + def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): + for conn in self._connections: + conn.maintenance_state = state - def set_maintenance_state_in_kwargs(self, state: "MaintenanceState"): + def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): self.connection_kwargs["maintenance_state"] = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index dd62602105..479a4ba090 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -359,15 +359,35 @@ def handle_node_moving_event(self, event: NodeMovingEvent): ): if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(True) - # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) - self.pool.set_maintenance_state_for_all(MaintenanceState.MOVING) - self.pool.set_maintenance_state_in_kwargs(MaintenanceState.MOVING) + + prev_moving_in_progress = False + if ( + self.pool.connection_kwargs.get("maintenance_state") + == MaintenanceState.MOVING + ): + # The pool is already in MOVING state, update just the new host information + prev_moving_in_progress = True + + if not prev_moving_in_progress: + # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) + self.pool.set_maintenance_state_for_all_connections( + MaintenanceState.MOVING + ) + self.pool.set_maintenance_state_in_connection_kwargs( + MaintenanceState.MOVING + ) # edit the config for new connections until the notification expires - self.pool.update_connection_kwargs_with_tmp_settings( + # skip original data update if we are already in MOVING state + # as the original data is already stored in the connection kwargs + self.pool.add_tmp_config_to_connection_kwargs( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + skip_original_data_update=prev_moving_in_progress, ) - if self.config.is_relax_timeouts_enabled(): + if ( + self.config.is_relax_timeouts_enabled() + and not prev_moving_in_progress + ): # extend the timeout for all connections that are currently in use self.pool.update_connections_current_timeout( self.config.relax_timeout @@ -375,42 +395,53 @@ def handle_node_moving_event(self, event: NodeMovingEvent): if self.config.proactive_reconnect: # take care for the active connections in the pool # mark them for reconnect after they complete the current command + # skip original data update if we are already in MOVING state + # as the original data is already stored in the connection self.pool.update_active_connections_for_reconnect( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + skip_original_data_update=prev_moving_in_progress, ) # take care for the inactive connections in the pool # delete them and create new ones + # skip original data update if we are already in MOVING state + # as the original data is already stored in the connection self.pool.disconnect_and_reconfigure_free_connections( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + skip_original_data_update=prev_moving_in_progress, ) if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(False) - threading.Timer(event.ttl, self.handle_node_moved_event).start() + threading.Timer( + event.ttl, self.handle_node_moved_event, args=(event,) + ).start() self._processed_events.add(event) - def handle_node_moved_event(self): + def handle_node_moved_event(self, event: NodeMovingEvent): with self._lock: - self.pool.update_connection_kwargs_with_tmp_settings( - tmp_host_address=None, - tmp_relax_timeout=-1, - ) + if self.pool.connection_kwargs.get("host") != event.new_node_host: + # if the current host is not matching the event + # it means there has been a new moving event after this one + # so we don't need to handle this one anymore + # the settings will be reverted by the moved handler of the next event + return + self.pool.remove_tmp_config_from_connection_kwargs() # Clear state to NONE in kwargs immediately after updating tmp kwargs - self.pool.set_maintenance_state_in_kwargs(MaintenanceState.NONE) + self.pool.set_maintenance_state_in_connection_kwargs(MaintenanceState.NONE) with self.pool._lock: + self.pool.reset_connections_tmp_settings() if self.config.is_relax_timeouts_enabled(): # reset the timeout for existing connections self.pool.update_connections_current_timeout( relax_timeout=-1, include_free_connections=True ) - self.pool.update_connections_tmp_settings( - tmp_host_address=None, tmp_relax_timeout=-1 - ) # Clear state to NONE for all connections - self.pool.set_maintenance_state_for_all(MaintenanceState.NONE) + self.pool.set_maintenance_state_for_all_connections( + MaintenanceState.NONE + ) class MaintenanceEventConnectionHandler: @@ -434,10 +465,10 @@ def handle_migrating_event(self, notification: NodeMigratingEvent): or not self.config.is_relax_timeouts_enabled() ): return - self.connection.set_maintenance_state(MaintenanceState.MIGRATING) + self.connection.maintenance_state = MaintenanceState.MIGRATING + self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) # extend the timeout for all created connections self.connection.update_current_socket_timeout(self.config.relax_timeout) - self.connection.update_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): # Only reset timeouts if state is not MOVING and relax timeouts are enabled @@ -446,8 +477,8 @@ def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): or not self.config.is_relax_timeouts_enabled() ): return - self.connection.set_maintenance_state(MaintenanceState.NONE) + self.connection.reset_tmp_settings(reset_relax_timeout=True) # Node migration completed - reset the connection # timeouts by providing -1 as the relax timeout self.connection.update_current_socket_timeout(-1) - self.connection.update_tmp_settings(tmp_relax_timeout=-1) + self.connection.maintenance_state = MaintenanceState.NONE diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index ac7d10b51e..37ef869100 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -438,7 +438,7 @@ def test_handle_node_moving_event_success(self): # Verify timer was started mock_timer.assert_called_once_with( - event.ttl, self.handler.handle_node_moved_event + event.ttl, self.handler.handle_node_moved_event, args=(event,) ) mock_timer.return_value.start.assert_called_once() @@ -446,17 +446,18 @@ def test_handle_node_moving_event_success(self): assert event in self.handler._processed_events # Verify pool methods were called - self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once() + self.mock_pool.add_tmp_config_to_connection_kwargs.assert_called_once() def test_handle_node_moved_event(self): """Test handling of node moved event (cleanup).""" - self.handler.handle_node_moved_event() + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.mock_pool.connection_kwargs = {"host": "localhost"} + self.handler.handle_node_moved_event(event) # Verify cleanup methods were called - self.mock_pool.update_connection_kwargs_with_tmp_settings.assert_called_once_with( - tmp_host_address=None, - tmp_relax_timeout=-1, - ) + self.mock_pool.remove_tmp_config_from_connection_kwargs.assert_called_once() class TestMaintenanceEventConnectionHandler: @@ -519,7 +520,7 @@ def test_handle_migrating_event_success(self): self.handler.handle_migrating_event(event) self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) - self.mock_connection.update_tmp_settings.assert_called_once_with( + self.mock_connection.set_tmp_settings.assert_called_once_with( tmp_relax_timeout=20 ) @@ -540,6 +541,6 @@ def test_handle_migration_completed_event_success(self): self.handler.handle_migration_completed_event(event) self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) - self.mock_connection.update_tmp_settings.assert_called_once_with( - tmp_relax_timeout=-1 + self.mock_connection.reset_tmp_settings.assert_called_once_with( + reset_relax_timeout=True ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index b573a55e5f..fe0b529fdb 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -280,25 +280,6 @@ def _validate_connection_handlers(self, conn, pool_handler, config): # Validate that the connection's maintenance handler has the same config object assert conn._maintenance_event_connection_handler.config is config - def _validate_current_timeout_for_thread( - self, thread_id, expected_timeout, error_msg=None - ): - """Helper method to validate the current timeout for the calling thread.""" - actual_timeout = None - # Get the actual thread ID from the current thread - current_thread_id = threading.current_thread().ident - for sock in self.mock_sockets: - if current_thread_id in sock.thread_timeouts: - actual_timeout = sock.thread_timeouts[current_thread_id] - break - - assert actual_timeout == expected_timeout, ( - error_msg, - f"Thread {thread_id}: Expected timeout ({expected_timeout}), " - f"but found timeout: {actual_timeout} for thread {thread_id}. " - f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", - ) - def _validate_current_timeout(self, expected_timeout, error_msg=None): """Helper method to validate the current timeout for the calling thread.""" actual_timeout = None @@ -336,8 +317,12 @@ def _validate_in_use_connections_state( self, in_use_connections: List[AbstractConnection], expected_state=MaintenanceState.NONE, - expected_tmp_host_address=None, - expected_tmp_relax_timeout=-1, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ): @@ -346,21 +331,33 @@ def _validate_in_use_connections_state( # and timeout is updated for connection in in_use_connections: assert connection._should_reconnect is True - assert connection.tmp_host_address == expected_tmp_host_address - assert connection.tmp_relax_timeout == expected_tmp_relax_timeout - assert connection._sock.gettimeout() == expected_current_socket_timeout - assert connection._sock.connected is True + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + if connection._sock is not None: + assert connection._sock.gettimeout() == expected_current_socket_timeout + assert connection._sock.connected is True + assert connection._sock.getpeername()[0] == expected_current_peername assert connection.maintenance_state == expected_state - assert connection._sock.getpeername()[0] == expected_current_peername def _validate_free_connections_state( self, pool, - tmp_host_address, - relax_timeout, should_be_connected_count, connected_to_tmp_addres=False, expected_state=MaintenanceState.MOVING, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ): """Helper method to validate state of free/available connections.""" if isinstance(pool, BlockingConnectionPool): @@ -373,8 +370,15 @@ def _validate_free_connections_state( connected_count = 0 for connection in free_connections: assert connection._should_reconnect is False - assert connection.tmp_host_address == tmp_host_address - assert connection.tmp_relax_timeout == relax_timeout + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) assert connection.maintenance_state == expected_state if connection._sock is not None: assert connection._sock.connected is True @@ -401,14 +405,29 @@ def _validate_conn_kwargs( pool, expected_host_address, expected_port, - expected_tmp_host_address, - expected_tmp_relax_timeout, + expected_socket_timeout, + expected_socket_connect_timeout, + expected_orig_host_address, + expected_orig_socket_timeout, + expected_orig_socket_connect_timeout, ): """Helper method to validate connection kwargs.""" assert pool.connection_kwargs["host"] == expected_host_address assert pool.connection_kwargs["port"] == expected_port - assert pool.connection_kwargs["tmp_host_address"] == expected_tmp_host_address - assert pool.connection_kwargs["tmp_relax_timeout"] == expected_tmp_relax_timeout + assert pool.connection_kwargs["socket_timeout"] == expected_socket_timeout + assert ( + pool.connection_kwargs["socket_connect_timeout"] + == expected_socket_connect_timeout + ) + assert pool.connection_kwargs["orig_host_address"] == expected_orig_host_address + assert ( + pool.connection_kwargs["orig_socket_timeout"] + == expected_orig_socket_timeout + ) + assert ( + pool.connection_kwargs["orig_socket_connect_timeout"] + == expected_orig_socket_connect_timeout + ) def test_client_initialization(self): """Test that Redis client is created with maintenance events configuration.""" @@ -427,8 +446,8 @@ def test_client_initialization(self): conn = test_redis_client.connection_pool.get_connection() assert conn._should_reconnect is False - assert conn.tmp_host_address is None - assert conn.tmp_relax_timeout == -1 + assert conn.orig_host_address is None + assert conn.orig_socket_timeout is None # Test that the node moving handler function is correctly set by # comparing the underlying function and instance @@ -764,7 +783,8 @@ def test_moving_related_events_handling_integration(self, pool_class): This test validates the complete MOVING event lifecycle: 1. Creates multiple connections in the pool 2. Executes a Redis command that triggers a MOVING push message - 3. Validates that pool configuration is updated with temporary address and timeout + 3. Validates that pool configuration is updated with temporary + address and timeout - for new connections creation 4. Validates that existing connections are marked for disconnection 5. Tests both ConnectionPool and BlockingConnectionPool implementations """ @@ -804,46 +824,64 @@ def test_moving_related_events_handling_integration(self, pool_class): # Validate pool and connections settings were updated according to MOVING event self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, ) self._validate_disconnected(5) self._validate_connected(6) self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[ 0 ], # the in use connections reconnect when they complete their current task ) self._validate_free_connections_state( - test_redis_client.connection_pool, - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_state=MaintenanceState.MOVING, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, should_be_connected_count=1, connected_to_tmp_addres=True, - expected_state=MaintenanceState.MOVING, ) # Wait for MOVING timeout to expire and the moving completed handler to run sleep(MockSocket.MOVING_TIMEOUT + 0.5) self._validate_all_timeouts(None) self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - None, - -1, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) self._validate_free_connections_state( - test_redis_client.connection_pool, - None, - -1, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, should_be_connected_count=1, connected_to_tmp_addres=True, expected_state=MaintenanceState.NONE, @@ -896,11 +934,14 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): # Validate pool and connections settings were updated according to MOVING event self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, ) # Now get several more connections to force creation of new ones @@ -915,11 +956,8 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): # Validate that new connections are created with temporary address and relax timeout # and when connecting those configs are used # get_connection() returns a connection that is already connected - assert ( - new_connection.tmp_host_address - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] - ) - assert new_connection.tmp_relax_timeout == self.config.relax_timeout + assert new_connection.host == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + assert new_connection.socket_timeout is self.config.relax_timeout # New connections should be connected to the temporary address assert new_connection._sock is not None assert new_connection._sock.connected is True @@ -984,8 +1022,8 @@ def test_create_new_conn_after_moving_expires(self, pool_class): new_connection = test_redis_client.connection_pool.get_connection() # Validate that new connections are created with original address (no temporary settings) - assert new_connection.tmp_host_address is None - assert new_connection.tmp_relax_timeout == -1 + assert new_connection.orig_host_address is None + assert new_connection.orig_socket_timeout is None # New connections should be connected to the original address assert new_connection._sock is not None assert new_connection._sock.connected is True @@ -1044,13 +1082,18 @@ def test_receive_migrated_after_moving(self, pool_class): # Validate pool and connections settings were updated according to MOVING event self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, ) + # TODO validate current socket timeout + # Step 2: Run command that will receive and handle MIGRATED event # This should clear the temporary settings key_migrated = "key_receive_migrated_0" @@ -1062,14 +1105,17 @@ def test_receive_migrated_after_moving(self, pool_class): # Step 3: Validate that MIGRATED event was processed but MOVING settings remain # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) + # MOVING settings should still be active + # MOVING timeout should still be active self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[ - 0 - ], # MOVING settings still active - self.config.relax_timeout, # MOVING timeout still active + pool=test_redis_client.connection_pool, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, ) # Step 4: Create new connections after MIGRATED to verify they still use MOVING settings @@ -1081,10 +1127,7 @@ def test_receive_migrated_after_moving(self, pool_class): # Validate that new connections are created with MOVING settings (still active) for connection in new_connections: - assert ( - connection.tmp_host_address - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] - ) + assert connection.host == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] # Note: New connections may not inherit the exact relax timeout value # but they should have the temporary host address # New connections should be connected @@ -1130,61 +1173,85 @@ def test_overlapping_moving_events(self, pool_class): result1 = test_redis_client.set(key_moving1, value_moving1) assert result1 is True self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # Validate all connections reflect the first MOVING event self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) self._validate_free_connections_state( - test_redis_client.connection_pool, - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, should_be_connected_count=1, connected_to_tmp_addres=True, + expected_state=MaintenanceState.MOVING, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # Before the first MOVING expires, trigger a second MOVING event (simulate new address) - # Patch MockSocket to use a new address for the second event - new_address = "5.6.7.8:6380" + # Validate the orig properties are not changed! + second_moving_address = "5.6.7.8:6380" orig_after_moving = MockSocket.AFTER_MOVING_ADDRESS - MockSocket.AFTER_MOVING_ADDRESS = new_address + MockSocket.AFTER_MOVING_ADDRESS = second_moving_address try: key_moving2 = "key_receive_moving_1" value_moving2 = "value3_1" result2 = test_redis_client.set(key_moving2, value_moving2) assert result2 is True self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - new_address.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_host_address=second_moving_address.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # Validate all connections reflect the second MOVING event self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=new_address.split(":")[0], - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=second_moving_address.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) self._validate_free_connections_state( test_redis_client.connection_pool, - new_address.split(":")[0], - self.config.relax_timeout, should_be_connected_count=1, connected_to_tmp_addres=True, + expected_state=MaintenanceState.MOVING, + expected_host_address=second_moving_address.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) finally: MockSocket.AFTER_MOVING_ADDRESS = orig_after_moving @@ -1192,11 +1259,14 @@ def test_overlapping_moving_events(self, pool_class): # Wait for both MOVING timeouts to expire sleep(MockSocket.MOVING_TIMEOUT + 0.5) self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - None, - -1, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) finally: if hasattr(test_redis_client.connection_pool, "disconnect"): @@ -1234,12 +1304,16 @@ def worker(idx): assert not errors, f"Errors occurred in threads: {errors}" # After all threads, MOVING event should have been handled safely self._validate_conn_kwargs( - test_redis_client.connection_pool, - MockSocket.DEFAULT_ADDRESS.split(":")[0], - int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - self.config.relax_timeout, + pool=test_redis_client.connection_pool, + expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) + if hasattr(test_redis_client.connection_pool, "disconnect"): test_redis_client.connection_pool.disconnect() @@ -1279,18 +1353,26 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=tmp_address, - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) self._validate_free_connections_state( - pool, - tmp_address, - self.config.relax_timeout, + pool=pool, should_be_connected_count=0, connected_to_tmp_addres=False, expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # 2. MIGRATING event (simulate direct connection handler call) @@ -1301,8 +1383,12 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=tmp_address, - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) @@ -1316,29 +1402,41 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_tmp_host_address=tmp_address, - expected_tmp_relax_timeout=self.config.relax_timeout, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) # 4. MOVED event (simulate timer expiry) - pool_handler.handle_node_moved_event() + pool_handler.handle_node_moved_event(moving_event) self._validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.NONE, - expected_tmp_host_address=None, - expected_tmp_relax_timeout=-1, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], ) self._validate_free_connections_state( - pool, - None, - -1, + pool=pool, should_be_connected_count=0, connected_to_tmp_addres=False, expected_state=MaintenanceState.NONE, + expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=None, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) # New connection after MOVED new_conn_none = pool.get_connection() From 6d496f0b3ce2a3bafbed47eea65f93718b34555f Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Thu, 24 Jul 2025 16:38:31 +0300 Subject: [PATCH 15/16] Apply review comments --- redis/_parsers/base.py | 20 +-- redis/connection.py | 186 +++++++++++----------- redis/maintenance_events.py | 7 - tests/test_connection_pool.py | 2 - tests/test_maintenance_events.py | 23 ++- tests/test_maintenance_events_handling.py | 38 +++-- 6 files changed, 134 insertions(+), 142 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 77d0188092..d5e4add661 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -129,8 +129,9 @@ def __del__(self): def on_connect(self, connection): "Called when the socket connects" self._sock = connection._sock - timeout = connection.socket_timeout - self._buffer = SocketBuffer(self._sock, self.socket_read_size, timeout) + self._buffer = SocketBuffer( + self._sock, self.socket_read_size, connection.socket_timeout + ) self.encoder = connection.encoder def on_disconnect(self): @@ -201,19 +202,18 @@ def handle_push_response(self, response, **kwargs): if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: - if msg_type in _MOVING_MESSAGE: - host, port = response[2].decode().split(":") - ttl = response[1] - id = 1 # Hardcoded value for sync parser - notification = NodeMovingEvent(id, host, port, ttl) - return self.node_moving_push_handler_func(notification) + host, port = response[2].decode().split(":") + ttl = response[1] + id = 1 # Hardcoded value until the notification starts including the id + notification = NodeMovingEvent(id, host, port, ttl) + return self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: ttl = response[1] - id = 2 # Hardcoded value for sync parser + id = 2 # Hardcoded value until the notification starts including the id notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: - id = 3 # Hardcoded value for sync parser + id = 3 # Hardcoded value until the notification starts including the id notification = NodeMigratedEvent(id) else: notification = None diff --git a/redis/connection.py b/redis/connection.py index c20c89dd9d..3f1b54e26c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -233,34 +233,34 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass + @property @abstractmethod - def mark_for_reconnect(self): + def maintenance_state(self) -> MaintenanceState: """ - Mark the connection to be reconnected on the next command. - This is useful when a connection is moved to a different node. + Returns the current maintenance state of the connection. """ pass + @maintenance_state.setter @abstractmethod - def should_reconnect(self): + def maintenance_state(self, state: "MaintenanceState"): """ - Returns True if the connection should be reconnected. + Sets the current maintenance state of the connection. """ pass - @property @abstractmethod - def maintenance_state(self) -> MaintenanceState: + def mark_for_reconnect(self): """ - Returns the current maintenance state of the connection. + Mark the connection to be reconnected on the next command. + This is useful when a connection is moved to a different node. """ pass - @maintenance_state.setter @abstractmethod - def maintenance_state(self, state: "MaintenanceState"): + def should_reconnect(self): """ - Sets the current maintenance state of the connection. + Returns True if the connection should be reconnected. """ pass @@ -323,10 +323,10 @@ def __init__( event_dispatcher: Optional[EventDispatcher] = None, maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, maintenance_events_config: Optional[MaintenanceEventsConfig] = None, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, orig_host_address: Optional[str] = None, orig_socket_timeout: Optional[float] = None, orig_socket_connect_timeout: Optional[float] = None, - maintenance_state: "MaintenanceState" = MaintenanceState.NONE, ): """ Initialize a new Connection. @@ -412,13 +412,22 @@ def __init__( self._maintenance_event_connection_handler.handle_event ) - self._command_packer = self._construct_command_packer(command_packer) + self.orig_host_address = ( + orig_host_address if orig_host_address else self.host + ) + self.orig_socket_timeout = ( + orig_socket_timeout if orig_socket_timeout else self.socket_timeout + ) + self.orig_socket_connect_timeout = ( + orig_socket_connect_timeout + if orig_socket_connect_timeout + else self.socket_connect_timeout + ) self._should_reconnect = False - self.orig_host_address = orig_host_address - self.orig_socket_timeout = orig_socket_timeout - self.orig_socket_connect_timeout = orig_socket_connect_timeout self.maintenance_state = maintenance_state + self._command_packer = self._construct_command_packer(command_packer) + def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) return f"<{self.__class__.__module__}.{self.__class__.__name__}({repr_args})>" @@ -878,20 +887,13 @@ def set_tmp_settings( self, tmp_host_address: Optional[Union[str, object]] = SENTINEL, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): """ The value of SENTINEL is used to indicate that the property should not be updated. """ if tmp_host_address is not SENTINEL: - if not skip_original_data_update: - self.orig_host_address = self.host self.host = tmp_host_address if tmp_relax_timeout != -1: - if not skip_original_data_update: - self.orig_socket_timeout = self.socket_timeout - self.orig_socket_connect_timeout = self.socket_connect_timeout - self.socket_timeout = tmp_relax_timeout self.socket_connect_timeout = tmp_relax_timeout @@ -902,12 +904,9 @@ def reset_tmp_settings( ): if reset_host_address: self.host = self.orig_host_address - self.orig_host_address = None if reset_relax_timeout: self.socket_timeout = self.orig_socket_timeout self.socket_connect_timeout = self.orig_socket_connect_timeout - self.orig_socket_timeout = None - self.orig_socket_connect_timeout = None class Connection(AbstractConnection): @@ -1600,6 +1599,24 @@ def __init__( raise RedisError( "Push handlers on connection are only supported with RESP version 3" ) + config = connection_kwargs.get("maintenance_events_config", None) or ( + connection_kwargs.get("maintenance_events_pool_handler").config + if connection_kwargs.get("maintenance_events_pool_handler") + else None + ) + + if config and config.enabled: + connection_kwargs.update( + { + "orig_host_address": connection_kwargs.get("host"), + "orig_socket_timeout": connection_kwargs.get( + "socket_timeout", None + ), + "orig_socket_connect_timeout": connection_kwargs.get( + "socket_connect_timeout", None + ), + } + ) self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: @@ -1641,7 +1658,7 @@ def maintenance_events_pool_handler_enabled(self): True if the maintenance events pool handler is enabled, False otherwise. """ maintenance_events_config = self.connection_kwargs.get( - "maintenance_events_config", False + "maintenance_events_config", None ) return maintenance_events_config and maintenance_events_config.enabled @@ -1663,6 +1680,7 @@ def set_maintenance_events_pool_handler( def _update_maintenance_events_configs_for_connections( self, maintenance_events_pool_handler ): + """Update the maintenance events config for all connections in the pool.""" with self._lock: for conn in self._available_connections: conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) @@ -1791,12 +1809,7 @@ def make_connection(self) -> "ConnectionInterface": raise MaxConnectionsError("Too many connections") self._created_connections += 1 - # Pass current maintenance_state to new connections - maintenance_state = self.connection_kwargs.get( - "maintenance_state", MaintenanceState.NONE - ) kwargs = dict(self.connection_kwargs) - kwargs["maintenance_state"] = maintenance_state if self.cache is not None: return CacheProxyConnection( @@ -1892,7 +1905,6 @@ def add_tmp_config_to_connection_kwargs( self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): """ Store original connection configuration and apply temporary settings. @@ -1913,26 +1925,7 @@ def add_tmp_config_to_connection_kwargs( :param tmp_relax_timeout: The temporary timeout value to use for both socket_timeout and socket_connect_timeout. If -1 is provided, the timeout settings are not modified (relax timeout is disabled). - :param skip_original_data_update: Whether to skip updating the original data. - This is used when we are already in MOVING state - and the original data is already stored in the connection kwargs. """ - if not skip_original_data_update: - # Store original values in temporary storage - original_host = self.connection_kwargs.get("host") - original_socket_timeout = self.connection_kwargs.get("socket_timeout") - original_connect_timeout = self.connection_kwargs.get( - "socket_connect_timeout" - ) - - self.connection_kwargs.update( - { - "orig_host_address": original_host, - "orig_socket_timeout": original_socket_timeout, - "orig_socket_connect_timeout": original_connect_timeout, - } - ) - # Apply temporary values as active configuration self.connection_kwargs.update({"host": tmp_host_address}) @@ -1964,9 +1957,6 @@ def remove_tmp_config_from_connection_kwargs(self): self.connection_kwargs.update( { - "orig_host_address": None, - "orig_socket_timeout": None, - "orig_socket_connect_timeout": None, "host": orig_host, "socket_timeout": orig_socket_timeout, "socket_connect_timeout": orig_connect_timeout, @@ -1997,10 +1987,7 @@ def reset_connections_tmp_settings(self): ) def update_active_connections_for_reconnect( - self, - tmp_host_address: str, - tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, + self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None ): """ Mark all active connections for reconnect. @@ -2008,18 +1995,18 @@ def update_active_connections_for_reconnect( When this method is called the pool will already be locked, so getting the pool lock inside is not needed. - :param orig_host_address: The temporary host address to use for the connection. + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. """ for conn in self._in_use_connections: self._update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update + conn, tmp_host_address, tmp_relax_timeout ) def disconnect_and_reconfigure_free_connections( self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): """ Disconnect all free/available connections. @@ -2033,7 +2020,7 @@ def disconnect_and_reconfigure_free_connections( for conn in self._available_connections: self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update + conn, tmp_host_address, tmp_relax_timeout ) def update_connections_current_timeout( @@ -2063,13 +2050,10 @@ def _update_connection_for_reconnect( connection: "Connection", tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): connection.mark_for_reconnect() connection.set_tmp_settings( - tmp_host_address=tmp_host_address, - tmp_relax_timeout=tmp_relax_timeout, - skip_original_data_update=skip_original_data_update, + tmp_host_address=tmp_host_address, tmp_relax_timeout=tmp_relax_timeout ) def _disconnect_and_update_connection_for_reconnect( @@ -2077,13 +2061,10 @@ def _disconnect_and_update_connection_for_reconnect( connection: "Connection", tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, ): connection.disconnect() connection.set_tmp_settings( - tmp_host_address=tmp_host_address, - tmp_relax_timeout=tmp_relax_timeout, - skip_original_data_update=skip_original_data_update, + tmp_host_address=tmp_host_address, tmp_relax_timeout=tmp_relax_timeout ) async def _mock(self, error: RedisError): @@ -2188,20 +2169,15 @@ def make_connection(self): if self._in_maintenance: self._lock.acquire() self._locked = True - # Pass current maintenance_state to new connections - maintenance_state = self.connection_kwargs.get( - "maintenance_state", MaintenanceState.NONE - ) - kwargs = dict(self.connection_kwargs) - kwargs["maintenance_state"] = maintenance_state + if self.cache is not None: connection = CacheProxyConnection( - self.connection_class(**kwargs), + self.connection_class(**self.connection_kwargs), self.cache, self._lock, ) else: - connection = self.connection_class(**kwargs) + connection = self.connection_class(**self.connection_kwargs) self._connections.append(connection) return connection finally: @@ -2332,34 +2308,45 @@ def disconnect(self): self._locked = False def update_active_connections_for_reconnect( - self, - tmp_host_address: str, - tmp_relax_timeout: Optional[float] = None, - skip_original_data_update: bool = False, + self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ with self._lock: connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: self._update_connection_for_reconnect( - conn, - tmp_host_address, - tmp_relax_timeout, - skip_original_data_update, + conn, tmp_host_address, tmp_relax_timeout ) def disconnect_and_reconfigure_free_connections( self, tmp_host_address: str, tmp_relax_timeout: Optional[Number] = None, - skip_original_data_update: bool = False, ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ existing_connections = self.pool.queue for conn in existing_connections: if conn: self._disconnect_and_update_connection_for_reconnect( - conn, tmp_host_address, tmp_relax_timeout, skip_original_data_update + conn, tmp_host_address, tmp_relax_timeout ) def update_connections_current_timeout( @@ -2367,6 +2354,15 @@ def update_connections_current_timeout( relax_timeout: Optional[float] = None, include_free_connections: bool = False, ): + """ + Update the timeout for the current socket. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param relax_timeout: The relax timeout to use for the connection. + :param include_free_connections: Whether to include available connections in the update. + """ if include_free_connections: for conn in tuple(self._connections): conn.update_current_socket_timeout(relax_timeout) @@ -2385,7 +2381,7 @@ def _update_maintenance_events_config_for_connections( def _update_maintenance_events_configs_for_connections( self, maintenance_events_pool_handler ): - """Override base class method to work with BlockingConnectionPool's structure.""" + """Update the maintenance events config for all connections in the pool.""" with self._lock: for conn in tuple(self._connections): conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) @@ -2401,12 +2397,14 @@ def reset_connections_tmp_settings(self): conn.reset_tmp_settings(reset_host_address=True, reset_relax_timeout=True) def set_in_maintenance(self, in_maintenance: bool): - """Set the maintenance mode for the connection pool.""" + """ + Sets a flag that this Blocking ConnectionPool is in maintenance mode. + + This is used to prevent new connections from being created while we are in maintenance mode. + The pool will be in maintenance mode only when we are processing a MOVING event. + """ self._in_maintenance = in_maintenance def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): for conn in self._connections: conn.maintenance_state = state - - def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): - self.connection_kwargs["maintenance_state"] = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index 479a4ba090..d4b4e06231 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -382,7 +382,6 @@ def handle_node_moving_event(self, event: NodeMovingEvent): self.pool.add_tmp_config_to_connection_kwargs( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, - skip_original_data_update=prev_moving_in_progress, ) if ( self.config.is_relax_timeouts_enabled() @@ -395,21 +394,15 @@ def handle_node_moving_event(self, event: NodeMovingEvent): if self.config.proactive_reconnect: # take care for the active connections in the pool # mark them for reconnect after they complete the current command - # skip original data update if we are already in MOVING state - # as the original data is already stored in the connection self.pool.update_active_connections_for_reconnect( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, - skip_original_data_update=prev_moving_in_progress, ) # take care for the inactive connections in the pool # delete them and create new ones - # skip original data update if we are already in MOVING state - # as the original data is already stored in the connection self.pool.disconnect_and_reconfigure_free_connections( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, - skip_original_data_update=prev_moving_in_progress, ) if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(False) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 880b6db27e..282aec567d 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -57,7 +57,6 @@ def test_connection_creation(self): connection_kwargs = { "foo": "bar", "biz": "baz", - "maintenance_state": MaintenanceState.NONE, } pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection @@ -160,7 +159,6 @@ def test_connection_creation(self, master_host): } pool = self.get_pool(connection_kwargs=connection_kwargs) - connection_kwargs["maintenance_state"] = MaintenanceState.NONE connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py index 37ef869100..3eb648f079 100644 --- a/tests/test_maintenance_events.py +++ b/tests/test_maintenance_events.py @@ -33,25 +33,22 @@ def test_init_through_subclass(self): assert event.creation_time == 1000 assert event.expire_at == 1010 - def test_is_expired_false(self): + @pytest.mark.parametrize( + ("current_time", "expected_expired_state"), + [ + (1005, False), + (1015, True), + ], + ) + def test_is_expired(self, current_time, expected_expired_state): """Test is_expired returns False for non-expired event.""" with patch("time.monotonic", return_value=1000): event = NodeMovingEvent( id=1, new_node_host="localhost", new_node_port=6379, ttl=10 ) - with patch("time.monotonic", return_value=1005): # 5 seconds later - assert not event.is_expired() - - def test_is_expired_true(self): - """Test is_expired returns True for expired event.""" - with patch("time.monotonic", return_value=1000): - event = NodeMovingEvent( - id=1, new_node_host="localhost", new_node_port=6379, ttl=10 - ) - - with patch("time.monotonic", return_value=1015): # 15 seconds later - assert event.is_expired() + with patch("time.monotonic", return_value=current_time): + assert event.is_expired() == expected_expired_state def test_is_expired_exact_boundary(self): """Test is_expired at exact expiration boundary.""" diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index fe0b529fdb..dc4e850a50 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -320,7 +320,7 @@ def _validate_in_use_connections_state( expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, @@ -355,7 +355,7 @@ def _validate_free_connections_state( expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ): @@ -419,13 +419,16 @@ def _validate_conn_kwargs( pool.connection_kwargs["socket_connect_timeout"] == expected_socket_connect_timeout ) - assert pool.connection_kwargs["orig_host_address"] == expected_orig_host_address assert ( - pool.connection_kwargs["orig_socket_timeout"] + pool.connection_kwargs.get("orig_host_address", None) + == expected_orig_host_address + ) + assert ( + pool.connection_kwargs.get("orig_socket_timeout", None) == expected_orig_socket_timeout ) assert ( - pool.connection_kwargs["orig_socket_connect_timeout"] + pool.connection_kwargs.get("orig_socket_connect_timeout", None) == expected_orig_socket_connect_timeout ) @@ -446,7 +449,7 @@ def test_client_initialization(self): conn = test_redis_client.connection_pool.get_connection() assert conn._should_reconnect is False - assert conn.orig_host_address is None + assert conn.orig_host_address == "localhost" assert conn.orig_socket_timeout is None # Test that the node moving handler function is correctly set by @@ -825,13 +828,13 @@ def test_moving_related_events_handling_integration(self, pool_class): # Validate pool and connections settings were updated according to MOVING event self._validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), - expected_orig_socket_timeout=None, - expected_orig_socket_connect_timeout=None, expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, ) self._validate_disconnected(5) self._validate_connected(6) @@ -870,7 +873,7 @@ def test_moving_related_events_handling_integration(self, pool_class): expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -879,7 +882,7 @@ def test_moving_related_events_handling_integration(self, pool_class): expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, should_be_connected_count=1, @@ -1022,7 +1025,10 @@ def test_create_new_conn_after_moving_expires(self, pool_class): new_connection = test_redis_client.connection_pool.get_connection() # Validate that new connections are created with original address (no temporary settings) - assert new_connection.orig_host_address is None + assert ( + new_connection.orig_host_address + == MockSocket.DEFAULT_ADDRESS.split(":")[0] + ) assert new_connection.orig_socket_timeout is None # New connections should be connected to the original address assert new_connection._sock is not None @@ -1264,7 +1270,7 @@ def test_overlapping_moving_events(self, pool_class): expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1420,7 +1426,7 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, @@ -1434,7 +1440,7 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=None, + expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) From 2d3731f435bd4804959151deba9500a78618c399 Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Sat, 26 Jul 2025 13:07:58 +0300 Subject: [PATCH 16/16] Applying moving/moved only on connections to the same proxy. --- redis/_parsers/base.py | 8 +- redis/connection.py | 128 +++- redis/maintenance_events.py | 86 +-- tests/test_connection_pool.py | 1 - tests/test_maintenance_events_handling.py | 739 +++++++++++++++------- 5 files changed, 702 insertions(+), 260 deletions(-) diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index d5e4add661..c3d4c136d2 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -202,6 +202,7 @@ def handle_push_response(self, response, **kwargs): if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # TODO: PARSE latest format when available host, port = response[2].decode().split(":") ttl = response[1] id = 1 # Hardcoded value until the notification starts including the id @@ -209,10 +210,12 @@ def handle_push_response(self, response, **kwargs): return self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: + # TODO: PARSE latest format when available ttl = response[1] id = 2 # Hardcoded value until the notification starts including the id notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: + # TODO: PARSE latest format when available id = 3 # Hardcoded value until the notification starts including the id notification = NodeMigratedEvent(id) else: @@ -260,6 +263,7 @@ async def handle_push_response(self, response, **kwargs): return await self.invalidation_push_handler_func(response) if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: # push notification from enterprise cluster for node moving + # TODO: PARSE latest format when available host, port = response[2].split(":") ttl = response[1] id = 1 # Hardcoded value for async parser @@ -267,10 +271,12 @@ async def handle_push_response(self, response, **kwargs): return await self.node_moving_push_handler_func(notification) if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: if msg_type in _MIGRATING_MESSAGE: + # TODO: PARSE latest format when available ttl = response[1] id = 2 # Hardcoded value for async parser notification = NodeMigratingEvent(id, ttl) elif msg_type in _MIGRATED_MESSAGE: + # TODO: PARSE latest format when available id = 3 # Hardcoded value for async parser notification = NodeMigratedEvent(id) return await self.maintenance_push_handler_func(notification) @@ -283,7 +289,7 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func): """Set the invalidation push handler function""" self.invalidation_push_handler_func = invalidation_push_handler_func - def set_node_moving_push_handler_func(self, node_moving_push_handler_func): + def set_node_moving_push_handler(self, node_moving_push_handler_func): self.node_moving_push_handler_func = node_moving_push_handler_func def set_maintenance_push_handler(self, maintenance_push_handler_func): diff --git a/redis/connection.py b/redis/connection.py index 3f1b54e26c..0d8a3983e8 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,7 +8,7 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -249,6 +249,13 @@ def maintenance_state(self, state: "MaintenanceState"): """ pass + @abstractmethod + def getpeername(self): + """ + Returns the peer name of the connection. + """ + pass + @abstractmethod def mark_for_reconnect(self): """ @@ -402,6 +409,7 @@ def __init__( if maintenance_events_config and maintenance_events_config.enabled: if maintenance_events_pool_handler: + maintenance_events_pool_handler.set_connection(self) self._parser.set_node_moving_push_handler( maintenance_events_pool_handler.handle_event ) @@ -484,6 +492,7 @@ def set_parser(self, parser_class): def set_maintenance_event_pool_handler( self, maintenance_event_pool_handler: MaintenanceEventPoolHandler ): + maintenance_event_pool_handler.set_connection(self) self._parser.set_node_moving_push_handler( maintenance_event_pool_handler.handle_event ) @@ -867,6 +876,11 @@ def maintenance_state(self) -> MaintenanceState: def maintenance_state(self, state: "MaintenanceState"): self._maintenance_state = state + def getpeername(self): + if not self._sock: + return None + return self._sock.getpeername()[0] + def mark_for_reconnect(self): self._should_reconnect = True @@ -1892,10 +1906,27 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) - def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): + def set_maintenance_state_for_connections( + self, + state: "MaintenanceState", + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + ): for conn in self._available_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.maintenance_state = state for conn in self._in_use_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.maintenance_state = state def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): @@ -1963,7 +1994,12 @@ def remove_tmp_config_from_connection_kwargs(self): } ) - def reset_connections_tmp_settings(self): + def reset_connections_tmp_settings( + self, + moving_address: Optional[str] = None, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): """ Restore original settings from temporary configuration for all connections in the pool. @@ -1978,16 +2014,25 @@ def reset_connections_tmp_settings(self): """ with self._lock: for conn in self._available_connections: + if moving_address and conn.host != moving_address: + continue conn.reset_tmp_settings( - reset_host_address=True, reset_relax_timeout=True + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, ) for conn in self._in_use_connections: + if moving_address and conn.host != moving_address: + continue conn.reset_tmp_settings( - reset_host_address=True, reset_relax_timeout=True + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, ) def update_active_connections_for_reconnect( - self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, ): """ Mark all active connections for reconnect. @@ -1999,6 +2044,8 @@ def update_active_connections_for_reconnect( :param tmp_relax_timeout: The relax timeout to use for the connection. """ for conn in self._in_use_connections: + if moving_address_src and conn.getpeername() != moving_address_src: + continue self._update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2007,6 +2054,7 @@ def disconnect_and_reconfigure_free_connections( self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, ): """ Disconnect all free/available connections. @@ -2019,6 +2067,8 @@ def disconnect_and_reconfigure_free_connections( """ for conn in self._available_connections: + if moving_address_src and conn.getpeername() != moving_address_src: + continue self._disconnect_and_update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2026,6 +2076,8 @@ def disconnect_and_reconfigure_free_connections( def update_connections_current_timeout( self, relax_timeout: Optional[float], + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", include_free_connections: bool = False, ): """ @@ -2039,10 +2091,22 @@ def update_connections_current_timeout( :param include_available_connections: Whether to include available connections in the update. """ for conn in self._in_use_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.update_current_socket_timeout(relax_timeout) if include_free_connections: for conn in self._available_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.update_current_socket_timeout(relax_timeout) def _update_connection_for_reconnect( @@ -2308,7 +2372,10 @@ def disconnect(self): self._locked = False def update_active_connections_for_reconnect( - self, tmp_host_address: str, tmp_relax_timeout: Optional[float] = None + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, ): """ Mark all active connections for reconnect. @@ -2323,6 +2390,8 @@ def update_active_connections_for_reconnect( connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: + if moving_address_src and conn.getpeername() != moving_address_src: + continue self._update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2331,6 +2400,7 @@ def disconnect_and_reconfigure_free_connections( self, tmp_host_address: str, tmp_relax_timeout: Optional[Number] = None, + moving_address_src: Optional[str] = None, ): """ Disconnect all free/available connections. @@ -2345,6 +2415,8 @@ def disconnect_and_reconfigure_free_connections( for conn in existing_connections: if conn: + if moving_address_src and conn.getpeername() != moving_address_src: + continue self._disconnect_and_update_connection_for_reconnect( conn, tmp_host_address, tmp_relax_timeout ) @@ -2352,6 +2424,8 @@ def disconnect_and_reconfigure_free_connections( def update_connections_current_timeout( self, relax_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", include_free_connections: bool = False, ): """ @@ -2365,11 +2439,23 @@ def update_connections_current_timeout( """ if include_free_connections: for conn in tuple(self._connections): + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.update_current_socket_timeout(relax_timeout) else: connections_in_queue = {conn for conn in self.pool.queue if conn} for conn in self._connections: if conn not in connections_in_queue: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue conn.update_current_socket_timeout(relax_timeout) def _update_maintenance_events_config_for_connections( @@ -2387,14 +2473,24 @@ def _update_maintenance_events_configs_for_connections( conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) conn.maintenance_events_config = maintenance_events_pool_handler.config - def reset_connections_tmp_settings(self): + def reset_connections_tmp_settings( + self, + moving_address: Optional[str] = None, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): """ Override base class method to work with BlockingConnectionPool's structure. Restore original settings from temporary configuration for all connections in the pool. """ for conn in tuple(self._connections): - conn.reset_tmp_settings(reset_host_address=True, reset_relax_timeout=True) + if moving_address and conn.host != moving_address: + continue + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) def set_in_maintenance(self, in_maintenance: bool): """ @@ -2405,6 +2501,18 @@ def set_in_maintenance(self, in_maintenance: bool): """ self._in_maintenance = in_maintenance - def set_maintenance_state_for_all_connections(self, state: "MaintenanceState"): + def set_maintenance_state_for_connections( + self, + state: "MaintenanceState", + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + ): for conn in self._connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + conn.maintenance_state = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py index d4b4e06231..3b83da9e02 100644 --- a/redis/maintenance_events.py +++ b/redis/maintenance_events.py @@ -324,6 +324,10 @@ def __init__( self.config = config self._processed_events = set() self._lock = threading.RLock() + self.connection = None + + def set_connection(self, connection: "ConnectionInterface"): + self.connection = connection def remove_expired_notifications(self): with self._lock: @@ -357,25 +361,20 @@ def handle_node_moving_event(self, event: NodeMovingEvent): self.config.proactive_reconnect or self.config.is_relax_timeouts_enabled() ): + moving_address_src = ( + self.connection.getpeername() if self.connection else None + ) + if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(True) - prev_moving_in_progress = False - if ( - self.pool.connection_kwargs.get("maintenance_state") - == MaintenanceState.MOVING - ): - # The pool is already in MOVING state, update just the new host information - prev_moving_in_progress = True - - if not prev_moving_in_progress: - # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) - self.pool.set_maintenance_state_for_all_connections( - MaintenanceState.MOVING - ) - self.pool.set_maintenance_state_in_connection_kwargs( - MaintenanceState.MOVING - ) + # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) + self.pool.set_maintenance_state_for_connections( + MaintenanceState.MOVING, moving_address_src + ) + self.pool.set_maintenance_state_in_connection_kwargs( + MaintenanceState.MOVING + ) # edit the config for new connections until the notification expires # skip original data update if we are already in MOVING state # as the original data is already stored in the connection kwargs @@ -383,13 +382,12 @@ def handle_node_moving_event(self, event: NodeMovingEvent): tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, ) - if ( - self.config.is_relax_timeouts_enabled() - and not prev_moving_in_progress - ): + if self.config.is_relax_timeouts_enabled(): # extend the timeout for all connections that are currently in use self.pool.update_connections_current_timeout( - self.config.relax_timeout + relax_timeout=self.config.relax_timeout, + matching_address=moving_address_src, + address_type_to_match="connected", ) if self.config.proactive_reconnect: # take care for the active connections in the pool @@ -397,16 +395,18 @@ def handle_node_moving_event(self, event: NodeMovingEvent): self.pool.update_active_connections_for_reconnect( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, ) # take care for the inactive connections in the pool # delete them and create new ones self.pool.disconnect_and_reconfigure_free_connections( tmp_host_address=event.new_node_host, tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, ) if getattr(self.pool, "set_in_maintenance", False): self.pool.set_in_maintenance(False) - + print(f"Starting timer for {event} for {event.ttl} seconds") threading.Timer( event.ttl, self.handle_node_moved_event, args=(event,) ).start() @@ -415,25 +415,39 @@ def handle_node_moving_event(self, event: NodeMovingEvent): def handle_node_moved_event(self, event: NodeMovingEvent): with self._lock: - if self.pool.connection_kwargs.get("host") != event.new_node_host: - # if the current host is not matching the event - # it means there has been a new moving event after this one - # so we don't need to handle this one anymore - # the settings will be reverted by the moved handler of the next event - return - self.pool.remove_tmp_config_from_connection_kwargs() - # Clear state to NONE in kwargs immediately after updating tmp kwargs - self.pool.set_maintenance_state_in_connection_kwargs(MaintenanceState.NONE) + # if the current host in kwargs is not matching the event + # it means there has been a new moving event after this one + # and we don't need to revert the kwargs + if self.pool.connection_kwargs.get("host") == event.new_node_host: + self.pool.remove_tmp_config_from_connection_kwargs() + # Clear state to NONE in kwargs immediately after updating tmp kwargs + self.pool.set_maintenance_state_in_connection_kwargs( + MaintenanceState.NONE + ) with self.pool._lock: - self.pool.reset_connections_tmp_settings() + moving_address = event.new_node_host if self.config.is_relax_timeouts_enabled(): + self.pool.reset_connections_tmp_settings( + moving_address, reset_relax_timeout=True + ) # reset the timeout for existing connections self.pool.update_connections_current_timeout( - relax_timeout=-1, include_free_connections=True + relax_timeout=-1, + matching_address=moving_address, + address_type_to_match="configured", + include_free_connections=True, ) - # Clear state to NONE for all connections - self.pool.set_maintenance_state_for_all_connections( - MaintenanceState.NONE + + # Clear maintenance state to NONE for all matching connections + self.pool.set_maintenance_state_for_connections( + state=MaintenanceState.NONE, + matching_address=moving_address, + address_type_to_match="configured", + ) + # reset the host address after all other operations that + # compare against tmp host are completed + self.pool.reset_connections_tmp_settings( + moving_address, reset_host_address=True ) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 282aec567d..1eb68d3775 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -9,7 +9,6 @@ import redis from redis.cache import CacheConfig from redis.connection import CacheProxyConnection, Connection, to_bool -from redis.maintenance_events import MaintenanceState from redis.utils import SSL_AVAILABLE from .conftest import ( diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py index dc4e850a50..8ea5488aa8 100644 --- a/tests/test_maintenance_events_handling.py +++ b/tests/test_maintenance_events_handling.py @@ -1,7 +1,8 @@ import socket import threading -from typing import List +from typing import List, Union from unittest.mock import patch + import pytest from time import sleep @@ -21,13 +22,132 @@ ) +AFTER_MOVING_ADDRESS = "1.2.3.4:6379" +DEFAULT_ADDRESS = "12.45.34.56:6379" +MOVING_TIMEOUT = 1 + + +class Helpers: + """Helper class containing static methods for validation in maintenance events tests.""" + + @staticmethod + def validate_in_use_connections_state( + in_use_connections: List[AbstractConnection], + expected_state=MaintenanceState.NONE, + expected_should_reconnect: Union[bool, str] = True, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ): + """Helper method to validate state of in-use connections.""" + + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for connection in in_use_connections: + if expected_should_reconnect != "any": + assert connection._should_reconnect == expected_should_reconnect + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + if connection._sock is not None: + assert connection._sock.gettimeout() == expected_current_socket_timeout + assert connection._sock.connected is True + if expected_current_peername != "any": + assert ( + connection._sock.getpeername()[0] == expected_current_peername + ) + assert connection.maintenance_state == expected_state + + @staticmethod + def validate_free_connections_state( + pool, + should_be_connected_count=0, + connected_to_tmp_addres=False, + tmp_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_state=MaintenanceState.MOVING, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ): + """Helper method to validate state of free/available connections.""" + + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + else: + raise ValueError(f"Unsupported pool type: {type(pool)}") + + connected_count = 0 + for connection in free_connections: + assert connection._should_reconnect is False + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + assert connection.maintenance_state == expected_state + if connection._sock is not None: + assert connection._sock.connected is True + if connected_to_tmp_addres and tmp_address != "any": + assert connection._sock.getpeername()[0] == tmp_address + connected_count += 1 + assert connected_count == should_be_connected_count + + @staticmethod + def validate_conn_kwargs( + pool, + expected_host_address, + expected_port, + expected_socket_timeout, + expected_socket_connect_timeout, + expected_orig_host_address, + expected_orig_socket_timeout, + expected_orig_socket_connect_timeout, + ): + """Helper method to validate connection kwargs.""" + assert pool.connection_kwargs["host"] == expected_host_address + assert pool.connection_kwargs["port"] == expected_port + assert pool.connection_kwargs["socket_timeout"] == expected_socket_timeout + assert ( + pool.connection_kwargs["socket_connect_timeout"] + == expected_socket_connect_timeout + ) + assert ( + pool.connection_kwargs.get("orig_host_address", None) + == expected_orig_host_address + ) + assert ( + pool.connection_kwargs.get("orig_socket_timeout", None) + == expected_orig_socket_timeout + ) + assert ( + pool.connection_kwargs.get("orig_socket_connect_timeout", None) + == expected_orig_socket_connect_timeout + ) + + class MockSocket: """Mock socket that simulates Redis protocol responses.""" - AFTER_MOVING_ADDRESS = "1.2.3.4:6379" - DEFAULT_ADDRESS = "12.45.34.56:6379" - MOVING_TIMEOUT = 1 - def __init__(self): self.connected = False self.address = None @@ -73,7 +193,7 @@ def send(self, data): # MOVING push message before SET key_receive_moving_X response # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) # Note: Using + instead of $ to send as simple string instead of bulk string - moving_push = f">3\r\n$6\r\nMOVING\r\n:{MockSocket.MOVING_TIMEOUT}\r\n+{MockSocket.AFTER_MOVING_ADDRESS}\r\n" + moving_push = f">3\r\n$6\r\nMOVING\r\n:{MOVING_TIMEOUT}\r\n+{AFTER_MOVING_ADDRESS}\r\n" response = moving_push.encode() + response self.pending_responses.append(response) @@ -164,7 +284,7 @@ def shutdown(self, how): pass -class TestMaintenanceEventsHandling: +class TestMaintenanceEventsHandlingSingleProxy: """Integration tests for maintenance events handling with real connection pool.""" def setup_method(self): @@ -233,8 +353,8 @@ def _get_client( ) test_pool = pool_class( - host=MockSocket.DEFAULT_ADDRESS.split(":")[0], - port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + host=DEFAULT_ADDRESS.split(":")[0], + port=int(DEFAULT_ADDRESS.split(":")[1]), max_connections=max_connections, protocol=3, # Required for maintenance events maintenance_events_config=config, @@ -313,124 +433,12 @@ def _validate_connected(self, expected_count): connected_sockets_count += 1 assert connected_sockets_count == expected_count - def _validate_in_use_connections_state( - self, - in_use_connections: List[AbstractConnection], - expected_state=MaintenanceState.NONE, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_socket_timeout=None, - expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_orig_socket_timeout=None, - expected_orig_socket_connect_timeout=None, - expected_current_socket_timeout=None, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], - ): - """Helper method to validate state of in-use connections.""" - # validate in use connections are still working with set flag for reconnect - # and timeout is updated - for connection in in_use_connections: - assert connection._should_reconnect is True - assert connection.host == expected_host_address - assert connection.socket_timeout == expected_socket_timeout - assert connection.socket_connect_timeout == expected_socket_connect_timeout - assert connection.orig_host_address == expected_orig_host_address - assert connection.orig_socket_timeout == expected_orig_socket_timeout - assert ( - connection.orig_socket_connect_timeout - == expected_orig_socket_connect_timeout - ) - if connection._sock is not None: - assert connection._sock.gettimeout() == expected_current_socket_timeout - assert connection._sock.connected is True - assert connection._sock.getpeername()[0] == expected_current_peername - assert connection.maintenance_state == expected_state - - def _validate_free_connections_state( - self, - pool, - should_be_connected_count, - connected_to_tmp_addres=False, - expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_socket_timeout=None, - expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_orig_socket_timeout=None, - expected_orig_socket_connect_timeout=None, - ): - """Helper method to validate state of free/available connections.""" - if isinstance(pool, BlockingConnectionPool): - free_connections = [conn for conn in pool.pool.queue if conn is not None] - elif isinstance(pool, ConnectionPool): - free_connections = pool._available_connections - else: - raise ValueError(f"Unsupported pool type: {type(pool)}") - - connected_count = 0 - for connection in free_connections: - assert connection._should_reconnect is False - assert connection.host == expected_host_address - assert connection.socket_timeout == expected_socket_timeout - assert connection.socket_connect_timeout == expected_socket_connect_timeout - assert connection.orig_host_address == expected_orig_host_address - assert connection.orig_socket_timeout == expected_orig_socket_timeout - assert ( - connection.orig_socket_connect_timeout - == expected_orig_socket_connect_timeout - ) - assert connection.maintenance_state == expected_state - if connection._sock is not None: - assert connection._sock.connected is True - if connected_to_tmp_addres: - assert ( - connection._sock.getpeername()[0] - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] - ) - connected_count += 1 - assert connected_count == should_be_connected_count - def _validate_all_timeouts(self, expected_timeout): """Helper method to validate state of in-use connections.""" # validate in use connections are still working with set flag for reconnect # and timeout is updated for mock_socket in self.mock_sockets: - if expected_timeout is None: - assert mock_socket.gettimeout() is None - else: - assert mock_socket.gettimeout() == expected_timeout - - def _validate_conn_kwargs( - self, - pool, - expected_host_address, - expected_port, - expected_socket_timeout, - expected_socket_connect_timeout, - expected_orig_host_address, - expected_orig_socket_timeout, - expected_orig_socket_connect_timeout, - ): - """Helper method to validate connection kwargs.""" - assert pool.connection_kwargs["host"] == expected_host_address - assert pool.connection_kwargs["port"] == expected_port - assert pool.connection_kwargs["socket_timeout"] == expected_socket_timeout - assert ( - pool.connection_kwargs["socket_connect_timeout"] - == expected_socket_connect_timeout - ) - assert ( - pool.connection_kwargs.get("orig_host_address", None) - == expected_orig_host_address - ) - assert ( - pool.connection_kwargs.get("orig_socket_timeout", None) - == expected_orig_socket_timeout - ) - assert ( - pool.connection_kwargs.get("orig_socket_connect_timeout", None) - == expected_orig_socket_connect_timeout - ) + assert mock_socket.gettimeout() == expected_timeout def test_client_initialization(self): """Test that Redis client is created with maintenance events configuration.""" @@ -826,63 +834,75 @@ def test_moving_related_events_handling_integration(self, pool_class): assert result2 is True, "Command 2 (SET key_receive_moving) failed" # Validate pool and connections settings were updated according to MOVING event - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) self._validate_disconnected(5) self._validate_connected(6) - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[ + expected_current_peername=DEFAULT_ADDRESS.split(":")[ 0 ], # the in use connections reconnect when they complete their current task ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=test_redis_client.connection_pool, expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, should_be_connected_count=1, connected_to_tmp_addres=True, ) # Wait for MOVING timeout to expire and the moving completed handler to run - sleep(MockSocket.MOVING_TIMEOUT + 0.5) - self._validate_all_timeouts(None) - self._validate_conn_kwargs( + sleep(MOVING_TIMEOUT + 0.5) + + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_host_address=DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, should_be_connected_count=1, @@ -936,13 +956,13 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): assert result is True, "SET key_receive_moving command failed" # Validate pool and connections settings were updated according to MOVING event - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, ) @@ -959,14 +979,14 @@ def test_create_new_conn_while_moving_not_expired(self, pool_class): # Validate that new connections are created with temporary address and relax timeout # and when connecting those configs are used # get_connection() returns a connection that is already connected - assert new_connection.host == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + assert new_connection.host == AFTER_MOVING_ADDRESS.split(":")[0] assert new_connection.socket_timeout is self.config.relax_timeout # New connections should be connected to the temporary address assert new_connection._sock is not None assert new_connection._sock.connected is True assert ( new_connection._sock.getpeername()[0] - == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + == AFTER_MOVING_ADDRESS.split(":")[0] ) assert new_connection._sock.gettimeout() == self.config.relax_timeout @@ -1014,7 +1034,7 @@ def test_create_new_conn_after_moving_expires(self, pool_class): assert result is True, "SET key_receive_moving command failed" # Wait for MOVING timeout to expire - sleep(MockSocket.MOVING_TIMEOUT + 0.5) + sleep(MOVING_TIMEOUT + 0.5) # Now get several new connections after expiration old_connections = [] @@ -1025,10 +1045,7 @@ def test_create_new_conn_after_moving_expires(self, pool_class): new_connection = test_redis_client.connection_pool.get_connection() # Validate that new connections are created with original address (no temporary settings) - assert ( - new_connection.orig_host_address - == MockSocket.DEFAULT_ADDRESS.split(":")[0] - ) + assert new_connection.orig_host_address == DEFAULT_ADDRESS.split(":")[0] assert new_connection.orig_socket_timeout is None # New connections should be connected to the original address assert new_connection._sock is not None @@ -1087,13 +1104,13 @@ def test_receive_migrated_after_moving(self, pool_class): assert result_moving is True, "SET key_receive_moving command failed" # Validate pool and connections settings were updated according to MOVING event - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, ) @@ -1113,13 +1130,13 @@ def test_receive_migrated_after_moving(self, pool_class): # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) # MOVING settings should still be active # MOVING timeout should still be active - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, ) @@ -1133,7 +1150,7 @@ def test_receive_migrated_after_moving(self, pool_class): # Validate that new connections are created with MOVING settings (still active) for connection in new_connections: - assert connection.host == MockSocket.AFTER_MOVING_ADDRESS.split(":")[0] + assert connection.host == AFTER_MOVING_ADDRESS.split(":")[0] # Note: New connections may not inherit the exact relax timeout value # but they should have the temporary host address # New connections should be connected @@ -1158,13 +1175,19 @@ def test_overlapping_moving_events(self, pool_class): Test handling of overlapping/duplicate MOVING events (e.g., two MOVING events before the first expires). Ensures that the second MOVING event updates the pool and connections as expected, and that expiry/cleanup works. """ + global AFTER_MOVING_ADDRESS test_redis_client = self._get_client( pool_class, max_connections=5, setup_pool_handler=True ) try: # Create and release some connections + in_use_connections = [] for _ in range(3): - conn = test_redis_client.connection_pool.get_connection() + in_use_connections.append( + test_redis_client.connection_pool.get_connection() + ) + + for conn in in_use_connections: test_redis_client.connection_pool.release(conn) # Take 2 connections to be in use @@ -1178,99 +1201,106 @@ def test_overlapping_moving_events(self, pool_class): value_moving1 = "value3_0" result1 = test_redis_client.set(key_moving1, value_moving1) assert result1 is True - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) # Validate all connections reflect the first MOVING event - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=test_redis_client.connection_pool, should_be_connected_count=1, connected_to_tmp_addres=True, expected_state=MaintenanceState.MOVING, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) + # Reconnect in use connections + for conn in in_use_connections: + conn.disconnect() + conn.connect() # Before the first MOVING expires, trigger a second MOVING event (simulate new address) # Validate the orig properties are not changed! second_moving_address = "5.6.7.8:6380" - orig_after_moving = MockSocket.AFTER_MOVING_ADDRESS - MockSocket.AFTER_MOVING_ADDRESS = second_moving_address + orig_after_moving = AFTER_MOVING_ADDRESS + # Temporarily modify the global constant for this test + AFTER_MOVING_ADDRESS = second_moving_address try: key_moving2 = "key_receive_moving_1" value_moving2 = "value3_1" result2 = test_redis_client.set(key_moving2, value_moving2) assert result2 is True - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, expected_host_address=second_moving_address.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) # Validate all connections reflect the second MOVING event - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=second_moving_address.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=orig_after_moving.split(":")[0], ) - self._validate_free_connections_state( + # print(test_redis_client.connection_pool._available_connections) + Helpers.validate_free_connections_state( test_redis_client.connection_pool, should_be_connected_count=1, connected_to_tmp_addres=True, + tmp_address=second_moving_address.split(":")[0], expected_state=MaintenanceState.MOVING, expected_host_address=second_moving_address.split(":")[0], expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) finally: - MockSocket.AFTER_MOVING_ADDRESS = orig_after_moving + AFTER_MOVING_ADDRESS = orig_after_moving # Wait for both MOVING timeouts to expire - sleep(MockSocket.MOVING_TIMEOUT + 0.5) - self._validate_conn_kwargs( + sleep(MOVING_TIMEOUT + 0.5) + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1309,13 +1339,13 @@ def worker(idx): assert all(results), f"Not all threads succeeded: {results}" assert not errors, f"Errors occurred in threads: {errors}" # After all threads, MOVING event should have been handled safely - self._validate_conn_kwargs( + Helpers.validate_conn_kwargs( pool=test_redis_client.connection_pool, - expected_host_address=MockSocket.AFTER_MOVING_ADDRESS.split(":")[0], - expected_port=int(MockSocket.DEFAULT_ADDRESS.split(":")[1]), + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1356,19 +1386,19 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): id=1, new_node_host=tmp_address, new_node_port=6379, ttl=1 ) pool_handler.handle_event(moving_event) - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=pool, should_be_connected_count=0, connected_to_tmp_addres=False, @@ -1376,7 +1406,7 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1386,17 +1416,17 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): conn._maintenance_event_connection_handler.handle_event( NodeMigratingEvent(id=2, ttl=1) ) - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) # 3. MIGRATED event (simulate direct connection handler call) @@ -1405,42 +1435,42 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): NodeMigratedEvent(id=2) ) # State should not change for connections that are in MOVING state - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.MOVING, expected_host_address=tmp_address, expected_socket_timeout=self.config.relax_timeout, expected_socket_connect_timeout=self.config.relax_timeout, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=self.config.relax_timeout, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) # 4. MOVED event (simulate timer expiry) pool_handler.handle_node_moved_event(moving_event) - self._validate_in_use_connections_state( + Helpers.validate_in_use_connections_state( in_use_connections, expected_state=MaintenanceState.NONE, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_host_address=DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, expected_current_socket_timeout=None, - expected_current_peername=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], ) - self._validate_free_connections_state( + Helpers.validate_free_connections_state( pool=pool, should_be_connected_count=0, connected_to_tmp_addres=False, expected_state=MaintenanceState.NONE, - expected_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_host_address=DEFAULT_ADDRESS.split(":")[0], expected_socket_timeout=None, expected_socket_connect_timeout=None, - expected_orig_host_address=MockSocket.DEFAULT_ADDRESS.split(":")[0], + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], expected_orig_socket_timeout=None, expected_orig_socket_connect_timeout=None, ) @@ -1453,3 +1483,288 @@ def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): pool.release(conn) if hasattr(pool, "disconnect"): pool.disconnect() + + +class TestMaintenanceEventsHandlingMultipleProxies: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + self.orig_host = "test.address.com" + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if hasattr(sock, "pending_responses") and sock.pending_responses: + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + ips = ["1.2.3.4", "5.6.7.8", "9.10.11.12"] + ips = ips * 3 + + # Mock socket creation to return our mock sockets + def mock_socket_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): + if host == self.orig_host: + ip_address = ips.pop(0) + else: + ip_address = host + + # Return the standard getaddrinfo format + # (family, type, proto, canonname, sockaddr) + return [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + (ip_address, port), + ) + ] + + self.getaddrinfo_patcher = patch( + "socket.getaddrinfo", side_effect=mock_socket_getaddrinfo + ) + self.getaddrinfo_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + self.getaddrinfo_patcher.stop() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_after_moving_multiple_proxies(self, pool_class): + """ """ + # Setup + + pool = pool_class( + host=self.orig_host, + port=12345, + max_connections=10, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + pool.set_maintenance_events_pool_handler( + MaintenanceEventPoolHandler(pool, self.config) + ) + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + key1 = "1.2.3.4" + key2 = "5.6.7.8" + key3 = "9.10.11.12" + in_use_connections = {key1: [], key2: [], key3: []} + # Create 7 connections + for _ in range(7): + conn = pool.get_connection() + in_use_connections[conn.getpeername()].append(conn) + + for _, conns in in_use_connections.items(): + while len(conns) > 1: + pool.release(conns.pop()) + + # Send MOVING event to con with ip = key1 + conn = in_use_connections[key1][0] + pool_handler.set_connection(conn) + new_ip = "13.14.15.16" + pool_handler.handle_event( + NodeMovingEvent(id=1, new_node_host=new_ip, new_node_port=6379, ttl=1) + ) + + # validate in use connection and ip1 + Helpers.validate_in_use_connections_state( + in_use_connections[key1], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key1, + ) + # validate free connections for ip1 + changed_free_connections = 0 + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + for conn in free_connections: + if conn.host == new_ip: + changed_free_connections += 1 + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.host == new_ip + assert conn.socket_timeout == self.config.relax_timeout + assert conn.socket_connect_timeout == self.config.relax_timeout + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + else: + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.host == self.orig_host + assert conn.socket_timeout is None + assert conn.socket_connect_timeout is None + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + assert changed_free_connections == 2 + assert len(free_connections) == 4 + + # Send second MOVING event to con with ip = key2 + conn = in_use_connections[key2][0] + pool_handler.set_connection(conn) + new_ip_2 = "17.18.19.20" + pool_handler.handle_event( + NodeMovingEvent(id=2, new_node_host=new_ip_2, new_node_port=6379, ttl=2) + ) + + # validate in use connection and ip2 + Helpers.validate_in_use_connections_state( + in_use_connections[key2], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip_2, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key2, + ) + # validate free connections for ip2 + changed_free_connections = 0 + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + for conn in free_connections: + if conn.host == new_ip_2: + changed_free_connections += 1 + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.host == new_ip_2 + assert conn.socket_timeout == self.config.relax_timeout + assert conn.socket_connect_timeout == self.config.relax_timeout + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + # here I can't validate the other connections since some of + # them are in MOVING state from the first event + # and some are in NONE state + assert changed_free_connections == 1 + + # MIGRATING event on connection that has already been marked as MOVING + conn = in_use_connections[key2][0] + conn_event_handler = conn._maintenance_event_connection_handler + conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) + # validate connection does not lose its MOVING state + assert conn.maintenance_state == MaintenanceState.MOVING + # MIGRATED event + conn_event_handler.handle_event(NodeMigratedEvent(id=3)) + # validate connection does not lose its MOVING state and relax timeout + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.socket_timeout == self.config.relax_timeout + + # Send Migrating event to con with ip = key3 + conn = in_use_connections[key3][0] + conn_event_handler = conn._maintenance_event_connection_handler + conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) + # validate connection is in MIGRATING state + assert conn.maintenance_state == MaintenanceState.MIGRATING + assert conn.socket_timeout == self.config.relax_timeout + + # Send MIGRATED event to con with ip = key3 + conn_event_handler.handle_event(NodeMigratedEvent(id=3)) + # validate connection is in MOVING state + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.socket_timeout is None + + # sleep to expire only the first MOVING events + sleep(1.3) + # validate only the connections affected by the first MOVING event + # have lost their MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections[key1], + expected_state=MaintenanceState.NONE, + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=key1, + ) + Helpers.validate_in_use_connections_state( + in_use_connections[key2], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip_2, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key2, + ) + Helpers.validate_in_use_connections_state( + in_use_connections[key3], + expected_state=MaintenanceState.NONE, + expected_should_reconnect=False, + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=key3, + ) + # TODO validate free connections + + # sleep to expire the second MOVING events + sleep(1) + # validate all connections have lost their MOVING state + Helpers.validate_in_use_connections_state( + [ + *in_use_connections[key1], + *in_use_connections[key2], + *in_use_connections[key3], + ], + expected_state=MaintenanceState.NONE, + expected_should_reconnect="any", + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername="any", + ) + # TODO validate free connections