diff --git a/hitless.go b/hitless.go new file mode 100644 index 000000000..3c81512d4 --- /dev/null +++ b/hitless.go @@ -0,0 +1,416 @@ +package redis + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/hitless" + "github.com/redis/go-redis/v9/internal/pool" +) + +// HitlessUpgradeConfig provides configuration for hitless upgrades +type HitlessUpgradeConfig struct { + // Enabled controls whether hitless upgrades are active + Enabled bool + + // TransitionTimeout is the increased timeout for connections during transitions + // (MIGRATING/FAILING_OVER). This should be longer than normal operation timeouts + // to account for the time needed to complete the transition. + // Default: 60 seconds + TransitionTimeout time.Duration + + // CleanupInterval controls how often expired states are cleaned up + // Default: 30 seconds + CleanupInterval time.Duration +} + +// DefaultHitlessUpgradeConfig returns the default configuration for hitless upgrades +func DefaultHitlessUpgradeConfig() *HitlessUpgradeConfig { + return &HitlessUpgradeConfig{ + Enabled: true, + TransitionTimeout: 60 * time.Second, // Longer timeout for transitioning connections + CleanupInterval: 30 * time.Second, // How often to clean up expired states + } +} + +// HitlessUpgradeStatistics provides statistics about ongoing upgrade operations +type HitlessUpgradeStatistics struct { + ActiveConnections int // Total connections in transition + IsMoving bool // Whether pool is currently moving + MigratingConnections int // Connections in MIGRATING state + FailingOverConnections int // Connections in FAILING_OVER state + Timestamp time.Time // When these statistics were collected +} + +// HitlessUpgradeStatus provides detailed status of all ongoing upgrades +type HitlessUpgradeStatus struct { + ConnectionStates map[interface{}]interface{} + IsMoving bool + NewEndpoint string + Timestamp time.Time +} + +// HitlessIntegration provides the interface for hitless upgrade functionality +type HitlessIntegration interface { + // IsEnabled returns whether hitless upgrades are currently enabled + IsEnabled() bool + + // EnableHitlessUpgrades enables hitless upgrade functionality + EnableHitlessUpgrades() + + // DisableHitlessUpgrades disables hitless upgrade functionality + DisableHitlessUpgrades() + + // GetConnectionTimeout returns the appropriate timeout for a connection + // If the connection is transitioning, returns the longer TransitionTimeout + GetConnectionTimeout(conn interface{}, defaultTimeout time.Duration) time.Duration + + // GetConnectionTimeouts returns both read and write timeouts for a connection + // If the connection is transitioning, returns increased timeouts + GetConnectionTimeouts(conn interface{}, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) + + // MarkConnectionAsBlocking marks a connection as having blocking commands + MarkConnectionAsBlocking(conn interface{}, isBlocking bool) + + // IsConnectionMarkedForClosing checks if a connection should be closed + IsConnectionMarkedForClosing(conn interface{}) bool + + // ShouldRedirectBlockingConnection checks if a blocking connection should be redirected + ShouldRedirectBlockingConnection(conn interface{}) (bool, string) + + // GetUpgradeStatistics returns current upgrade statistics + GetUpgradeStatistics() *HitlessUpgradeStatistics + + // GetUpgradeStatus returns detailed upgrade status + GetUpgradeStatus() *HitlessUpgradeStatus + + // UpdateConfig updates the hitless upgrade configuration + UpdateConfig(config *HitlessUpgradeConfig) error + + // GetConfig returns the current configuration + GetConfig() *HitlessUpgradeConfig + + // Close shuts down the hitless integration + Close() error +} + +// hitlessIntegrationImpl implements the HitlessIntegration interface +type hitlessIntegrationImpl struct { + integration *hitless.RedisClientIntegration + mu sync.RWMutex +} + +// newHitlessIntegration creates a new hitless integration instance +func newHitlessIntegration(config *HitlessUpgradeConfig) *hitlessIntegrationImpl { + if config == nil { + config = DefaultHitlessUpgradeConfig() + } + + // Convert to internal config format + internalConfig := &hitless.HitlessUpgradeConfig{ + Enabled: config.Enabled, + TransitionTimeout: config.TransitionTimeout, + CleanupInterval: config.CleanupInterval, + } + + integration := hitless.NewRedisClientIntegration(internalConfig, 3*time.Second, 3*time.Second) + + return &hitlessIntegrationImpl{ + integration: integration, + } +} + +// newHitlessIntegrationWithTimeouts creates a new hitless integration instance with timeout configuration +func newHitlessIntegrationWithTimeouts(config *HitlessUpgradeConfig, defaultReadTimeout, defaultWriteTimeout time.Duration) *hitlessIntegrationImpl { + // Start with defaults + defaults := DefaultHitlessUpgradeConfig() + + // If config is nil, use all defaults + if config == nil { + config = defaults + } + + // Ensure all fields are set with defaults if they are zero values + enabled := config.Enabled + transitionTimeout := config.TransitionTimeout + cleanupInterval := config.CleanupInterval + + // Apply defaults for zero values + if transitionTimeout == 0 { + transitionTimeout = defaults.TransitionTimeout + } + if cleanupInterval == 0 { + cleanupInterval = defaults.CleanupInterval + } + + // Convert to internal config format with all fields properly set + internalConfig := &hitless.HitlessUpgradeConfig{ + Enabled: enabled, + TransitionTimeout: transitionTimeout, + CleanupInterval: cleanupInterval, + } + + integration := hitless.NewRedisClientIntegration(internalConfig, defaultReadTimeout, defaultWriteTimeout) + + return &hitlessIntegrationImpl{ + integration: integration, + } +} + +// IsEnabled returns whether hitless upgrades are currently enabled +func (h *hitlessIntegrationImpl) IsEnabled() bool { + h.mu.RLock() + defer h.mu.RUnlock() + return h.integration.IsEnabled() +} + +// EnableHitlessUpgrades enables hitless upgrade functionality +func (h *hitlessIntegrationImpl) EnableHitlessUpgrades() { + h.mu.Lock() + defer h.mu.Unlock() + h.integration.EnableHitlessUpgrades() +} + +// DisableHitlessUpgrades disables hitless upgrade functionality +func (h *hitlessIntegrationImpl) DisableHitlessUpgrades() { + h.mu.Lock() + defer h.mu.Unlock() + h.integration.DisableHitlessUpgrades() +} + +// GetConnectionTimeout returns the appropriate timeout for a connection +func (h *hitlessIntegrationImpl) GetConnectionTimeout(conn interface{}, defaultTimeout time.Duration) time.Duration { + h.mu.RLock() + defer h.mu.RUnlock() + + // Convert interface{} to *pool.Conn + if poolConn, ok := conn.(*pool.Conn); ok { + return h.integration.GetConnectionTimeout(poolConn, defaultTimeout) + } + + // If not a pool connection, return default timeout + return defaultTimeout +} + +// GetConnectionTimeouts returns both read and write timeouts for a connection +func (h *hitlessIntegrationImpl) GetConnectionTimeouts(conn interface{}, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) { + h.mu.RLock() + defer h.mu.RUnlock() + + // Convert interface{} to *pool.Conn + if poolConn, ok := conn.(*pool.Conn); ok { + return h.integration.GetConnectionTimeouts(poolConn, defaultReadTimeout, defaultWriteTimeout) + } + + // If not a pool connection, return default timeouts + return defaultReadTimeout, defaultWriteTimeout +} + +// MarkConnectionAsBlocking marks a connection as having blocking commands +func (h *hitlessIntegrationImpl) MarkConnectionAsBlocking(conn interface{}, isBlocking bool) { + h.mu.Lock() + defer h.mu.Unlock() + + // Convert interface{} to *pool.Conn + if poolConn, ok := conn.(*pool.Conn); ok { + h.integration.MarkConnectionAsBlocking(poolConn, isBlocking) + } +} + +// IsConnectionMarkedForClosing checks if a connection should be closed +func (h *hitlessIntegrationImpl) IsConnectionMarkedForClosing(conn interface{}) bool { + h.mu.RLock() + defer h.mu.RUnlock() + + // Convert interface{} to *pool.Conn + if poolConn, ok := conn.(*pool.Conn); ok { + return h.integration.IsConnectionMarkedForClosing(poolConn) + } + + return false +} + +// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected +func (h *hitlessIntegrationImpl) ShouldRedirectBlockingConnection(conn interface{}) (bool, string) { + h.mu.RLock() + defer h.mu.RUnlock() + + // Convert interface{} to *pool.Conn (can be nil for checking pool state) + var poolConn *pool.Conn + if conn != nil { + if pc, ok := conn.(*pool.Conn); ok { + poolConn = pc + } + } + + return h.integration.ShouldRedirectBlockingConnection(poolConn) +} + +// GetUpgradeStatistics returns current upgrade statistics +func (h *hitlessIntegrationImpl) GetUpgradeStatistics() *HitlessUpgradeStatistics { + h.mu.RLock() + defer h.mu.RUnlock() + + stats := h.integration.GetUpgradeStatistics() + if stats == nil { + return &HitlessUpgradeStatistics{Timestamp: time.Now()} + } + + return &HitlessUpgradeStatistics{ + ActiveConnections: stats.ActiveConnections, + IsMoving: stats.IsMoving, + MigratingConnections: stats.MigratingConnections, + FailingOverConnections: stats.FailingOverConnections, + Timestamp: stats.Timestamp, + } +} + +// GetUpgradeStatus returns detailed upgrade status +func (h *hitlessIntegrationImpl) GetUpgradeStatus() *HitlessUpgradeStatus { + h.mu.RLock() + defer h.mu.RUnlock() + + status := h.integration.GetUpgradeStatus() + if status == nil { + return &HitlessUpgradeStatus{ + ConnectionStates: make(map[interface{}]interface{}), + IsMoving: false, + NewEndpoint: "", + Timestamp: time.Now(), + } + } + + return &HitlessUpgradeStatus{ + ConnectionStates: convertToInterfaceMap(status.ConnectionStates), + IsMoving: status.IsMoving, + NewEndpoint: status.NewEndpoint, + Timestamp: status.Timestamp, + } +} + +// UpdateConfig updates the hitless upgrade configuration +func (h *hitlessIntegrationImpl) UpdateConfig(config *HitlessUpgradeConfig) error { + if config == nil { + return fmt.Errorf("config cannot be nil") + } + + h.mu.Lock() + defer h.mu.Unlock() + + // Start with defaults for any zero values + defaults := DefaultHitlessUpgradeConfig() + + // Ensure all fields are set with defaults if they are zero values + enabled := config.Enabled + transitionTimeout := config.TransitionTimeout + cleanupInterval := config.CleanupInterval + + // Apply defaults for zero values + if transitionTimeout == 0 { + transitionTimeout = defaults.TransitionTimeout + } + if cleanupInterval == 0 { + cleanupInterval = defaults.CleanupInterval + } + + // Convert to internal config format with all fields properly set + internalConfig := &hitless.HitlessUpgradeConfig{ + Enabled: enabled, + TransitionTimeout: transitionTimeout, + CleanupInterval: cleanupInterval, + } + + return h.integration.UpdateConfig(internalConfig) +} + +// GetConfig returns the current configuration +func (h *hitlessIntegrationImpl) GetConfig() *HitlessUpgradeConfig { + h.mu.RLock() + defer h.mu.RUnlock() + + internalConfig := h.integration.GetConfig() + if internalConfig == nil { + return DefaultHitlessUpgradeConfig() + } + + return &HitlessUpgradeConfig{ + Enabled: internalConfig.Enabled, + TransitionTimeout: internalConfig.TransitionTimeout, + CleanupInterval: internalConfig.CleanupInterval, + } +} + +// Close shuts down the hitless integration +func (h *hitlessIntegrationImpl) Close() error { + h.mu.Lock() + defer h.mu.Unlock() + return h.integration.Close() +} + +// getInternalIntegration returns the internal integration for use by Redis clients +func (h *hitlessIntegrationImpl) getInternalIntegration() *hitless.RedisClientIntegration { + h.mu.RLock() + defer h.mu.RUnlock() + return h.integration +} + +// ClientTimeoutProvider interface for extracting timeout configuration from client options +type ClientTimeoutProvider interface { + GetReadTimeout() time.Duration + GetWriteTimeout() time.Duration +} + +// optionsTimeoutProvider implements ClientTimeoutProvider for Options struct +type optionsTimeoutProvider struct { + readTimeout time.Duration + writeTimeout time.Duration +} + +func (p *optionsTimeoutProvider) GetReadTimeout() time.Duration { + return p.readTimeout +} + +func (p *optionsTimeoutProvider) GetWriteTimeout() time.Duration { + return p.writeTimeout +} + +// newOptionsTimeoutProvider creates a timeout provider from Options +func newOptionsTimeoutProvider(readTimeout, writeTimeout time.Duration) ClientTimeoutProvider { + return &optionsTimeoutProvider{ + readTimeout: readTimeout, + writeTimeout: writeTimeout, + } +} + +// initializeHitlessIntegration initializes hitless integration for a client +func initializeHitlessIntegration(client interface{}, config *HitlessUpgradeConfig, timeoutProvider ClientTimeoutProvider) (*hitlessIntegrationImpl, error) { + if config == nil || !config.Enabled { + return nil, nil + } + + // Extract timeout configuration from client options + defaultReadTimeout := timeoutProvider.GetReadTimeout() + defaultWriteTimeout := timeoutProvider.GetWriteTimeout() + + // Create hitless integration - each client gets its own instance + integration := newHitlessIntegrationWithTimeouts(config, defaultReadTimeout, defaultWriteTimeout) + + // Push notification handlers are registered directly by the client + // No separate registration needed in simplified implementation + + internal.Logger.Printf(context.Background(), "hitless: initialized hitless upgrades for client") + + return integration, nil +} + +// convertToInterfaceMap converts a typed map to interface{} map for public API +func convertToInterfaceMap(input map[*pool.Conn]*hitless.ConnectionState) map[interface{}]interface{} { + result := make(map[interface{}]interface{}) + for k, v := range input { + result[k] = v + } + return result +} diff --git a/hitless_config_defaults_test.go b/hitless_config_defaults_test.go new file mode 100644 index 000000000..3b4d0aff4 --- /dev/null +++ b/hitless_config_defaults_test.go @@ -0,0 +1,197 @@ +package redis + +import ( + "testing" + "time" +) + +func TestHitlessUpgradeConfig_DefaultValues(t *testing.T) { + tests := []struct { + name string + inputConfig *HitlessUpgradeConfig + expectedConfig *HitlessUpgradeConfig + }{ + { + name: "nil config should use all defaults", + inputConfig: nil, + expectedConfig: &HitlessUpgradeConfig{ + Enabled: true, + TransitionTimeout: 60 * time.Second, + CleanupInterval: 30 * time.Second, + }, + }, + { + name: "zero TransitionTimeout should use default", + inputConfig: &HitlessUpgradeConfig{ + Enabled: false, + TransitionTimeout: 0, // Zero value + CleanupInterval: 45 * time.Second, + }, + expectedConfig: &HitlessUpgradeConfig{ + Enabled: false, + TransitionTimeout: 60 * time.Second, // Should use default + CleanupInterval: 45 * time.Second, + }, + }, + { + name: "zero CleanupInterval should use default", + inputConfig: &HitlessUpgradeConfig{ + Enabled: true, + TransitionTimeout: 90 * time.Second, + CleanupInterval: 0, // Zero value + }, + expectedConfig: &HitlessUpgradeConfig{ + Enabled: true, + TransitionTimeout: 90 * time.Second, + CleanupInterval: 30 * time.Second, // Should use default + }, + }, + { + name: "both timeouts zero should use defaults", + inputConfig: &HitlessUpgradeConfig{ + Enabled: true, + TransitionTimeout: 0, // Zero value + CleanupInterval: 0, // Zero value + }, + expectedConfig: &HitlessUpgradeConfig{ + Enabled: true, + TransitionTimeout: 60 * time.Second, // Should use default + CleanupInterval: 30 * time.Second, // Should use default + }, + }, + { + name: "all values set should be preserved", + inputConfig: &HitlessUpgradeConfig{ + Enabled: false, + TransitionTimeout: 120 * time.Second, + CleanupInterval: 60 * time.Second, + }, + expectedConfig: &HitlessUpgradeConfig{ + Enabled: false, + TransitionTimeout: 120 * time.Second, + CleanupInterval: 60 * time.Second, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test with a mock client that has hitless upgrades enabled + opt := &Options{ + Addr: "127.0.0.1:6379", + Protocol: 3, + HitlessUpgrades: true, + HitlessUpgradeConfig: tt.inputConfig, + ReadTimeout: 5 * time.Second, + WriteTimeout: 5 * time.Second, + } + + // Test the integration creation using the internal method directly + // since initializeHitlessIntegration requires a push processor + integration := newHitlessIntegrationWithTimeouts(tt.inputConfig, opt.ReadTimeout, opt.WriteTimeout) + if integration == nil { + t.Fatal("Integration should not be nil") + } + + // Get the config from the integration + actualConfig := integration.GetConfig() + if actualConfig == nil { + t.Fatal("Config should not be nil") + } + + // Verify all fields match expected values + if actualConfig.Enabled != tt.expectedConfig.Enabled { + t.Errorf("Enabled: expected %v, got %v", tt.expectedConfig.Enabled, actualConfig.Enabled) + } + if actualConfig.TransitionTimeout != tt.expectedConfig.TransitionTimeout { + t.Errorf("TransitionTimeout: expected %v, got %v", tt.expectedConfig.TransitionTimeout, actualConfig.TransitionTimeout) + } + if actualConfig.CleanupInterval != tt.expectedConfig.CleanupInterval { + t.Errorf("CleanupInterval: expected %v, got %v", tt.expectedConfig.CleanupInterval, actualConfig.CleanupInterval) + } + + // Test UpdateConfig as well + newConfig := &HitlessUpgradeConfig{ + Enabled: !tt.expectedConfig.Enabled, + TransitionTimeout: 0, // Zero value should use default + CleanupInterval: 0, // Zero value should use default + } + + err := integration.UpdateConfig(newConfig) + if err != nil { + t.Fatalf("Failed to update config: %v", err) + } + + // Verify updated config has defaults applied + updatedConfig := integration.GetConfig() + if updatedConfig.Enabled == tt.expectedConfig.Enabled { + t.Error("Enabled should have been toggled") + } + if updatedConfig.TransitionTimeout != 60*time.Second { + t.Errorf("TransitionTimeout should use default (60s), got %v", updatedConfig.TransitionTimeout) + } + if updatedConfig.CleanupInterval != 30*time.Second { + t.Errorf("CleanupInterval should use default (30s), got %v", updatedConfig.CleanupInterval) + } + }) + } +} + +func TestDefaultHitlessUpgradeConfig(t *testing.T) { + config := DefaultHitlessUpgradeConfig() + + if config == nil { + t.Fatal("Default config should not be nil") + } + + if !config.Enabled { + t.Error("Default config should have Enabled=true") + } + + if config.TransitionTimeout != 60*time.Second { + t.Errorf("Default TransitionTimeout should be 60s, got %v", config.TransitionTimeout) + } + + if config.CleanupInterval != 30*time.Second { + t.Errorf("Default CleanupInterval should be 30s, got %v", config.CleanupInterval) + } +} + +func TestHitlessUpgradeConfig_ZeroValueHandling(t *testing.T) { + // Test that zero values are properly handled in various scenarios + + // Test 1: Partial config with some zero values + partialConfig := &HitlessUpgradeConfig{ + Enabled: true, + // TransitionTimeout and CleanupInterval are zero values + } + + integration := newHitlessIntegrationWithTimeouts(partialConfig, 3*time.Second, 3*time.Second) + if integration == nil { + t.Fatal("Integration should not be nil") + } + + config := integration.GetConfig() + if config.TransitionTimeout == 0 { + t.Error("Zero TransitionTimeout should have been replaced with default") + } + if config.CleanupInterval == 0 { + t.Error("Zero CleanupInterval should have been replaced with default") + } + + // Test 2: Empty struct + emptyConfig := &HitlessUpgradeConfig{} + + integration2 := newHitlessIntegrationWithTimeouts(emptyConfig, 3*time.Second, 3*time.Second) + if integration2 == nil { + t.Fatal("Integration should not be nil") + } + + config2 := integration2.GetConfig() + if config2.TransitionTimeout == 0 { + t.Error("Zero TransitionTimeout in empty config should have been replaced with default") + } + if config2.CleanupInterval == 0 { + t.Error("Zero CleanupInterval in empty config should have been replaced with default") + } +} diff --git a/internal/hitless/README.md b/internal/hitless/README.md new file mode 100644 index 000000000..6b591bdbc --- /dev/null +++ b/internal/hitless/README.md @@ -0,0 +1,23 @@ +# Hitless Upgrade Package + +This package implements hitless upgrade functionality for Redis cluster clients using the push notification architecture. It provides handlers for managing connection and pool state during Redis cluster upgrades. + +## Quick Start + +To enable hitless upgrades in your Redis client, simply set the configuration option: + +```go +import "github.com/redis/go-redis/v9" + +// Enable hitless upgrades with a simple configuration option +client := redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{"127.0.0.1:7000", "127.0.0.1:7001", "127.0.0.1:7002"}, + Protocol: 3, // RESP3 required for push notifications + HitlessUpgrades: true, // Enable hitless upgrades +}) +defer client.Close() + +// That's it! Use your client normally - hitless upgrades work automatically +ctx := context.Background() +client.Set(ctx, "key", "value", 0) +``` \ No newline at end of file diff --git a/internal/hitless/client_integration.go b/internal/hitless/client_integration.go new file mode 100644 index 000000000..1494bcd36 --- /dev/null +++ b/internal/hitless/client_integration.go @@ -0,0 +1,200 @@ +package hitless + +import ( + "context" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// ClientIntegrator provides integration between hitless upgrade handlers and Redis clients +type ClientIntegrator struct { + upgradeHandler *UpgradeHandler + mu sync.RWMutex + + // Simple atomic state for pool redirection + isMoving int32 // atomic: 0 = not moving, 1 = moving + newEndpoint string // only written during MOVING, read-only after +} + +// NewClientIntegrator creates a new client integrator with client timeout configuration +func NewClientIntegrator(defaultReadTimeout, defaultWriteTimeout time.Duration) *ClientIntegrator { + return &ClientIntegrator{ + upgradeHandler: NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout), + } +} + +// GetUpgradeHandler returns the upgrade handler for direct access +func (ci *ClientIntegrator) GetUpgradeHandler() *UpgradeHandler { + return ci.upgradeHandler +} + +// HandlePushNotification is the main entry point for processing upgrade notifications +func (ci *ClientIntegrator) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // Handle MOVING notifications for pool redirection + if len(notification) > 0 { + if notificationType, ok := notification[0].(string); ok && notificationType == "MOVING" { + if len(notification) >= 3 { + if newEndpoint, ok := notification[2].(string); ok { + // Simple atomic state update - no locks needed + ci.newEndpoint = newEndpoint + atomic.StoreInt32(&ci.isMoving, 1) + } + } + } + } + + return ci.upgradeHandler.HandlePushNotification(ctx, handlerCtx, notification) +} + +// Close shuts down the client integrator +func (ci *ClientIntegrator) Close() error { + ci.mu.Lock() + defer ci.mu.Unlock() + + // Reset atomic state + atomic.StoreInt32(&ci.isMoving, 0) + ci.newEndpoint = "" + + return nil +} + +// IsMoving returns true if the pool is currently moving to a new endpoint +// Uses atomic read - no locks needed +func (ci *ClientIntegrator) IsMoving() bool { + return atomic.LoadInt32(&ci.isMoving) == 1 +} + +// GetNewEndpoint returns the new endpoint if moving, empty string otherwise +// Safe to read without locks since it's only written during MOVING +func (ci *ClientIntegrator) GetNewEndpoint() string { + if ci.IsMoving() { + return ci.newEndpoint + } + return "" +} + +// PushNotificationHandlerInterface defines the interface for push notification handlers +// This implements the interface expected by the push notification system +type PushNotificationHandlerInterface interface { + HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error +} + +// Ensure ClientIntegrator implements the interface +var _ PushNotificationHandlerInterface = (*ClientIntegrator)(nil) + +// PoolRedirector provides pool redirection functionality for hitless upgrades +type PoolRedirector struct { + poolManager *PoolEndpointManager + mu sync.RWMutex +} + +// NewPoolRedirector creates a new pool redirector +func NewPoolRedirector() *PoolRedirector { + return &PoolRedirector{ + poolManager: NewPoolEndpointManager(), + } +} + +// RedirectPool redirects a connection pool to a new endpoint +func (pr *PoolRedirector) RedirectPool(ctx context.Context, pooler pool.Pooler, newEndpoint string, timeout time.Duration) error { + pr.mu.Lock() + defer pr.mu.Unlock() + + return pr.poolManager.RedirectPool(ctx, pooler, newEndpoint, timeout) +} + +// IsPoolRedirected checks if a pool is currently redirected +func (pr *PoolRedirector) IsPoolRedirected(pooler pool.Pooler) bool { + pr.mu.RLock() + defer pr.mu.RUnlock() + + return pr.poolManager.IsPoolRedirected(pooler) +} + +// GetRedirection returns redirection information for a pool +func (pr *PoolRedirector) GetRedirection(pooler pool.Pooler) (*EndpointRedirection, bool) { + pr.mu.RLock() + defer pr.mu.RUnlock() + + return pr.poolManager.GetRedirection(pooler) +} + +// Close shuts down the pool redirector +func (pr *PoolRedirector) Close() error { + pr.mu.Lock() + defer pr.mu.Unlock() + + // Clean up all redirections + ctx := context.Background() + pr.poolManager.CleanupExpiredRedirections(ctx) + + return nil +} + +// ConnectionStateTracker tracks connection states during upgrades +type ConnectionStateTracker struct { + upgradeHandler *UpgradeHandler + mu sync.RWMutex +} + +// NewConnectionStateTracker creates a new connection state tracker with timeout configuration +func NewConnectionStateTracker(defaultReadTimeout, defaultWriteTimeout time.Duration) *ConnectionStateTracker { + return &ConnectionStateTracker{ + upgradeHandler: NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout), + } +} + +// IsConnectionTransitioning checks if a connection is currently transitioning +func (cst *ConnectionStateTracker) IsConnectionTransitioning(conn *pool.Conn) bool { + cst.mu.RLock() + defer cst.mu.RUnlock() + + return cst.upgradeHandler.IsConnectionTransitioning(conn) +} + +// GetConnectionState returns the current state of a connection +func (cst *ConnectionStateTracker) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) { + cst.mu.RLock() + defer cst.mu.RUnlock() + + return cst.upgradeHandler.GetConnectionState(conn) +} + +// CleanupConnection removes tracking for a connection +func (cst *ConnectionStateTracker) CleanupConnection(conn *pool.Conn) { + cst.mu.Lock() + defer cst.mu.Unlock() + + cst.upgradeHandler.CleanupConnection(conn) +} + +// Close shuts down the connection state tracker +func (cst *ConnectionStateTracker) Close() error { + cst.mu.Lock() + defer cst.mu.Unlock() + + // Clean up all expired states + cst.upgradeHandler.CleanupExpiredStates() + + return nil +} + +// HitlessUpgradeConfig provides configuration for hitless upgrades +type HitlessUpgradeConfig struct { + Enabled bool + TransitionTimeout time.Duration + CleanupInterval time.Duration +} + +// DefaultHitlessUpgradeConfig returns default configuration for hitless upgrades +func DefaultHitlessUpgradeConfig() *HitlessUpgradeConfig { + return &HitlessUpgradeConfig{ + Enabled: true, + TransitionTimeout: 60 * time.Second, // Longer timeout for transitioning connections + CleanupInterval: 30 * time.Second, // How often to clean up expired states + } +} diff --git a/internal/hitless/pool_manager.go b/internal/hitless/pool_manager.go new file mode 100644 index 000000000..ac87b3b0c --- /dev/null +++ b/internal/hitless/pool_manager.go @@ -0,0 +1,205 @@ +package hitless + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" +) + +// PoolEndpointManager manages endpoint transitions for connection pools during hitless upgrades. +// It provides functionality to redirect new connections to new endpoints while maintaining +// existing connections until they can be gracefully transitioned. +type PoolEndpointManager struct { + mu sync.RWMutex + + // Map of pools to their endpoint redirections + redirections map[interface{}]*EndpointRedirection + + // Original dialers for pools (to restore after transition) + originalDialers map[interface{}]func(context.Context) (net.Conn, error) +} + +// EndpointRedirection represents an active endpoint redirection +type EndpointRedirection struct { + OriginalEndpoint string + NewEndpoint string + StartTime time.Time + Timeout time.Duration + + // Statistics + NewConnections int64 + FailedConnections int64 +} + +// NewPoolEndpointManager creates a new pool endpoint manager +func NewPoolEndpointManager() *PoolEndpointManager { + return &PoolEndpointManager{ + redirections: make(map[interface{}]*EndpointRedirection), + originalDialers: make(map[interface{}]func(context.Context) (net.Conn, error)), + } +} + +// RedirectPool redirects new connections from a pool to a new endpoint +func (m *PoolEndpointManager) RedirectPool(ctx context.Context, pooler pool.Pooler, newEndpoint string, timeout time.Duration) error { + m.mu.Lock() + defer m.mu.Unlock() + + // Check if pool is already being redirected + if _, exists := m.redirections[pooler]; exists { + return fmt.Errorf("pool is already being redirected") + } + + // Get the current dialer from the pool + connPool, ok := pooler.(*pool.ConnPool) + if !ok { + return fmt.Errorf("unsupported pool type: %T", pooler) + } + + // Store original dialer + originalDialer := m.getPoolDialer(connPool) + if originalDialer == nil { + return fmt.Errorf("could not get original dialer from pool") + } + + m.originalDialers[pooler] = originalDialer + + // Create new dialer that connects to the new endpoint + newDialer := m.createRedirectDialer(ctx, newEndpoint, originalDialer) + + // Replace the pool's dialer + if err := m.setPoolDialer(connPool, newDialer); err != nil { + delete(m.originalDialers, pooler) + return fmt.Errorf("failed to set new dialer: %w", err) + } + + // Record the redirection + m.redirections[pooler] = &EndpointRedirection{ + OriginalEndpoint: m.extractEndpointFromDialer(originalDialer), + NewEndpoint: newEndpoint, + StartTime: time.Now(), + Timeout: timeout, + } + + internal.Logger.Printf(ctx, "hitless: redirected pool to new endpoint %s", newEndpoint) + + return nil +} + +// IsPoolRedirected checks if a pool is currently being redirected +func (m *PoolEndpointManager) IsPoolRedirected(pooler pool.Pooler) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + _, exists := m.redirections[pooler] + return exists +} + +// GetRedirection returns redirection information for a pool +func (m *PoolEndpointManager) GetRedirection(pooler pool.Pooler) (*EndpointRedirection, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + redirection, exists := m.redirections[pooler] + if !exists { + return nil, false + } + + // Return a copy to avoid race conditions + redirectionCopy := *redirection + return &redirectionCopy, true +} + +// CleanupExpiredRedirections removes expired redirections +func (m *PoolEndpointManager) CleanupExpiredRedirections(ctx context.Context) { + m.mu.Lock() + defer m.mu.Unlock() + + now := time.Now() + + for pooler, redirection := range m.redirections { + if now.Sub(redirection.StartTime) > redirection.Timeout { + // TODO: Here we should decide if we need to failback to the original dialer, + // i.e. if the new endpoint did not produce any active connections. + delete(m.redirections, pooler) + delete(m.originalDialers, pooler) + internal.Logger.Printf(ctx, "hitless: cleaned up expired redirection for pool") + } + } +} + +// createRedirectDialer creates a dialer that connects to the new endpoint +func (m *PoolEndpointManager) createRedirectDialer(ctx context.Context, newEndpoint string, originalDialer func(context.Context) (net.Conn, error)) func(context.Context) (net.Conn, error) { + return func(dialCtx context.Context) (net.Conn, error) { + // Try to connect to the new endpoint + conn, err := net.DialTimeout("tcp", newEndpoint, 10*time.Second) + if err != nil { + internal.Logger.Printf(ctx, "hitless: failed to connect to new endpoint %s: %v", newEndpoint, err) + + // Fallback to original dialer + return originalDialer(dialCtx) + } + + internal.Logger.Printf(ctx, "hitless: successfully connected to new endpoint %s", newEndpoint) + return conn, nil + } +} + +// getPoolDialer extracts the dialer from a connection pool +func (m *PoolEndpointManager) getPoolDialer(connPool *pool.ConnPool) func(context.Context) (net.Conn, error) { + return connPool.GetDialer() +} + +// setPoolDialer sets a new dialer for a connection pool +func (m *PoolEndpointManager) setPoolDialer(connPool *pool.ConnPool, dialer func(context.Context) (net.Conn, error)) error { + return connPool.SetDialer(dialer) +} + +// extractEndpointFromDialer extracts the endpoint address from a dialer +func (m *PoolEndpointManager) extractEndpointFromDialer(dialer func(context.Context) (net.Conn, error)) string { + // Try to extract endpoint by making a test connection + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + conn, err := dialer(ctx) + if err != nil { + return "unknown" + } + defer conn.Close() + + if conn.RemoteAddr() != nil { + return conn.RemoteAddr().String() + } + + return "unknown" +} + +// GetActiveRedirections returns all active redirections +func (m *PoolEndpointManager) GetActiveRedirections() map[interface{}]*EndpointRedirection { + m.mu.RLock() + defer m.mu.RUnlock() + + // Create copies to avoid race conditions + redirections := make(map[interface{}]*EndpointRedirection) + for pooler, redirection := range m.redirections { + redirectionCopy := *redirection + redirections[pooler] = &redirectionCopy + } + + return redirections +} + +// UpdateRedirectionStats updates statistics for a redirection +func (m *PoolEndpointManager) UpdateRedirectionStats(pooler pool.Pooler, newConnections, failedConnections int64) { + m.mu.Lock() + defer m.mu.Unlock() + + if redirection, exists := m.redirections[pooler]; exists { + redirection.NewConnections += newConnections + redirection.FailedConnections += failedConnections + } +} diff --git a/internal/hitless/redis_integration.go b/internal/hitless/redis_integration.go new file mode 100644 index 000000000..1c50ecbf1 --- /dev/null +++ b/internal/hitless/redis_integration.go @@ -0,0 +1,309 @@ +package hitless + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// UpgradeStatus represents the current status of all upgrade operations +type UpgradeStatus struct { + ConnectionStates map[*pool.Conn]*ConnectionState + IsMoving bool + NewEndpoint string + Timestamp time.Time +} + +// UpgradeStatistics provides statistics about upgrade operations +type UpgradeStatistics struct { + ActiveConnections int + IsMoving bool + MigratingConnections int + FailingOverConnections int + Timestamp time.Time +} + +// RedisClientIntegration provides complete hitless upgrade integration for Redis clients +type RedisClientIntegration struct { + clientIntegrator *ClientIntegrator + connectionStateTracker *ConnectionStateTracker + config *HitlessUpgradeConfig + + mu sync.RWMutex + enabled bool +} + +// NewRedisClientIntegration creates a new Redis client integration for hitless upgrades +// This is used internally by the main hitless.go package +func NewRedisClientIntegration(config *HitlessUpgradeConfig, defaultReadTimeout, defaultWriteTimeout time.Duration) *RedisClientIntegration { + // Start with defaults + defaults := DefaultHitlessUpgradeConfig() + + // If config is nil, use all defaults + if config == nil { + config = defaults + } else { + // Ensure all fields are set with defaults if they are zero values + if config.TransitionTimeout == 0 { + config = &HitlessUpgradeConfig{ + Enabled: config.Enabled, + TransitionTimeout: defaults.TransitionTimeout, + CleanupInterval: config.CleanupInterval, + } + } + if config.CleanupInterval == 0 { + config = &HitlessUpgradeConfig{ + Enabled: config.Enabled, + TransitionTimeout: config.TransitionTimeout, + CleanupInterval: defaults.CleanupInterval, + } + } + } + + return &RedisClientIntegration{ + clientIntegrator: NewClientIntegrator(defaultReadTimeout, defaultWriteTimeout), + connectionStateTracker: NewConnectionStateTracker(defaultReadTimeout, defaultWriteTimeout), + config: config, + enabled: config.Enabled, + } +} + +// EnableHitlessUpgrades enables hitless upgrade functionality +func (rci *RedisClientIntegration) EnableHitlessUpgrades() { + rci.mu.Lock() + defer rci.mu.Unlock() + rci.enabled = true +} + +// DisableHitlessUpgrades disables hitless upgrade functionality +func (rci *RedisClientIntegration) DisableHitlessUpgrades() { + rci.mu.Lock() + defer rci.mu.Unlock() + rci.enabled = false +} + +// IsEnabled returns whether hitless upgrades are enabled +func (rci *RedisClientIntegration) IsEnabled() bool { + rci.mu.RLock() + defer rci.mu.RUnlock() + return rci.enabled +} + +// No client registration needed - each client has its own hitless integration instance + +// HandlePushNotification processes push notifications for hitless upgrades +func (rci *RedisClientIntegration) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if !rci.IsEnabled() { + // If disabled, just log and return without processing + internal.Logger.Printf(ctx, "hitless: received notification but hitless upgrades are disabled") + return nil + } + + return rci.clientIntegrator.HandlePushNotification(ctx, handlerCtx, notification) +} + +// IsConnectionTransitioning checks if a connection is currently transitioning +func (rci *RedisClientIntegration) IsConnectionTransitioning(conn *pool.Conn) bool { + if !rci.IsEnabled() { + return false + } + + return rci.connectionStateTracker.IsConnectionTransitioning(conn) +} + +// GetConnectionState returns the current state of a connection +func (rci *RedisClientIntegration) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) { + if !rci.IsEnabled() { + return nil, false + } + + return rci.connectionStateTracker.GetConnectionState(conn) +} + +// GetUpgradeStatus returns comprehensive status of all ongoing upgrades +func (rci *RedisClientIntegration) GetUpgradeStatus() *UpgradeStatus { + connStates := rci.clientIntegrator.GetUpgradeHandler().GetActiveTransitions() + + return &UpgradeStatus{ + ConnectionStates: connStates, + IsMoving: rci.clientIntegrator.IsMoving(), + NewEndpoint: rci.clientIntegrator.GetNewEndpoint(), + Timestamp: time.Now(), + } +} + +// GetUpgradeStatistics returns statistics about upgrade operations +func (rci *RedisClientIntegration) GetUpgradeStatistics() *UpgradeStatistics { + connStates := rci.clientIntegrator.GetUpgradeHandler().GetActiveTransitions() + + stats := &UpgradeStatistics{ + ActiveConnections: len(connStates), + IsMoving: rci.clientIntegrator.IsMoving(), + Timestamp: time.Now(), + } + + // Count by type + stats.MigratingConnections = 0 + stats.FailingOverConnections = 0 + for _, state := range connStates { + switch state.TransitionType { + case "MIGRATING": + stats.MigratingConnections++ + case "FAILING_OVER": + stats.FailingOverConnections++ + } + } + + return stats +} + +// GetConnectionTimeout returns the appropriate timeout for a connection +// If the connection is transitioning (MIGRATING/FAILING_OVER), returns the longer TransitionTimeout +// Otherwise returns the provided defaultTimeout +func (rci *RedisClientIntegration) GetConnectionTimeout(conn *pool.Conn, defaultTimeout time.Duration) time.Duration { + if !rci.IsEnabled() { + return defaultTimeout + } + + // Check if connection is transitioning + if rci.connectionStateTracker.IsConnectionTransitioning(conn) { + // Use longer timeout for transitioning connections + return rci.config.TransitionTimeout + } + + return defaultTimeout +} + +// GetConnectionTimeouts returns both read and write timeouts for a connection +func (rci *RedisClientIntegration) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout time.Duration) (time.Duration, time.Duration) { + if !rci.IsEnabled() { + return defaultReadTimeout, defaultWriteTimeout + } + + // Use the upgrade handler to get appropriate timeouts + upgradeHandler := rci.clientIntegrator.GetUpgradeHandler() + return upgradeHandler.GetConnectionTimeouts(conn, defaultReadTimeout, defaultWriteTimeout, rci.config.TransitionTimeout) +} + +// MarkConnectionAsBlocking marks a connection as having blocking commands +func (rci *RedisClientIntegration) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) { + if !rci.IsEnabled() { + return + } + + // Use the upgrade handler to mark connection as blocking + upgradeHandler := rci.clientIntegrator.GetUpgradeHandler() + upgradeHandler.MarkConnectionAsBlocking(conn, isBlocking) +} + +// IsConnectionMarkedForClosing checks if a connection should be closed +func (rci *RedisClientIntegration) IsConnectionMarkedForClosing(conn *pool.Conn) bool { + if !rci.IsEnabled() { + return false + } + + // Use the upgrade handler to check if connection is marked for closing + upgradeHandler := rci.clientIntegrator.GetUpgradeHandler() + return upgradeHandler.IsConnectionMarkedForClosing(conn) +} + +// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected +func (rci *RedisClientIntegration) ShouldRedirectBlockingConnection(conn *pool.Conn) (bool, string) { + if !rci.IsEnabled() { + return false, "" + } + + // Check client integrator's atomic state for pool-level redirection + if rci.clientIntegrator.IsMoving() { + return true, rci.clientIntegrator.GetNewEndpoint() + } + + // Check specific connection state + upgradeHandler := rci.clientIntegrator.GetUpgradeHandler() + return upgradeHandler.ShouldRedirectBlockingConnection(conn, rci.clientIntegrator) +} + +// CleanupConnection removes tracking for a connection (called when connection is closed) +func (rci *RedisClientIntegration) CleanupConnection(conn *pool.Conn) { + if rci.IsEnabled() { + rci.connectionStateTracker.CleanupConnection(conn) + } +} + +// CleanupPool removed - no pool state to clean up + +// Close shuts down the Redis client integration +func (rci *RedisClientIntegration) Close() error { + rci.mu.Lock() + defer rci.mu.Unlock() + + var firstErr error + + // Close all components + if err := rci.clientIntegrator.Close(); err != nil && firstErr == nil { + firstErr = err + } + + // poolRedirector removed in simplified implementation + + if err := rci.connectionStateTracker.Close(); err != nil && firstErr == nil { + firstErr = err + } + + rci.enabled = false + + return firstErr +} + +// GetConfig returns the current configuration +func (rci *RedisClientIntegration) GetConfig() *HitlessUpgradeConfig { + rci.mu.RLock() + defer rci.mu.RUnlock() + + // Return a copy to prevent modification + configCopy := *rci.config + return &configCopy +} + +// UpdateConfig updates the configuration +func (rci *RedisClientIntegration) UpdateConfig(config *HitlessUpgradeConfig) error { + if config == nil { + return fmt.Errorf("config cannot be nil") + } + + rci.mu.Lock() + defer rci.mu.Unlock() + + // Start with defaults for any zero values + defaults := DefaultHitlessUpgradeConfig() + + // Ensure all fields are set with defaults if they are zero values + enabled := config.Enabled + transitionTimeout := config.TransitionTimeout + cleanupInterval := config.CleanupInterval + + // Apply defaults for zero values + if transitionTimeout == 0 { + transitionTimeout = defaults.TransitionTimeout + } + if cleanupInterval == 0 { + cleanupInterval = defaults.CleanupInterval + } + + // Create properly configured config + finalConfig := &HitlessUpgradeConfig{ + Enabled: enabled, + TransitionTimeout: transitionTimeout, + CleanupInterval: cleanupInterval, + } + + rci.config = finalConfig + rci.enabled = finalConfig.Enabled + + return nil +} diff --git a/internal/hitless/upgrade_handler.go b/internal/hitless/upgrade_handler.go new file mode 100644 index 000000000..4a2191337 --- /dev/null +++ b/internal/hitless/upgrade_handler.go @@ -0,0 +1,500 @@ +package hitless + +import ( + "context" + "fmt" + "strconv" + "sync" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// UpgradeHandler handles hitless upgrade push notifications for Redis cluster upgrades. +// It implements different strategies based on notification type: +// - MOVING: Changes pool state to use new endpoint for future connections +// - mark existing connections for closing, handle pubsub to change underlying connection, change pool dialer with the new endpoint +// +// - MIGRATING/FAILING_OVER: Marks specific connection as in transition +// - relaxing timeouts +// +// - MIGRATED/FAILED_OVER: Clears transition state for specific connection +// - return to original timeouts +type UpgradeHandler struct { + mu sync.RWMutex + + // Connection-specific state for MIGRATING/FAILING_OVER notifications + connectionStates map[*pool.Conn]*ConnectionState + + // Pool-level state removed - using atomic state in ClusterUpgradeManager instead + + // Client configuration for getting default timeouts + defaultReadTimeout time.Duration + defaultWriteTimeout time.Duration +} + +// ConnectionState tracks the state of a specific connection during upgrades +type ConnectionState struct { + IsTransitioning bool + TransitionType string + StartTime time.Time + ShardID string + TimeoutSeconds int + + // Timeout management + OriginalReadTimeout time.Duration // Original read timeout + OriginalWriteTimeout time.Duration // Original write timeout + + // MOVING state specific + MarkedForClosing bool // Connection should be closed after current commands + IsBlocking bool // Connection has blocking commands + NewEndpoint string // New endpoint for blocking commands + LastCommandTime time.Time // When the last command was sent +} + +// PoolState removed - using atomic state in ClusterUpgradeManager instead + +// NewUpgradeHandler creates a new hitless upgrade handler with client timeout configuration +func NewUpgradeHandler(defaultReadTimeout, defaultWriteTimeout time.Duration) *UpgradeHandler { + return &UpgradeHandler{ + connectionStates: make(map[*pool.Conn]*ConnectionState), + defaultReadTimeout: defaultReadTimeout, + defaultWriteTimeout: defaultWriteTimeout, + } +} + +// GetConnectionTimeouts returns the appropriate read and write timeouts for a connection +// If the connection is transitioning, returns increased timeouts +func (h *UpgradeHandler) GetConnectionTimeouts(conn *pool.Conn, defaultReadTimeout, defaultWriteTimeout, transitionTimeout time.Duration) (time.Duration, time.Duration) { + h.mu.RLock() + defer h.mu.RUnlock() + + state, exists := h.connectionStates[conn] + if !exists || !state.IsTransitioning { + return defaultReadTimeout, defaultWriteTimeout + } + + // For transitioning connections (MIGRATING/FAILING_OVER), use longer timeouts + switch state.TransitionType { + case "MIGRATING", "FAILING_OVER": + return transitionTimeout, transitionTimeout + case "MOVING": + // For MOVING connections, use default timeouts but mark for special handling + return defaultReadTimeout, defaultWriteTimeout + default: + return defaultReadTimeout, defaultWriteTimeout + } +} + +// MarkConnectionForClosing marks a connection to be closed after current commands complete +func (h *UpgradeHandler) MarkConnectionForClosing(conn *pool.Conn, newEndpoint string) { + h.mu.Lock() + defer h.mu.Unlock() + + state, exists := h.connectionStates[conn] + if !exists { + state = &ConnectionState{ + IsTransitioning: true, + TransitionType: "MOVING", + StartTime: time.Now(), + } + h.connectionStates[conn] = state + } + + state.MarkedForClosing = true + state.NewEndpoint = newEndpoint + state.LastCommandTime = time.Now() +} + +// IsConnectionMarkedForClosing checks if a connection should be closed +func (h *UpgradeHandler) IsConnectionMarkedForClosing(conn *pool.Conn) bool { + h.mu.RLock() + defer h.mu.RUnlock() + + state, exists := h.connectionStates[conn] + return exists && state.MarkedForClosing +} + +// MarkConnectionAsBlocking marks a connection as having blocking commands +func (h *UpgradeHandler) MarkConnectionAsBlocking(conn *pool.Conn, isBlocking bool) { + h.mu.Lock() + defer h.mu.Unlock() + + state, exists := h.connectionStates[conn] + if !exists { + state = &ConnectionState{ + IsTransitioning: false, + } + h.connectionStates[conn] = state + } + + state.IsBlocking = isBlocking +} + +// ShouldRedirectBlockingConnection checks if a blocking connection should be redirected +// Uses client integrator's atomic state for pool-level checks - minimal locking +func (h *UpgradeHandler) ShouldRedirectBlockingConnection(conn *pool.Conn, clientIntegrator interface{}) (bool, string) { + if conn != nil { + // Check specific connection - need lock only for connection state + h.mu.RLock() + state, exists := h.connectionStates[conn] + h.mu.RUnlock() + + if exists && state.IsBlocking && state.TransitionType == "MOVING" && state.NewEndpoint != "" { + return true, state.NewEndpoint + } + } + + // Check client integrator's atomic state - no locks needed + if ci, ok := clientIntegrator.(*ClientIntegrator); ok && ci != nil && ci.IsMoving() { + return true, ci.GetNewEndpoint() + } + + return false, "" +} + +// ShouldRedirectNewBlockingConnection removed - functionality merged into ShouldRedirectBlockingConnection + +// HandlePushNotification processes hitless upgrade push notifications +func (h *UpgradeHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) == 0 { + return fmt.Errorf("hitless: empty notification received") + } + + notificationType, ok := notification[0].(string) + if !ok { + return fmt.Errorf("hitless: notification type is not a string: %T", notification[0]) + } + + internal.Logger.Printf(ctx, "hitless: processing %s notification with %d elements", notificationType, len(notification)) + + switch notificationType { + case "MOVING": + return h.handleMovingNotification(ctx, handlerCtx, notification) + case "MIGRATING": + return h.handleMigratingNotification(ctx, handlerCtx, notification) + case "MIGRATED": + return h.handleMigratedNotification(ctx, handlerCtx, notification) + case "FAILING_OVER": + return h.handleFailingOverNotification(ctx, handlerCtx, notification) + case "FAILED_OVER": + return h.handleFailedOverNotification(ctx, handlerCtx, notification) + default: + internal.Logger.Printf(ctx, "hitless: unknown notification type: %s", notificationType) + return nil + } +} + +// handleMovingNotification processes MOVING notifications that affect the entire pool +// Format: ["MOVING", time_seconds, "new_endpoint"] +func (h *UpgradeHandler) handleMovingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 3 { + return fmt.Errorf("hitless: MOVING notification requires at least 3 elements, got %d", len(notification)) + } + + // Parse timeout + timeoutSeconds, err := h.parseTimeoutSeconds(notification[1]) + if err != nil { + return fmt.Errorf("hitless: failed to parse timeout for MOVING notification: %w", err) + } + + // Parse new endpoint + newEndpoint, ok := notification[2].(string) + if !ok { + return fmt.Errorf("hitless: new endpoint is not a string: %T", notification[2]) + } + + internal.Logger.Printf(ctx, "hitless: MOVING notification - endpoint will move to %s in %d seconds", newEndpoint, timeoutSeconds) + h.mu.Lock() + // Mark all existing connections for closing after current commands complete + for _, state := range h.connectionStates { + state.MarkedForClosing = true + state.NewEndpoint = newEndpoint + state.TransitionType = "MOVING" + state.LastCommandTime = time.Now() + } + h.mu.Unlock() + + internal.Logger.Printf(ctx, "hitless: marked existing connections for closing, new blocking commands will use %s", newEndpoint) + + return nil +} + +// Removed complex helper methods - simplified to direct inline logic + +// handleMigratingNotification processes MIGRATING notifications for specific connections +// Format: ["MIGRATING", time_seconds, shard_id] +func (h *UpgradeHandler) handleMigratingNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 3 { + return fmt.Errorf("hitless: MIGRATING notification requires at least 3 elements, got %d", len(notification)) + } + + timeoutSeconds, err := h.parseTimeoutSeconds(notification[1]) + if err != nil { + return fmt.Errorf("hitless: failed to parse timeout for MIGRATING notification: %w", err) + } + + shardID, err := h.parseShardID(notification[2]) + if err != nil { + return fmt.Errorf("hitless: failed to parse shard ID for MIGRATING notification: %w", err) + } + + conn := handlerCtx.Conn + if conn == nil { + return fmt.Errorf("hitless: no connection available in handler context") + } + + internal.Logger.Printf(ctx, "hitless: MIGRATING notification - shard %s will migrate in %d seconds on connection %p", shardID, timeoutSeconds, conn) + + h.mu.Lock() + defer h.mu.Unlock() + + // Store original timeouts if not already stored + var originalReadTimeout, originalWriteTimeout time.Duration + if existingState, exists := h.connectionStates[conn]; exists { + originalReadTimeout = existingState.OriginalReadTimeout + originalWriteTimeout = existingState.OriginalWriteTimeout + } else { + // Get default timeouts from client configuration + originalReadTimeout = h.defaultReadTimeout + originalWriteTimeout = h.defaultWriteTimeout + } + + h.connectionStates[conn] = &ConnectionState{ + IsTransitioning: true, + TransitionType: "MIGRATING", + StartTime: time.Now(), + ShardID: shardID, + TimeoutSeconds: timeoutSeconds, + OriginalReadTimeout: originalReadTimeout, + OriginalWriteTimeout: originalWriteTimeout, + LastCommandTime: time.Now(), + } + + internal.Logger.Printf(ctx, "hitless: connection %p marked as MIGRATING with increased timeouts", conn) + + return nil +} + +// handleMigratedNotification processes MIGRATED notifications for specific connections +// Format: ["MIGRATED", shard_id] +func (h *UpgradeHandler) handleMigratedNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 2 { + return fmt.Errorf("hitless: MIGRATED notification requires at least 2 elements, got %d", len(notification)) + } + + shardID, err := h.parseShardID(notification[1]) + if err != nil { + return fmt.Errorf("hitless: failed to parse shard ID for MIGRATED notification: %w", err) + } + + conn := handlerCtx.Conn + if conn == nil { + return fmt.Errorf("hitless: no connection available in handler context") + } + + internal.Logger.Printf(ctx, "hitless: MIGRATED notification - shard %s migration completed on connection %p", shardID, conn) + + h.mu.Lock() + defer h.mu.Unlock() + + // Clear the transitioning state for this connection and restore original timeouts + if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "MIGRATING" && state.ShardID == shardID { + internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)", + conn, state.OriginalReadTimeout, state.OriginalWriteTimeout) + + // In a real implementation, this would restore the connection's original timeouts + // For now, we'll just log and delete the state + delete(h.connectionStates, conn) + internal.Logger.Printf(ctx, "hitless: cleared MIGRATING state and restored timeouts for connection %p", conn) + } + + return nil +} + +// handleFailingOverNotification processes FAILING_OVER notifications for specific connections +// Format: ["FAILING_OVER", time_seconds, shard_id] +func (h *UpgradeHandler) handleFailingOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 3 { + return fmt.Errorf("hitless: FAILING_OVER notification requires at least 3 elements, got %d", len(notification)) + } + + timeoutSeconds, err := h.parseTimeoutSeconds(notification[1]) + if err != nil { + return fmt.Errorf("hitless: failed to parse timeout for FAILING_OVER notification: %w", err) + } + + shardID, err := h.parseShardID(notification[2]) + if err != nil { + return fmt.Errorf("hitless: failed to parse shard ID for FAILING_OVER notification: %w", err) + } + + conn := handlerCtx.Conn + if conn == nil { + return fmt.Errorf("hitless: no connection available in handler context") + } + + internal.Logger.Printf(ctx, "hitless: FAILING_OVER notification - shard %s will failover in %d seconds on connection %p", shardID, timeoutSeconds, conn) + + h.mu.Lock() + defer h.mu.Unlock() + + // Store original timeouts if not already stored + var originalReadTimeout, originalWriteTimeout time.Duration + if existingState, exists := h.connectionStates[conn]; exists { + originalReadTimeout = existingState.OriginalReadTimeout + originalWriteTimeout = existingState.OriginalWriteTimeout + } else { + // Get default timeouts from client configuration + originalReadTimeout = h.defaultReadTimeout + originalWriteTimeout = h.defaultWriteTimeout + } + + h.connectionStates[conn] = &ConnectionState{ + IsTransitioning: true, + TransitionType: "FAILING_OVER", + StartTime: time.Now(), + ShardID: shardID, + TimeoutSeconds: timeoutSeconds, + OriginalReadTimeout: originalReadTimeout, + OriginalWriteTimeout: originalWriteTimeout, + LastCommandTime: time.Now(), + } + + internal.Logger.Printf(ctx, "hitless: connection %p marked as FAILING_OVER with increased timeouts", conn) + + return nil +} + +// handleFailedOverNotification processes FAILED_OVER notifications for specific connections +// Format: ["FAILED_OVER", shard_id] +func (h *UpgradeHandler) handleFailedOverNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 2 { + return fmt.Errorf("hitless: FAILED_OVER notification requires at least 2 elements, got %d", len(notification)) + } + + shardID, err := h.parseShardID(notification[1]) + if err != nil { + return fmt.Errorf("hitless: failed to parse shard ID for FAILED_OVER notification: %w", err) + } + + conn := handlerCtx.Conn + if conn == nil { + return fmt.Errorf("hitless: no connection available in handler context") + } + + internal.Logger.Printf(ctx, "hitless: FAILED_OVER notification - shard %s failover completed on connection %p", shardID, conn) + + h.mu.Lock() + defer h.mu.Unlock() + + // Clear the transitioning state for this connection and restore original timeouts + if state, exists := h.connectionStates[conn]; exists && state.TransitionType == "FAILING_OVER" && state.ShardID == shardID { + internal.Logger.Printf(ctx, "hitless: restoring original timeouts for connection %p (read: %v, write: %v)", + conn, state.OriginalReadTimeout, state.OriginalWriteTimeout) + + // In a real implementation, this would restore the connection's original timeouts + // For now, we'll just log and delete the state + delete(h.connectionStates, conn) + internal.Logger.Printf(ctx, "hitless: cleared FAILING_OVER state and restored timeouts for connection %p", conn) + } + + return nil +} + +// parseTimeoutSeconds parses timeout value from notification +func (h *UpgradeHandler) parseTimeoutSeconds(value interface{}) (int, error) { + switch v := value.(type) { + case int64: + return int(v), nil + case int: + return v, nil + case string: + return strconv.Atoi(v) + default: + return 0, fmt.Errorf("unsupported timeout type: %T", value) + } +} + +// parseShardID parses shard ID from notification +func (h *UpgradeHandler) parseShardID(value interface{}) (string, error) { + switch v := value.(type) { + case string: + return v, nil + case int64: + return strconv.FormatInt(v, 10), nil + case int: + return strconv.Itoa(v), nil + default: + return "", fmt.Errorf("unsupported shard ID type: %T", value) + } +} + +// GetConnectionState returns the current state of a connection +func (h *UpgradeHandler) GetConnectionState(conn *pool.Conn) (*ConnectionState, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + state, exists := h.connectionStates[conn] + if !exists { + return nil, false + } + + // Return a copy to avoid race conditions + stateCopy := *state + return &stateCopy, true +} + +// IsConnectionTransitioning checks if a connection is currently transitioning +func (h *UpgradeHandler) IsConnectionTransitioning(conn *pool.Conn) bool { + h.mu.RLock() + defer h.mu.RUnlock() + + state, exists := h.connectionStates[conn] + return exists && state.IsTransitioning +} + +// IsPoolMoving and GetNewEndpoint removed - using atomic state in ClusterUpgradeManager instead + +// CleanupExpiredStates removes expired connection and pool states +func (h *UpgradeHandler) CleanupExpiredStates() { + h.mu.Lock() + defer h.mu.Unlock() + + now := time.Now() + + // Cleanup expired connection states + for conn, state := range h.connectionStates { + timeout := time.Duration(state.TimeoutSeconds) * time.Second + if now.Sub(state.StartTime) > timeout { + delete(h.connectionStates, conn) + } + } + + // Pool state cleanup removed - using atomic state in ClusterUpgradeManager instead +} + +// CleanupConnection removes state for a specific connection (called when connection is closed) +func (h *UpgradeHandler) CleanupConnection(conn *pool.Conn) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.connectionStates, conn) +} + +// GetActiveTransitions returns information about all active connection transitions +func (h *UpgradeHandler) GetActiveTransitions() map[*pool.Conn]*ConnectionState { + h.mu.RLock() + defer h.mu.RUnlock() + + // Create copies to avoid race conditions + connStates := make(map[*pool.Conn]*ConnectionState) + for conn, state := range h.connectionStates { + stateCopy := *state + connStates[conn] = &stateCopy + } + + return connStates +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 22f8ea6a7..507e1b61b 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -565,3 +565,24 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool { cn.SetUsedAt(now) return true } + +// GetDialer returns the current dialer function for the pool +func (p *ConnPool) GetDialer() func(context.Context) (net.Conn, error) { + p.connsMu.Lock() + defer p.connsMu.Unlock() + return p.cfg.Dialer +} + +// SetDialer sets a new dialer function for the pool +// This is used for hitless upgrades to redirect new connections to new endpoints +func (p *ConnPool) SetDialer(dialer func(context.Context) (net.Conn, error)) error { + if p.closed() { + return ErrClosed + } + + p.connsMu.Lock() + defer p.connsMu.Unlock() + + p.cfg.Dialer = dialer + return nil +} diff --git a/internal/proto/peek_push_notification_test.go b/internal/proto/peek_push_notification_test.go index 338826e7d..58a794b84 100644 --- a/internal/proto/peek_push_notification_test.go +++ b/internal/proto/peek_push_notification_test.go @@ -3,6 +3,7 @@ package proto import ( "bytes" "fmt" + "math/rand" "strings" "testing" ) @@ -215,9 +216,9 @@ func TestPeekPushNotificationName(t *testing.T) { // This is acceptable behavior for malformed input name, err := reader.PeekPushNotificationName() if err != nil { - t.Logf("PeekPushNotificationName errored for corrupted data %s: %v", tc.name, err) + t.Logf("PeekPushNotificationName errored for corrupted data %s: %v (DATA: %s)", tc.name, err, tc.data) } else { - t.Logf("PeekPushNotificationName returned '%s' for corrupted data %s", name, tc.name) + t.Logf("PeekPushNotificationName returned '%s' for corrupted data NAME: %s, DATA: %s", name, tc.name, tc.data) } }) } @@ -293,15 +294,27 @@ func TestPeekPushNotificationName(t *testing.T) { func createValidPushNotification(notificationName, data string) *bytes.Buffer { buf := &bytes.Buffer{} + simpleOrString := rand.Intn(2) == 0 + if data == "" { + // Single element notification buf.WriteString(">1\r\n") - buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName)) + if simpleOrString { + buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName)) + } else { + buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName)) + } } else { // Two element notification buf.WriteString(">2\r\n") - buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName)) - buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(data), data)) + if simpleOrString { + buf.WriteString(fmt.Sprintf("+%s\r\n", notificationName)) + buf.WriteString(fmt.Sprintf("+%s\r\n", data)) + } else { + buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName)) + buf.WriteString(fmt.Sprintf("$%d\r\n%s\r\n", len(notificationName), notificationName)) + } } return buf diff --git a/internal/proto/reader.go b/internal/proto/reader.go index fa63f9e29..86bd32d7c 100644 --- a/internal/proto/reader.go +++ b/internal/proto/reader.go @@ -116,26 +116,55 @@ func (r *Reader) PeekPushNotificationName() (string, error) { if buf[0] != RespPush { return "", fmt.Errorf("redis: can't parse push notification: %q", buf) } - // remove push notification type and length - buf = buf[2:] + + if len(buf) < 3 { + return "", fmt.Errorf("redis: can't parse push notification: %q", buf) + } + + // remove push notification type + buf = buf[1:] + // remove first line - e.g. >2\r\n for i := 0; i < len(buf)-1; i++ { if buf[i] == '\r' && buf[i+1] == '\n' { buf = buf[i+2:] break + } else { + if buf[i] < '0' || buf[i] > '9' { + return "", fmt.Errorf("redis: can't parse push notification: %q", buf) + } } } + if len(buf) < 2 { + return "", fmt.Errorf("redis: can't parse push notification: %q", buf) + } + // next line should be $\r\n or +\r\n // should have the type of the push notification name and it's length - if buf[0] != RespString { + if buf[0] != RespString && buf[0] != RespStatus { return "", fmt.Errorf("redis: can't parse push notification name: %q", buf) } - // skip the length of the string - for i := 0; i < len(buf)-1; i++ { - if buf[i] == '\r' && buf[i+1] == '\n' { - buf = buf[i+2:] - break + typeOfName := buf[0] + // remove the type of the push notification name + buf = buf[1:] + if typeOfName == RespString { + // remove the length of the string + if len(buf) < 2 { + return "", fmt.Errorf("redis: can't parse push notification name: %q", buf) + } + for i := 0; i < len(buf)-1; i++ { + if buf[i] == '\r' && buf[i+1] == '\n' { + buf = buf[i+2:] + break + } else { + if buf[i] < '0' || buf[i] > '9' { + return "", fmt.Errorf("redis: can't parse push notification name: %q", buf) + } + } } } + if len(buf) < 2 { + return "", fmt.Errorf("redis: can't parse push notification name: %q", buf) + } // keep only the notification name for i := 0; i < len(buf)-1; i++ { if buf[i] == '\r' && buf[i+1] == '\n' { @@ -143,6 +172,7 @@ func (r *Reader) PeekPushNotificationName() (string, error) { break } } + return util.BytesToString(buf), nil } diff --git a/options.go b/options.go index 00568c6c9..61448ac2c 100644 --- a/options.go +++ b/options.go @@ -224,6 +224,18 @@ type Options struct { // PushNotificationProcessor is the processor for handling push notifications. // If nil, a default processor will be created for RESP3 connections. PushNotificationProcessor push.NotificationProcessor + + // HitlessUpgrades enables hitless upgrade functionality for cluster upgrades. + // Requires Protocol: 3 (RESP3) for push notifications. + // When enabled, the client will automatically handle cluster upgrade notifications + // and manage connection/pool state transitions seamlessly. + // + // default: false + HitlessUpgrades bool + + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // If nil, default configuration will be used when HitlessUpgrades is true. + HitlessUpgradeConfig *HitlessUpgradeConfig } func (opt *Options) init() { diff --git a/osscluster.go b/osscluster.go index bfcc39fcc..107386226 100644 --- a/osscluster.go +++ b/osscluster.go @@ -110,6 +110,18 @@ type ClusterOptions struct { // UnstableResp3 enables Unstable mode for Redis Search module with RESP3. UnstableResp3 bool + + // HitlessUpgrades enables hitless upgrade functionality for cluster upgrades. + // Requires Protocol: 3 (RESP3) for push notifications. + // When enabled, the client will automatically handle cluster upgrade notifications + // and manage connection/pool state transitions seamlessly. + // + // default: false + HitlessUpgrades bool + + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // If nil, default configuration will be used when HitlessUpgrades is true. + HitlessUpgradeConfig *HitlessUpgradeConfig } func (opt *ClusterOptions) init() { @@ -327,8 +339,10 @@ func (opt *ClusterOptions) clientOptions() *Options { // much use for ClusterSlots config). This means we cannot execute the // READONLY command against that node -- setting readOnly to false in such // situations in the options below will prevent that from happening. - readOnly: opt.ReadOnly && opt.ClusterSlots == nil, - UnstableResp3: opt.UnstableResp3, + readOnly: opt.ReadOnly && opt.ClusterSlots == nil, + UnstableResp3: opt.UnstableResp3, + HitlessUpgrades: opt.HitlessUpgrades, + HitlessUpgradeConfig: opt.HitlessUpgradeConfig, } } @@ -943,6 +957,9 @@ type ClusterClient struct { cmdsInfoCache *cmdsInfoCache cmdable hooksMixin + + // hitlessIntegration provides hitless upgrade functionality + hitlessIntegration HitlessIntegration } // NewClusterClient returns a Redis Cluster client as described in @@ -969,6 +986,22 @@ func NewClusterClient(opt *ClusterOptions) *ClusterClient { txPipeline: c.processTxPipeline, }) + // Initialize hitless upgrades if enabled + if opt.HitlessUpgrades { + if opt.Protocol != 3 { + internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol) + } else { + timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout) + integration, err := initializeHitlessIntegration(c, opt.HitlessUpgradeConfig, timeoutProvider) + if err != nil { + internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err) + } else { + c.hitlessIntegration = integration + internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for cluster client") + } + } + } + return c } @@ -977,6 +1010,14 @@ func (c *ClusterClient) Options() *ClusterOptions { return c.opt } +// GetHitlessIntegration returns the hitless integration instance for monitoring and control. +// Returns nil if hitless upgrades are not enabled. +func (c *ClusterClient) GetHitlessIntegration() HitlessIntegration { + return c.hitlessIntegration +} + +// getPushNotificationProcessor removed - not needed in simplified implementation + // ReloadState reloads cluster state. If available it calls ClusterSlots func // to get cluster slots information. func (c *ClusterClient) ReloadState(ctx context.Context) { @@ -988,6 +1029,13 @@ func (c *ClusterClient) ReloadState(ctx context.Context) { // It is rare to Close a ClusterClient, as the ClusterClient is meant // to be long-lived and shared between many goroutines. func (c *ClusterClient) Close() error { + // Close hitless integration first + if c.hitlessIntegration != nil { + if err := c.hitlessIntegration.Close(); err != nil { + internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err) + } + } + return c.nodes.Close() } diff --git a/push/handler_context.go b/push/handler_context.go index 3bcf128f1..e763f9504 100644 --- a/push/handler_context.go +++ b/push/handler_context.go @@ -13,9 +13,6 @@ type NotificationHandlerContext struct { // circular dependencies. The developer is responsible for type assertion. // It can be one of the following types: // - *redis.baseClient - // - *redis.Client - // - *redis.ClusterClient - // - *redis.Conn Client interface{} // ConnPool is the connection pool from which the connection was obtained. @@ -25,7 +22,7 @@ type NotificationHandlerContext struct { // - *pool.ConnPool // - *pool.SingleConnPool // - *pool.StickyConnPool - ConnPool interface{} + ConnPool pool.Pooler // PubSub is the PubSub instance that received the notification. // It is interface to both allow for future expansion and to avoid diff --git a/redis.go b/redis.go index 43673863f..3dbc071f5 100644 --- a/redis.go +++ b/redis.go @@ -2,6 +2,7 @@ package redis import ( "context" + "crypto/tls" "errors" "fmt" "net" @@ -211,6 +212,9 @@ type baseClient struct { // Push notification processing pushProcessor push.NotificationProcessor + + // hitlessIntegration provides hitless upgrade functionality + hitlessIntegration HitlessIntegration } func (c *baseClient) clone() *baseClient { @@ -466,7 +470,16 @@ func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) // Push notification processing errors shouldn't break normal Redis operations internal.Logger.Printf(ctx, "push: error processing pending notifications before releasing connection: %v", err) } - c.connPool.Put(ctx, cn) + + // Check if connection is marked for closing due to hitless upgrades + if c.hitlessIntegration != nil && c.hitlessIntegration.IsConnectionMarkedForClosing(cn) { + // Connection is marked for closing (e.g., during MOVING state) + // Remove it instead of putting it back in the pool + internal.Logger.Printf(ctx, "hitless: closing connection marked for closure during upgrade") + c.connPool.Remove(ctx, cn, nil) + } else { + c.connPool.Put(ctx, cn) + } } } @@ -528,7 +541,33 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool } retryTimeout := uint32(0) - if err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { + + // Check if this is a blocking command that needs redirection + isBlockingCommand := cmd.readTimeout() != nil + var redirectedConn *pool.Conn + var shouldRedirect bool + var newEndpoint string + + if c.hitlessIntegration != nil && isBlockingCommand { + // For blocking commands, check if we need to redirect to a new endpoint + // This happens during MOVING state when the endpoint is changing + shouldRedirect, newEndpoint = c.hitlessIntegration.ShouldRedirectBlockingConnection(nil) + if shouldRedirect { + // Create a new connection to the new endpoint + var err error + redirectedConn, err = c.createConnectionToEndpoint(ctx, newEndpoint) + if err != nil { + internal.Logger.Printf(ctx, "hitless: failed to create redirected connection to %s: %v", newEndpoint, err) + // Fall back to normal connection if redirection fails + shouldRedirect = false + } else { + internal.Logger.Printf(ctx, "hitless: redirecting blocking command %s to new endpoint %s", cmd.Name(), newEndpoint) + } + } + } + + // Use redirected connection if available, otherwise use normal connection + connFunc := func(ctx context.Context, cn *pool.Conn) error { // Process any pending push notifications before executing the command if err := c.processPushNotifications(ctx, cn); err != nil { // Log the error but don't fail the command execution @@ -536,7 +575,19 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool internal.Logger.Printf(ctx, "push: error processing pending notifications before command: %v", err) } - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + // Mark connection as blocking if this is a blocking command + if c.hitlessIntegration != nil && isBlockingCommand { + c.hitlessIntegration.MarkConnectionAsBlocking(cn, true) + internal.Logger.Printf(ctx, "hitless: marked connection as blocking for command %s", cmd.Name()) + } + + // Get appropriate write timeout for this connection + writeTimeout := c.opt.WriteTimeout + if c.hitlessIntegration != nil { + _, writeTimeout = c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout) + } + + if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error { return writeCmd(wr, cmd) }); err != nil { atomic.StoreUint32(&retryTimeout, 1) @@ -547,7 +598,7 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool if c.opt.Protocol != 2 && c.assertUnstableCommand(cmd) { readReplyFunc = cmd.readRawReply } - if err := cn.WithReader(c.context(ctx), c.cmdTimeout(cmd), func(rd *proto.Reader) error { + if err := cn.WithReader(c.context(ctx), c.cmdTimeoutForConnection(cmd, cn), func(rd *proto.Reader) error { // To be sure there are no buffered push notifications, we process them before reading the reply if err := c.processPendingPushNotificationWithReader(ctx, cn, rd); err != nil { // Log the error but don't fail the command execution @@ -564,8 +615,31 @@ func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool return err } + // Unmark connection as blocking after command completes + if c.hitlessIntegration != nil && isBlockingCommand { + c.hitlessIntegration.MarkConnectionAsBlocking(cn, false) + internal.Logger.Printf(ctx, "hitless: unmarked connection as blocking after command %s completed", cmd.Name()) + } + return nil - }); err != nil { + } + + // Execute the command with either redirected or normal connection + var err error + if shouldRedirect && redirectedConn != nil { + // Use the redirected connection for blocking command + err = connFunc(ctx, redirectedConn) + // Close the redirected connection after use + defer func() { + redirectedConn.Close() + internal.Logger.Printf(ctx, "hitless: closed redirected connection to %s after blocking command completed", newEndpoint) + }() + } else { + // Use normal connection pool + err = c.withConn(ctx, connFunc) + } + + if err != nil { retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) return retry, err } @@ -588,6 +662,70 @@ func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { return c.opt.ReadTimeout } +// cmdTimeoutForConnection returns the appropriate read timeout for a specific connection +// taking into account hitless upgrade state +func (c *baseClient) cmdTimeoutForConnection(cmd Cmder, cn *pool.Conn) time.Duration { + baseTimeout := c.cmdTimeout(cmd) + + // If hitless upgrades are enabled, get dynamic timeout based on connection state + if c.hitlessIntegration != nil { + // For blocking commands, use the command's timeout but check if connection needs increased timeout + if cmd.readTimeout() != nil { + // For blocking commands, use the base timeout but apply hitless upgrade adjustments + adjustedTimeout := c.hitlessIntegration.GetConnectionTimeout(cn, baseTimeout) + return adjustedTimeout + } else { + // For regular commands, get both read and write timeouts (use read timeout for command) + readTimeout, _ := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout) + return readTimeout + } + } + + return baseTimeout +} + +// createConnectionToEndpoint creates a new connection to a specific endpoint +// This is used for redirecting blocking commands during MOVING state +func (c *baseClient) createConnectionToEndpoint(ctx context.Context, endpoint string) (*pool.Conn, error) { + // Parse the endpoint to get host and port + addr := endpoint + if addr == "" { + return nil, fmt.Errorf("empty endpoint provided") + } + + // Create a temporary dialer for the new endpoint + dialer := func(ctx context.Context) (net.Conn, error) { + netDialer := &net.Dialer{ + Timeout: c.opt.DialTimeout, + KeepAlive: 30 * time.Second, + } + + if c.opt.TLSConfig == nil { + return netDialer.DialContext(ctx, c.opt.Network, addr) + } + + return tls.DialWithDialer(netDialer, c.opt.Network, addr, c.opt.TLSConfig) + } + + // Create a new connection using the dialer + netConn, err := dialer(ctx) + if err != nil { + return nil, fmt.Errorf("failed to dial new endpoint %s: %w", endpoint, err) + } + + // Wrap in pool.Conn + cn := pool.NewConn(netConn) + + // Initialize the connection (auth, select db, etc.) + if err := c.initConn(ctx, cn); err != nil { + cn.Close() + return nil, fmt.Errorf("failed to initialize connection to %s: %w", endpoint, err) + } + + internal.Logger.Printf(ctx, "hitless: created new connection to endpoint %s for blocking command redirection", endpoint) + return cn, nil +} + // context returns the context for the current connection. // If the context timeout is enabled, it returns the original context. // Otherwise, it returns a new background context. @@ -604,8 +742,19 @@ func (c *baseClient) context(ctx context.Context) context.Context { // long-lived and shared between many goroutines. func (c *baseClient) Close() error { var firstErr error + + // Close hitless integration first + if c.hitlessIntegration != nil { + if err := c.hitlessIntegration.Close(); err != nil { + internal.Logger.Printf(context.Background(), "hitless: error closing hitless integration: %v", err) + if firstErr == nil { + firstErr = err + } + } + } + if c.onClose != nil { - if err := c.onClose(); err != nil { + if err := c.onClose(); err != nil && firstErr == nil { firstErr = err } } @@ -678,14 +827,15 @@ func (c *baseClient) pipelineProcessCmds( internal.Logger.Printf(ctx, "push: error processing pending notifications before pipeline: %v", err) } - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout) + if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { setCmdsErr(cmds, err) return true, err } - if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error { // read all replies return c.pipelineReadCmds(ctx, cn, rd, cmds) }); err != nil { @@ -725,14 +875,15 @@ func (c *baseClient) txPipelineProcessCmds( internal.Logger.Printf(ctx, "push: error processing pending notifications before transaction: %v", err) } - if err := cn.WithWriter(c.context(ctx), c.opt.WriteTimeout, func(wr *proto.Writer) error { + readTimeout, writeTimeout := c.hitlessIntegration.GetConnectionTimeouts(cn, c.opt.ReadTimeout, c.opt.WriteTimeout) + if err := cn.WithWriter(c.context(ctx), writeTimeout, func(wr *proto.Writer) error { return writeCmds(wr, cmds) }); err != nil { setCmdsErr(cmds, err) return true, err } - if err := cn.WithReader(c.context(ctx), c.opt.ReadTimeout, func(rd *proto.Reader) error { + if err := cn.WithReader(c.context(ctx), readTimeout, func(rd *proto.Reader) error { statusCmd := cmds[0].(*StatusCmd) // Trim multi and exec. trimmedCmds := cmds[1 : len(cmds)-1] @@ -837,6 +988,22 @@ func NewClient(opt *Options) *Client { c.connPool = newConnPool(opt, c.dialHook) + // Initialize hitless upgrades if enabled + if opt.HitlessUpgrades { + if opt.Protocol != 3 { + internal.Logger.Printf(context.Background(), "hitless: RESP3 protocol required for hitless upgrades, but Protocol is %d", opt.Protocol) + } else { + timeoutProvider := newOptionsTimeoutProvider(opt.ReadTimeout, opt.WriteTimeout) + integration, err := initializeHitlessIntegration(&c, opt.HitlessUpgradeConfig, timeoutProvider) + if err != nil { + internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err) + } else { + c.hitlessIntegration = integration + internal.Logger.Printf(context.Background(), "hitless: successfully initialized hitless upgrades for client") + } + } + } + return &c } @@ -857,6 +1024,12 @@ func (c *Client) WithTimeout(timeout time.Duration) *Client { return &clone } +// GetHitlessIntegration returns the hitless integration instance for monitoring and control. +// Returns nil if hitless upgrades are not enabled. +func (c *Client) GetHitlessIntegration() HitlessIntegration { + return c.hitlessIntegration +} + func (c *Client) Conn() *Conn { return newConn(c.opt, pool.NewStickyConnPool(c.connPool), &c.hooksMixin) }