From 5649ffb3143936924e5f38bf67b847bb45af8c78 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 18 Aug 2025 22:14:06 +0300 Subject: [PATCH 01/21] feat(hitless): Introduce handlers for hitless upgrades This commit includes all the work on hitless upgrades with the addition of: - Pubsub Pool - Examples - Refactor of push - Refactor of pool (using atomics for most things) - Introducing of hooks in pool --- .gitignore | 3 + adapters.go | 149 +++++ async_handoff_integration_test.go | 348 +++++++++++ commands.go | 18 + example/pubsub/go.mod | 12 + example/pubsub/go.sum | 6 + example/pubsub/main.go | 146 +++++ example_instrumentation_test.go | 6 + hitless/README.md | 72 +++ hitless/config.go | 377 ++++++++++++ hitless/config_test.go | 427 +++++++++++++ hitless/errors.go | 76 +++ hitless/example_hooks.go | 63 ++ hitless/hitless_manager.go | 299 ++++++++++ hitless/hitless_manager_test.go | 260 ++++++++ hitless/hooks.go | 48 ++ hitless/notification_handler.go | 247 ++++++++ hitless/pool_hook.go | 477 +++++++++++++++ hitless/pool_hook_test.go | 959 ++++++++++++++++++++++++++++++ hitless/state.go | 24 + internal/interfaces/interfaces.go | 67 +++ internal/pool/bench_test.go | 7 +- internal/pool/buffer_size_test.go | 8 +- internal/pool/conn.go | 469 ++++++++++++++- internal/pool/export_test.go | 2 +- internal/pool/hooks.go | 114 ++++ internal/pool/hooks_test.go | 213 +++++++ internal/pool/pool.go | 385 ++++++++---- internal/pool/pool_single.go | 8 +- internal/pool/pool_sticky.go | 4 + internal/pool/pool_test.go | 39 +- internal/pool/pubsub.go | 77 +++ internal/redis.go | 3 + internal/util/math.go | 17 + options.go | 81 ++- osscluster.go | 42 +- pool_pubsub_bench_test.go | 375 ++++++++++++ pubsub.go | 40 +- push/handler_context.go | 11 +- push/processor_unit_test.go | 315 ++++++++++ push_notifications.go | 18 - redis.go | 211 ++++++- redis_test.go | 1 - sentinel.go | 52 +- tx.go | 2 +- universal.go | 14 +- 46 files changed, 6345 insertions(+), 247 deletions(-) create mode 100644 adapters.go create mode 100644 async_handoff_integration_test.go create mode 100644 example/pubsub/go.mod create mode 100644 example/pubsub/go.sum create mode 100644 example/pubsub/main.go create mode 100644 hitless/README.md create mode 100644 hitless/config.go create mode 100644 hitless/config_test.go create mode 100644 hitless/errors.go create mode 100644 hitless/example_hooks.go create mode 100644 hitless/hitless_manager.go create mode 100644 hitless/hitless_manager_test.go create mode 100644 hitless/hooks.go create mode 100644 hitless/notification_handler.go create mode 100644 hitless/pool_hook.go create mode 100644 hitless/pool_hook_test.go create mode 100644 hitless/state.go create mode 100644 internal/interfaces/interfaces.go create mode 100644 internal/pool/hooks.go create mode 100644 internal/pool/hooks_test.go create mode 100644 internal/pool/pubsub.go create mode 100644 internal/redis.go create mode 100644 internal/util/math.go create mode 100644 pool_pubsub_bench_test.go create mode 100644 push/processor_unit_test.go diff --git a/.gitignore b/.gitignore index 0d99709e34..5fe0716e29 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,6 @@ coverage.txt **/coverage.txt .vscode tmp/* + +# Hitless upgrade documentation (temporary) +hitless/docs/ diff --git a/adapters.go b/adapters.go new file mode 100644 index 0000000000..6f123e212b --- /dev/null +++ b/adapters.go @@ -0,0 +1,149 @@ +package redis + +import ( + "context" + "errors" + "net" + "time" + + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// ErrInvalidCommand is returned when an invalid command is passed to ExecuteCommand. +var ErrInvalidCommand = errors.New("invalid command type") + +// ErrInvalidPool is returned when the pool type is not supported. +var ErrInvalidPool = errors.New("invalid pool type") + +// newClientAdapter creates a new client adapter for regular Redis clients. +func newClientAdapter(client *baseClient) interfaces.ClientInterface { + return &clientAdapter{client: client} +} + +// clientAdapter adapts a Redis client to implement interfaces.ClientInterface. +type clientAdapter struct { + client *baseClient +} + +// GetOptions returns the client options. +func (ca *clientAdapter) GetOptions() interfaces.OptionsInterface { + return &optionsAdapter{options: ca.client.opt} +} + +// GetPushProcessor returns the client's push notification processor. +func (ca *clientAdapter) GetPushProcessor() interfaces.NotificationProcessor { + return &pushProcessorAdapter{processor: ca.client.pushProcessor} +} + +// optionsAdapter adapts Redis options to implement interfaces.OptionsInterface. +type optionsAdapter struct { + options *Options +} + +// GetReadTimeout returns the read timeout. +func (oa *optionsAdapter) GetReadTimeout() time.Duration { + return oa.options.ReadTimeout +} + +// GetWriteTimeout returns the write timeout. +func (oa *optionsAdapter) GetWriteTimeout() time.Duration { + return oa.options.WriteTimeout +} + +// GetNetwork returns the network type. +func (oa *optionsAdapter) GetNetwork() string { + return oa.options.Network +} + +// GetAddr returns the connection address. +func (oa *optionsAdapter) GetAddr() string { + return oa.options.Addr +} + +// IsTLSEnabled returns true if TLS is enabled. +func (oa *optionsAdapter) IsTLSEnabled() bool { + return oa.options.TLSConfig != nil +} + +// GetProtocol returns the protocol version. +func (oa *optionsAdapter) GetProtocol() int { + return oa.options.Protocol +} + +// GetPoolSize returns the connection pool size. +func (oa *optionsAdapter) GetPoolSize() int { + return oa.options.PoolSize +} + +// NewDialer returns a new dialer function for the connection. +func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) { + baseDialer := oa.options.NewDialer() + return func(ctx context.Context) (net.Conn, error) { + // Extract network and address from the options + network := oa.options.Network + addr := oa.options.Addr + return baseDialer(ctx, network, addr) + } +} + +// connectionAdapter adapts a Redis connection to interfaces.ConnectionWithRelaxedTimeout +type connectionAdapter struct { + conn *pool.Conn +} + +// Close closes the connection. +func (ca *connectionAdapter) Close() error { + return ca.conn.Close() +} + +// IsUsable returns true if the connection is safe to use for new commands. +func (ca *connectionAdapter) IsUsable() bool { + return ca.conn.IsUsable() +} + +// GetPoolConnection returns the underlying pool connection. +func (ca *connectionAdapter) GetPoolConnection() *pool.Conn { + return ca.conn +} + +// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades. +// These timeouts remain active until explicitly cleared. +func (ca *connectionAdapter) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { + ca.conn.SetRelaxedTimeout(readTimeout, writeTimeout) +} + +// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. +// After the deadline, timeouts automatically revert to normal values. +func (ca *connectionAdapter) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { + ca.conn.SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout, deadline) +} + +// ClearRelaxedTimeout clears relaxed timeouts for this connection. +func (ca *connectionAdapter) ClearRelaxedTimeout() { + ca.conn.ClearRelaxedTimeout() +} + +// pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor. +type pushProcessorAdapter struct { + processor push.NotificationProcessor +} + +// RegisterHandler registers a handler for a specific push notification name. +func (ppa *pushProcessorAdapter) RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error { + if pushHandler, ok := handler.(push.NotificationHandler); ok { + return ppa.processor.RegisterHandler(pushNotificationName, pushHandler, protected) + } + return errors.New("handler must implement push.NotificationHandler") +} + +// UnregisterHandler removes a handler for a specific push notification name. +func (ppa *pushProcessorAdapter) UnregisterHandler(pushNotificationName string) error { + return ppa.processor.UnregisterHandler(pushNotificationName) +} + +// GetHandler returns the handler for a specific push notification name. +func (ppa *pushProcessorAdapter) GetHandler(pushNotificationName string) interface{} { + return ppa.processor.GetHandler(pushNotificationName) +} diff --git a/async_handoff_integration_test.go b/async_handoff_integration_test.go new file mode 100644 index 0000000000..9cf553dc77 --- /dev/null +++ b/async_handoff_integration_test.go @@ -0,0 +1,348 @@ +package redis + +import ( + "context" + "net" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9/hitless" + "github.com/redis/go-redis/v9/internal/pool" +) + +// mockNetConn implements net.Conn for testing +type mockNetConn struct { + addr string +} + +func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil } +func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockNetConn) Close() error { return nil } +func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) SetDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type mockAddr struct { + addr string +} + +func (m *mockAddr) Network() string { return "tcp" } +func (m *mockAddr) String() string { return m.addr } + +// TestEventDrivenHandoffIntegration tests the complete event-driven handoff flow +func TestEventDrivenHandoffIntegration(t *testing.T) { + t.Run("EventDrivenHandoffWithPoolSkipping", func(t *testing.T) { + // Create a base dialer for testing + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + // Create processor with event-driven handoff support + processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Create a test pool with hooks + hookManager := pool.NewPoolHookManager() + hookManager.AddHook(processor) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + PoolSize: int32(5), + PoolTimeout: time.Second, + }) + + // Add the hook to the pool after creation + testPool.AddPoolHook(processor) + defer testPool.Close() + + // Set the pool reference in the processor for connection removal on handoff failure + processor.SetPool(testPool) + + ctx := context.Background() + + // Get a connection and mark it for handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + // Set initialization function with a small delay to ensure handoff is pending + initConnCalled := false + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending + initConnCalled = true + return nil + } + conn.SetInitConnFunc(initConnFunc) + + // Mark connection for handoff + err = conn.MarkForHandoff("new-endpoint:6379", 12345) + if err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Return connection to pool - this should queue handoff + testPool.Put(ctx, conn) + + // Give the on-demand worker a moment to start processing + time.Sleep(10 * time.Millisecond) + + // Verify handoff was queued + if !processor.IsHandoffPending(conn) { + t.Error("Handoff should be queued in pending map") + } + + // Try to get the same connection - should be skipped due to pending handoff + conn2, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get second connection: %v", err) + } + + // Should get a different connection (the pending one should be skipped) + if conn == conn2 { + t.Error("Should have gotten a different connection while handoff is pending") + } + + // Return the second connection + testPool.Put(ctx, conn2) + + // Wait for handoff to complete + time.Sleep(200 * time.Millisecond) + + // Verify handoff completed (removed from pending map) + if processor.IsHandoffPending(conn) { + t.Error("Handoff should have completed and been removed from pending map") + } + + if !initConnCalled { + t.Error("InitConn should have been called during handoff") + } + + // Now the original connection should be available again + conn3, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get third connection: %v", err) + } + + // Could be the original connection (now handed off) or a new one + testPool.Put(ctx, conn3) + }) + + t.Run("ConcurrentHandoffs", func(t *testing.T) { + // Create a base dialer that simulates slow handoffs + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + time.Sleep(50 * time.Millisecond) // Simulate network delay + return &mockNetConn{addr: addr}, nil + } + + processor := hitless.NewPoolHook(baseDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Create hooks manager and add processor as hook + hookManager := pool.NewPoolHookManager() + hookManager.AddHook(processor) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + + PoolSize: int32(10), + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Add the hook to the pool after creation + testPool.AddPoolHook(processor) + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + var wg sync.WaitGroup + + // Start multiple concurrent handoffs + for i := 0; i < 5; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + + // Get connection + conn, err := testPool.Get(ctx) + if err != nil { + t.Errorf("Failed to get connection %d: %v", id, err) + return + } + + // Set initialization function + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + return nil + } + conn.SetInitConnFunc(initConnFunc) + + // Mark for handoff + conn.MarkForHandoff("new-endpoint:6379", int64(id)) + + // Return to pool (starts async handoff) + testPool.Put(ctx, conn) + }(i) + } + + wg.Wait() + + // Wait for all handoffs to complete + time.Sleep(300 * time.Millisecond) + + // Verify pool is still functional + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Pool should still be functional after concurrent handoffs: %v", err) + } + testPool.Put(ctx, conn) + }) + + t.Run("HandoffFailureRecovery", func(t *testing.T) { + // Create a failing base dialer + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, &net.OpError{Op: "dial", Err: &net.DNSError{Name: addr}} + } + + processor := hitless.NewPoolHook(failingDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Create hooks manager and add processor as hook + hookManager := pool.NewPoolHookManager() + hookManager.AddHook(processor) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + + PoolSize: int32(3), + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Add the hook to the pool after creation + testPool.AddPoolHook(processor) + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + + // Get connection and mark for handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + conn.MarkForHandoff("unreachable-endpoint:6379", 12345) + + // Return to pool (starts async handoff that will fail) + testPool.Put(ctx, conn) + + // Wait for handoff to fail + time.Sleep(200 * time.Millisecond) + + // Connection should be removed from pending map after failed handoff + if processor.IsHandoffPending(conn) { + t.Error("Connection should be removed from pending map after failed handoff") + } + + // Pool should still be functional + conn2, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Pool should still be functional: %v", err) + } + + // In event-driven approach, the original connection remains in pool + // even after failed handoff (it's still a valid connection) + // We might get the same connection or a different one + testPool.Put(ctx, conn2) + }) + + t.Run("GracefulShutdown", func(t *testing.T) { + // Create a slow base dialer + slowDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + time.Sleep(100 * time.Millisecond) + return &mockNetConn{addr: addr}, nil + } + + processor := hitless.NewPoolHook(slowDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Create hooks manager and add processor as hook + hookManager := pool.NewPoolHookManager() + hookManager.AddHook(processor) + + testPool := pool.NewConnPool(&pool.Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &mockNetConn{addr: "original:6379"}, nil + }, + + PoolSize: int32(2), + PoolTimeout: time.Second, + }) + defer testPool.Close() + + // Add the hook to the pool after creation + testPool.AddPoolHook(processor) + + // Set the pool reference in the processor + processor.SetPool(testPool) + + ctx := context.Background() + + // Start a handoff + conn, err := testPool.Get(ctx) + if err != nil { + t.Fatalf("Failed to get connection: %v", err) + } + + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function with delay to ensure handoff is pending + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + time.Sleep(50 * time.Millisecond) // Add delay to keep handoff pending + return nil + }) + + testPool.Put(ctx, conn) + + // Give the on-demand worker a moment to start and begin processing + // The handoff should be pending because the slowDialer takes 100ms + time.Sleep(10 * time.Millisecond) + + // Verify handoff was queued and is being processed + if !processor.IsHandoffPending(conn) { + t.Error("Handoff should be queued in pending map") + } + + // Give the handoff a moment to start processing + time.Sleep(50 * time.Millisecond) + + // Shutdown processor gracefully + // Use a longer timeout to account for slow dialer (100ms) plus processing overhead + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err = processor.Shutdown(shutdownCtx) + if err != nil { + t.Errorf("Graceful shutdown should succeed: %v", err) + } + + // Handoff should have completed (removed from pending map) + if processor.IsHandoffPending(conn) { + t.Error("Handoff should have completed and been removed from pending map after shutdown") + } + }) +} diff --git a/commands.go b/commands.go index c0358001d1..3a1cfdef79 100644 --- a/commands.go +++ b/commands.go @@ -193,6 +193,7 @@ type Cmdable interface { ClientID(ctx context.Context) *IntCmd ClientUnblock(ctx context.Context, id int64) *IntCmd ClientUnblockWithError(ctx context.Context, id int64) *IntCmd + ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd ConfigResetStat(ctx context.Context) *StatusCmd ConfigSet(ctx context.Context, parameter, value string) *StatusCmd @@ -518,6 +519,23 @@ func (c cmdable) ClientInfo(ctx context.Context) *ClientInfoCmd { return cmd } +// ClientMaintNotifications enables or disables maintenance notifications for hitless upgrades. +// When enabled, the client will receive push notifications about Redis maintenance events. +func (c cmdable) ClientMaintNotifications(ctx context.Context, enabled bool, endpointType string) *StatusCmd { + args := []interface{}{"client", "maint_notifications"} + if enabled { + if endpointType == "" { + endpointType = "none" + } + args = append(args, "on", "moving-endpoint-type", endpointType) + } else { + args = append(args, "off") + } + cmd := NewStatusCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + // ------------------------------------------------------------------------------------------------ func (c cmdable) ConfigGet(ctx context.Context, parameter string) *MapStringStringCmd { diff --git a/example/pubsub/go.mod b/example/pubsub/go.mod new file mode 100644 index 0000000000..731a92839d --- /dev/null +++ b/example/pubsub/go.mod @@ -0,0 +1,12 @@ +module github.com/redis/go-redis/example/pubsub + +go 1.18 + +replace github.com/redis/go-redis/v9 => ../.. + +require github.com/redis/go-redis/v9 v9.11.0 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect +) diff --git a/example/pubsub/go.sum b/example/pubsub/go.sum new file mode 100644 index 0000000000..d64ea0303f --- /dev/null +++ b/example/pubsub/go.sum @@ -0,0 +1,6 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= diff --git a/example/pubsub/main.go b/example/pubsub/main.go new file mode 100644 index 0000000000..ddc0604d0e --- /dev/null +++ b/example/pubsub/main.go @@ -0,0 +1,146 @@ +package main + +import ( + "context" + "fmt" + "log" + "sync" + "time" + + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/hitless" +) + +var ctx = context.Background() + +// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management. +// It was used to find regressions in pool management in hitless mode. +// Please don't use it as a reference for how to use pubsub. +func main() { + wg := &sync.WaitGroup{} + rdb := redis.NewClient(&redis.Options{ + Addr: ":6379", + HitlessUpgradeConfig: &redis.HitlessUpgradeConfig{ + Mode: hitless.MaintNotificationsEnabled, + }, + }) + _ = rdb.FlushDB(ctx).Err() + + go func() { + for { + time.Sleep(2 * time.Second) + fmt.Printf("pool stats: %+v\n", rdb.PoolStats()) + } + }() + err := rdb.Ping(ctx).Err() + if err != nil { + panic(err) + } + if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil { + panic(err) + } + if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil { + panic(err) + } + fmt.Println("published", rdb.Get(ctx, "published").Val()) + fmt.Println("received", rdb.Get(ctx, "received").Val()) + subCtx, cancelSubCtx := context.WithCancel(ctx) + pubCtx, cancelPublishers := context.WithCancel(ctx) + for i := 0; i < 10; i++ { + wg.Add(1) + go subscribe(subCtx, rdb, "test", i, wg) + } + time.Sleep(time.Second) + cancelSubCtx() + time.Sleep(time.Second) + subCtx, cancelSubCtx = context.WithCancel(ctx) + for i := 0; i < 10; i++ { + if err := rdb.Incr(ctx, "publishers").Err(); err != nil { + panic(err) + } + wg.Add(1) + go floodThePool(pubCtx, rdb, wg) + } + + for i := 0; i < 500; i++ { + if err := rdb.Incr(ctx, "subscribers").Err(); err != nil { + panic(err) + } + wg.Add(1) + go subscribe(subCtx, rdb, "test2", i, wg) + } + time.Sleep(5 * time.Second) + fmt.Println("canceling publishers") + cancelPublishers() + time.Sleep(10 * time.Second) + fmt.Println("canceling subscribers") + cancelSubCtx() + wg.Wait() + published, err := rdb.Get(ctx, "published").Result() + received, err := rdb.Get(ctx, "received").Result() + publishers, err := rdb.Get(ctx, "publishers").Result() + subscribers, err := rdb.Get(ctx, "subscribers").Result() + fmt.Printf("publishers: %s\n", publishers) + fmt.Printf("published: %s\n", published) + fmt.Printf("subscribers: %s\n", subscribers) + fmt.Printf("received: %s\n", received) + publishedInt, err := rdb.Get(ctx, "published").Int() + subscribersInt, err := rdb.Get(ctx, "subscribers").Int() + fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt) + + time.Sleep(2 * time.Second) +} + +func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) { + defer wg.Done() + for { + select { + case <-ctx.Done(): + return + default: + } + err := rdb.Publish(ctx, "test2", "hello").Err() + if err != nil { + // noop + //log.Println("publish error:", err) + } + + err = rdb.Incr(ctx, "published").Err() + if err != nil { + // noop + //log.Println("incr error:", err) + } + time.Sleep(10 * time.Nanosecond) + } +} + +func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) { + defer wg.Done() + rec := rdb.Subscribe(ctx, topic) + recChan := rec.Channel() + for { + select { + case <-ctx.Done(): + rec.Close() + return + default: + select { + case <-ctx.Done(): + rec.Close() + return + case msg := <-recChan: + err := rdb.Incr(ctx, "received").Err() + if err != nil { + log.Println("incr error:", err) + } + _ = msg // Use the message to avoid unused variable warning + } + } + } +} diff --git a/example_instrumentation_test.go b/example_instrumentation_test.go index 36234ff09e..73248e4c53 100644 --- a/example_instrumentation_test.go +++ b/example_instrumentation_test.go @@ -57,6 +57,8 @@ func Example_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> + // starting processing: <[client maint_notifications on moving-endpoint-type external-ip]> + // finished processing: <[client maint_notifications on moving-endpoint-type external-ip]> // finished processing: <[ping]> } @@ -78,6 +80,8 @@ func ExamplePipeline_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> + // starting processing: <[client maint_notifications on moving-endpoint-type external-ip]> + // finished processing: <[client maint_notifications on moving-endpoint-type external-ip]> // pipeline finished processing: [[ping] [ping]] } @@ -99,6 +103,8 @@ func ExampleClient_Watch_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> + // starting processing: <[client maint_notifications on moving-endpoint-type external-ip]> + // finished processing: <[client maint_notifications on moving-endpoint-type external-ip]> // finished processing: <[watch foo]> // starting processing: <[ping]> // finished processing: <[ping]> diff --git a/hitless/README.md b/hitless/README.md new file mode 100644 index 0000000000..b82b33a3d2 --- /dev/null +++ b/hitless/README.md @@ -0,0 +1,72 @@ +# Hitless Upgrades + +Seamless Redis connection handoffs during topology changes without interrupting operations. + +## Quick Start + +```go +import "github.com/redis/go-redis/v9/hitless" + +opt := &redis.Options{ + Addr: "localhost:6379", + Protocol: 3, // RESP3 required + HitlessUpgrades: &redis.HitlessUpgradeConfig{ + Mode: hitless.MaintNotificationsEnabled, // or MaintNotificationsAuto + }, +} +client := redis.NewClient(opt) +``` + +## Modes + +- **`MaintNotificationsDisabled`**: Hitless upgrades are completely disabled +- **`MaintNotificationsEnabled`**: Hitless upgrades are forcefully enabled (fails if server doesn't support it) +- **`MaintNotificationsAuto`**: Hitless upgrades are enabled if server supports it (default) + +## Configuration + +```go +import "github.com/redis/go-redis/v9/hitless" + +Config: &hitless.Config{ + Mode: hitless.MaintNotificationsAuto, // Notification mode + MaxHandoffRetries: 3, // Retry failed handoffs + HandoffTimeout: 15 * time.Second, // Handoff operation timeout + RelaxedTimeout: 10 * time.Second, // Extended timeout during migrations + PostHandoffRelaxedDuration: 20 * time.Second, // Keep relaxed timeout after handoff + LogLevel: 1, // 0=errors, 1=warnings, 2=info, 3=debug + MaxWorkers: 15, // Concurrent handoff workers + HandoffQueueSize: 50, // Handoff request queue size +} +``` + +### Worker Scaling +- **Auto-calculated**: `min(10, PoolSize/3)` - scales with pool size, capped at 10 +- **Explicit values**: `max(10, set_value)` - enforces minimum 10 workers +- **On-demand**: Workers created when needed, cleaned up when idle + +### Queue Sizing +- **Auto-calculated**: `10 × MaxWorkers`, capped by pool size +- **Always capped**: Queue size never exceeds pool size + +## Metrics Hook Example + +A metrics collection hook is available in `example_hooks.go` that demonstrates how to monitor hitless upgrade operations: + +```go +import "github.com/redis/go-redis/v9/hitless" + +metricsHook := hitless.NewMetricsHook() +// Use with your monitoring system +``` + +The metrics hook tracks: +- Handoff success/failure rates +- Handoff duration +- Queue depth +- Worker utilization +- Connection lifecycle events + +## Requirements + +- **RESP3 Protocol**: Required for push notifications diff --git a/hitless/config.go b/hitless/config.go new file mode 100644 index 0000000000..b35a0d7185 --- /dev/null +++ b/hitless/config.go @@ -0,0 +1,377 @@ +package hitless + +import ( + "net" + "runtime" + "time" + + "github.com/redis/go-redis/v9/internal/util" +) + +// MaintNotificationsMode represents the maintenance notifications mode +type MaintNotificationsMode string + +// Constants for maintenance push notifications modes +const ( + MaintNotificationsDisabled MaintNotificationsMode = "disabled" // Client doesn't send CLIENT MAINT_NOTIFICATIONS ON command + MaintNotificationsEnabled MaintNotificationsMode = "enabled" // Client forcefully sends command, interrupts connection on error + MaintNotificationsAuto MaintNotificationsMode = "auto" // Client tries to send command, disables feature on error +) + +// IsValid returns true if the maintenance notifications mode is valid +func (m MaintNotificationsMode) IsValid() bool { + switch m { + case MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto: + return true + default: + return false + } +} + +// String returns the string representation of the mode +func (m MaintNotificationsMode) String() string { + return string(m) +} + +// EndpointType represents the type of endpoint to request in MOVING notifications +type EndpointType string + +// Constants for endpoint types +const ( + EndpointTypeAuto EndpointType = "auto" // Auto-detect based on connection + EndpointTypeInternalIP EndpointType = "internal-ip" // Internal IP address + EndpointTypeInternalFQDN EndpointType = "internal-fqdn" // Internal FQDN + EndpointTypeExternalIP EndpointType = "external-ip" // External IP address + EndpointTypeExternalFQDN EndpointType = "external-fqdn" // External FQDN + EndpointTypeNone EndpointType = "none" // No endpoint (reconnect with current config) +) + +// IsValid returns true if the endpoint type is valid +func (e EndpointType) IsValid() bool { + switch e { + case EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN, + EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone: + return true + default: + return false + } +} + +// String returns the string representation of the endpoint type +func (e EndpointType) String() string { + return string(e) +} + +// Config provides configuration options for hitless upgrades. +type Config struct { + // Mode controls how client maintenance notifications are handled. + // Valid values: MaintNotificationsDisabled, MaintNotificationsEnabled, MaintNotificationsAuto + // Default: MaintNotificationsAuto + Mode MaintNotificationsMode + + // EndpointType specifies the type of endpoint to request in MOVING notifications. + // Valid values: EndpointTypeAuto, EndpointTypeInternalIP, EndpointTypeInternalFQDN, + // EndpointTypeExternalIP, EndpointTypeExternalFQDN, EndpointTypeNone + // Default: EndpointTypeAuto + EndpointType EndpointType + + // RelaxedTimeout is the concrete timeout value to use during + // MIGRATING/FAILING_OVER states to accommodate increased latency. + // This applies to both read and write timeouts. + // Default: 10 seconds + RelaxedTimeout time.Duration + + // HandoffTimeout is the maximum time to wait for connection handoff to complete. + // If handoff takes longer than this, the old connection will be forcibly closed. + // Default: 15 seconds (matches server-side eviction timeout) + HandoffTimeout time.Duration + + // MaxWorkers is the maximum number of worker goroutines for processing handoff requests. + // Workers are created on-demand and automatically cleaned up when idle. + // If zero, defaults to min(10, PoolSize/3) to handle bursts effectively. + // If explicitly set, enforces minimum of 10 workers. + // + // Default: min(10, PoolSize/3), Minimum when set: 10 + MaxWorkers int + + // HandoffQueueSize is the size of the buffered channel used to queue handoff requests. + // If the queue is full, new handoff requests will be rejected. + // Always capped by pool size since you can't handoff more connections than exist. + // + // Default: 10x max workers, capped by pool size, min 2 + HandoffQueueSize int + + // PostHandoffRelaxedDuration is how long to keep relaxed timeouts on the new connection + // after a handoff completes. This provides additional resilience during cluster transitions. + // Default: 2 * RelaxedTimeout + PostHandoffRelaxedDuration time.Duration + + // ScaleDownDelay is the delay before checking if workers should be scaled down. + // This prevents expensive checks on every handoff completion and avoids rapid scaling cycles. + // Default: 2 seconds + ScaleDownDelay time.Duration + + // LogLevel controls the verbosity of hitless upgrade logging. + // 0 = errors only, 1 = warnings, 2 = info, 3 = debug + // Default: 1 (warnings) + LogLevel int + + // MaxHandoffRetries is the maximum number of times to retry a failed handoff. + // After this many retries, the connection will be removed from the pool. + // Default: 3 + MaxHandoffRetries int +} + +func (c *Config) IsEnabled() bool { + return c != nil && c.Mode != MaintNotificationsDisabled +} + +// DefaultConfig returns a Config with sensible defaults. +func DefaultConfig() *Config { + return &Config{ + Mode: MaintNotificationsAuto, // Enable by default for Redis Cloud + EndpointType: EndpointTypeAuto, // Auto-detect based on connection + RelaxedTimeout: 10 * time.Second, + HandoffTimeout: 15 * time.Second, + MaxWorkers: 0, // Auto-calculated based on pool size + HandoffQueueSize: 0, // Auto-calculated based on max workers + PostHandoffRelaxedDuration: 0, // Auto-calculated based on relaxed timeout + ScaleDownDelay: 2 * time.Second, + LogLevel: 1, + + // Connection Handoff Configuration + MaxHandoffRetries: 3, + } +} + +// Validate checks if the configuration is valid. +func (c *Config) Validate() error { + if c.RelaxedTimeout <= 0 { + return ErrInvalidRelaxedTimeout + } + if c.HandoffTimeout <= 0 { + return ErrInvalidHandoffTimeout + } + // Validate worker configuration + // Allow 0 for auto-calculation, but negative values are invalid + if c.MaxWorkers < 0 { + return ErrInvalidHandoffWorkers + } + // HandoffQueueSize validation - allow 0 for auto-calculation + if c.HandoffQueueSize < 0 { + return ErrInvalidHandoffQueueSize + } + if c.PostHandoffRelaxedDuration < 0 { + return ErrInvalidPostHandoffRelaxedDuration + } + if c.LogLevel < 0 || c.LogLevel > 3 { + return ErrInvalidLogLevel + } + + // Validate Mode (maintenance notifications mode) + if !c.Mode.IsValid() { + return ErrInvalidMaintNotifications + } + + // Validate EndpointType + if !c.EndpointType.IsValid() { + return ErrInvalidEndpointType + } + + // Validate configuration fields + if c.MaxHandoffRetries < 1 || c.MaxHandoffRetries > 10 { + return ErrInvalidHandoffRetries + } + + + + return nil +} + +// ApplyDefaults applies default values to any zero-value fields in the configuration. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaults() *Config { + return c.ApplyDefaultsWithPoolSize(0) +} + +// ApplyDefaultsWithPoolSize applies default values to any zero-value fields in the configuration, +// using the provided pool size to calculate worker defaults. +// This ensures that partially configured structs get sensible defaults for missing fields. +func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config { + if c == nil { + return DefaultConfig().ApplyDefaultsWithPoolSize(poolSize) + } + + defaults := DefaultConfig() + result := &Config{} + + // Apply defaults for enum fields (empty/zero means not set) + if c.Mode == "" { + result.Mode = defaults.Mode + } else { + result.Mode = c.Mode + } + + if c.EndpointType == "" { + result.EndpointType = defaults.EndpointType + } else { + result.EndpointType = c.EndpointType + } + + // Apply defaults for duration fields (zero means not set) + if c.RelaxedTimeout <= 0 { + result.RelaxedTimeout = defaults.RelaxedTimeout + } else { + result.RelaxedTimeout = c.RelaxedTimeout + } + + if c.HandoffTimeout <= 0 { + result.HandoffTimeout = defaults.HandoffTimeout + } else { + result.HandoffTimeout = c.HandoffTimeout + } + + // Apply defaults for integer fields (zero means not set) + if c.HandoffQueueSize <= 0 { + result.HandoffQueueSize = defaults.HandoffQueueSize + } else { + result.HandoffQueueSize = c.HandoffQueueSize + } + + // Copy worker configuration + result.MaxWorkers = c.MaxWorkers + + // Apply worker defaults based on pool size + result.applyWorkerDefaults(poolSize) + + // Apply queue size defaults based on max workers, capped by pool size + if c.HandoffQueueSize <= 0 { + // Queue size: 10x max workers, but never more than pool size + workerBasedSize := result.MaxWorkers * 10 + result.HandoffQueueSize = util.Min(workerBasedSize, poolSize) + } else { + result.HandoffQueueSize = c.HandoffQueueSize + } + + // Always cap queue size by pool size - no point having more queue slots than connections + result.HandoffQueueSize = util.Min(result.HandoffQueueSize, poolSize) + + // Ensure minimum queue size of 2 + if result.HandoffQueueSize < 2 { + result.HandoffQueueSize = 2 + } + + if c.PostHandoffRelaxedDuration <= 0 { + result.PostHandoffRelaxedDuration = result.RelaxedTimeout * 2 + } else { + result.PostHandoffRelaxedDuration = c.PostHandoffRelaxedDuration + } + + if c.ScaleDownDelay <= 0 { + result.ScaleDownDelay = defaults.ScaleDownDelay + } else { + result.ScaleDownDelay = c.ScaleDownDelay + } + + // LogLevel: 0 is a valid value (errors only), so we need to check if it was explicitly set + // We'll use the provided value as-is, since 0 is valid + result.LogLevel = c.LogLevel + + // Apply defaults for configuration fields + if c.MaxHandoffRetries <= 0 { + result.MaxHandoffRetries = defaults.MaxHandoffRetries + } else { + result.MaxHandoffRetries = c.MaxHandoffRetries + } + + + + return result +} + +// Clone creates a deep copy of the configuration. +func (c *Config) Clone() *Config { + if c == nil { + return DefaultConfig() + } + + return &Config{ + Mode: c.Mode, + EndpointType: c.EndpointType, + RelaxedTimeout: c.RelaxedTimeout, + HandoffTimeout: c.HandoffTimeout, + MaxWorkers: c.MaxWorkers, + HandoffQueueSize: c.HandoffQueueSize, + PostHandoffRelaxedDuration: c.PostHandoffRelaxedDuration, + ScaleDownDelay: c.ScaleDownDelay, + LogLevel: c.LogLevel, + + // Configuration fields + MaxHandoffRetries: c.MaxHandoffRetries, + } +} + +// applyWorkerDefaults calculates and applies worker defaults based on pool size +func (c *Config) applyWorkerDefaults(poolSize int) { + // Calculate defaults based on pool size + if poolSize <= 0 { + poolSize = 10 * runtime.GOMAXPROCS(0) + } + + if c.MaxWorkers == 0 { + // When not set: min(10, poolSize/3) - don't exceed 10 workers for small pools + c.MaxWorkers = util.Min(10, poolSize/3) + } else { + // When explicitly set: max(10, set_value) - ensure at least 10 workers + c.MaxWorkers = util.Max(10, c.MaxWorkers) + } + + // Ensure minimum of 1 worker (fallback for very small pools) + if c.MaxWorkers < 1 { + c.MaxWorkers = 1 + } +} + +// DetectEndpointType automatically detects the appropriate endpoint type +// based on the connection address and TLS configuration. +func DetectEndpointType(addr string, tlsEnabled bool) EndpointType { + // Parse the address to determine if it's an IP or hostname + isPrivate := isPrivateIP(addr) + + var endpointType EndpointType + + if tlsEnabled { + // TLS requires FQDN for certificate validation + if isPrivate { + endpointType = EndpointTypeInternalFQDN + } else { + endpointType = EndpointTypeExternalFQDN + } + } else { + // No TLS, can use IP addresses + if isPrivate { + endpointType = EndpointTypeInternalIP + } else { + endpointType = EndpointTypeExternalIP + } + } + + return endpointType +} + +// isPrivateIP checks if the given address is in a private IP range. +func isPrivateIP(addr string) bool { + // Extract host from "host:port" format + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr // Assume no port + } + + ip := net.ParseIP(host) + if ip == nil { + return false // Not an IP address (likely hostname) + } + + // Check for private/loopback ranges + return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() +} diff --git a/hitless/config_test.go b/hitless/config_test.go new file mode 100644 index 0000000000..7a032bc294 --- /dev/null +++ b/hitless/config_test.go @@ -0,0 +1,427 @@ +package hitless + +import ( + "context" + "net" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/util" +) + +func TestConfig(t *testing.T) { + t.Run("DefaultConfig", func(t *testing.T) { + config := DefaultConfig() + + // MaxWorkers should be 0 in default config (auto-calculated) + if config.MaxWorkers != 0 { + t.Errorf("Expected MaxWorkers to be 0 (auto-calculated), got %d", config.MaxWorkers) + } + + // HandoffQueueSize should be 0 in default config (auto-calculated) + if config.HandoffQueueSize != 0 { + t.Errorf("Expected HandoffQueueSize to be 0 (auto-calculated), got %d", config.HandoffQueueSize) + } + + if config.RelaxedTimeout != 10*time.Second { + t.Errorf("Expected RelaxedTimeout to be 10s, got %v", config.RelaxedTimeout) + } + + // Test configuration fields have proper defaults + if config.MaxHandoffRetries != 3 { + t.Errorf("Expected MaxHandoffRetries to be 3, got %d", config.MaxHandoffRetries) + } + + if config.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s, got %v", config.HandoffTimeout) + } + + if config.PostHandoffRelaxedDuration != 0 { + t.Errorf("Expected PostHandoffRelaxedDuration to be 0 (auto-calculated), got %v", config.PostHandoffRelaxedDuration) + } + + // Test that defaults are applied correctly + configWithDefaults := config.ApplyDefaultsWithPoolSize(100) + if configWithDefaults.PostHandoffRelaxedDuration != 20*time.Second { + t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout) after applying defaults, got %v", configWithDefaults.PostHandoffRelaxedDuration) + } + }) + + t.Run("ConfigValidation", func(t *testing.T) { + // Valid config with applied defaults + config := DefaultConfig().ApplyDefaults() + if err := config.Validate(); err != nil { + t.Errorf("Default config with applied defaults should be valid: %v", err) + } + + // Invalid worker configuration (negative MaxWorkers) + config = &Config{ + RelaxedTimeout: 30 * time.Second, + HandoffTimeout: 15 * time.Second, + MaxWorkers: -1, // This should be invalid + HandoffQueueSize: 100, + PostHandoffRelaxedDuration: 10 * time.Second, + LogLevel: 1, + MaxHandoffRetries: 3, // Add required field + } + if err := config.Validate(); err != ErrInvalidHandoffWorkers { + t.Errorf("Expected ErrInvalidHandoffWorkers, got %v", err) + } + + // Invalid HandoffQueueSize + config = DefaultConfig().ApplyDefaults() + config.HandoffQueueSize = -1 + if err := config.Validate(); err != ErrInvalidHandoffQueueSize { + t.Errorf("Expected ErrInvalidHandoffQueueSize, got %v", err) + } + + // Invalid PostHandoffRelaxedDuration + config = DefaultConfig().ApplyDefaults() + config.PostHandoffRelaxedDuration = -1 * time.Second + if err := config.Validate(); err != ErrInvalidPostHandoffRelaxedDuration { + t.Errorf("Expected ErrInvalidPostHandoffRelaxedDuration, got %v", err) + } + }) + + t.Run("ConfigClone", func(t *testing.T) { + original := DefaultConfig() + original.MaxWorkers = 20 + original.HandoffQueueSize = 200 + + cloned := original.Clone() + + if cloned.MaxWorkers != 20 { + t.Errorf("Expected cloned MaxWorkers to be 20, got %d", cloned.MaxWorkers) + } + + if cloned.HandoffQueueSize != 200 { + t.Errorf("Expected cloned HandoffQueueSize to be 200, got %d", cloned.HandoffQueueSize) + } + + // Modify original to ensure clone is independent + original.MaxWorkers = 2 + if cloned.MaxWorkers != 20 { + t.Error("Clone should be independent of original") + } + }) +} + +func TestApplyDefaults(t *testing.T) { + t.Run("NilConfig", func(t *testing.T) { + var config *Config + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // With nil config, should get default config with auto-calculated workers + if result.MaxWorkers <= 0 { + t.Errorf("Expected MaxWorkers to be > 0 after applying defaults, got %d", result.MaxWorkers) + } + + // HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size) + workerBasedSize := result.MaxWorkers * 10 + poolSize := 100 // Default pool size used in ApplyDefaults + expectedQueueSize := util.Min(workerBasedSize, poolSize) + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d", + expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize) + } + }) + + t.Run("PartialConfig", func(t *testing.T) { + config := &Config{ + MaxWorkers: 12, // Set this field explicitly + // Leave other fields as zero values + } + + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Should keep the explicitly set values + if result.MaxWorkers != 12 { + t.Errorf("Expected MaxWorkers to be 12 (explicitly set), got %d", result.MaxWorkers) + } + + // Should apply default for unset fields (auto-calculated queue size, capped by pool size) + workerBasedSize := result.MaxWorkers * 10 + poolSize := 100 // Default pool size used in ApplyDefaults + expectedQueueSize := util.Min(workerBasedSize, poolSize) + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d", + expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize) + } + + // Test explicit queue size capping by pool size + configWithLargeQueue := &Config{ + MaxWorkers: 5, + HandoffQueueSize: 1000, // Much larger than pool size + } + + resultCapped := configWithLargeQueue.ApplyDefaultsWithPoolSize(20) // Small pool size + if resultCapped.HandoffQueueSize != 20 { + t.Errorf("Expected HandoffQueueSize to be capped by pool size (20), got %d", resultCapped.HandoffQueueSize) + } + + if result.RelaxedTimeout != 10*time.Second { + t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout) + } + + if result.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", result.HandoffTimeout) + } + }) + + t.Run("ZeroValues", func(t *testing.T) { + config := &Config{ + MaxWorkers: 0, // Zero value should get auto-calculated defaults + HandoffQueueSize: 0, // Zero value should get default + RelaxedTimeout: 0, // Zero value should get default + LogLevel: 0, // Zero is valid for LogLevel (errors only) + } + + result := config.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Zero values should get auto-calculated defaults + if result.MaxWorkers <= 0 { + t.Errorf("Expected MaxWorkers to be > 0 (auto-calculated), got %d", result.MaxWorkers) + } + + // HandoffQueueSize should be auto-calculated (10 * MaxWorkers, capped by pool size) + workerBasedSize := result.MaxWorkers * 10 + poolSize := 100 // Default pool size used in ApplyDefaults + expectedQueueSize := util.Min(workerBasedSize, poolSize) + if result.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d", + expectedQueueSize, workerBasedSize, poolSize, result.HandoffQueueSize) + } + + if result.RelaxedTimeout != 10*time.Second { + t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", result.RelaxedTimeout) + } + + // LogLevel 0 should be preserved (it's a valid value) + if result.LogLevel != 0 { + t.Errorf("Expected LogLevel to be 0 (preserved), got %d", result.LogLevel) + } + }) +} + +func TestProcessorWithConfig(t *testing.T) { + t.Run("ProcessorUsesConfigValues", func(t *testing.T) { + config := &Config{ + MaxWorkers: 5, + HandoffQueueSize: 50, + RelaxedTimeout: 10 * time.Second, + HandoffTimeout: 5 * time.Second, + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // The processor should be created successfully with custom config + if processor == nil { + t.Error("Processor should be created with custom config") + } + }) + + t.Run("ProcessorWithPartialConfig", func(t *testing.T) { + config := &Config{ + MaxWorkers: 7, // Only set worker field + // Other fields will get defaults + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Should work with partial config (defaults applied) + if processor == nil { + t.Error("Processor should be created with partial config") + } + }) + + t.Run("ProcessorWithNilConfig", func(t *testing.T) { + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + // Should use default config when nil is passed + if processor == nil { + t.Error("Processor should be created with nil config (using defaults)") + } + }) +} + +func TestIntegrationWithApplyDefaults(t *testing.T) { + t.Run("ProcessorWithPartialConfigAppliesDefaults", func(t *testing.T) { + // Create a partial config with only some fields set + partialConfig := &Config{ + MaxWorkers: 15, // Custom value (>= 10 to test preservation) + LogLevel: 2, // Custom value + // Other fields left as zero values - should get defaults + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + // Create processor - should apply defaults to missing fields + processor := NewPoolHook(baseDialer, "tcp", partialConfig, nil) + defer processor.Shutdown(context.Background()) + + // Processor should be created successfully + if processor == nil { + t.Error("Processor should be created with partial config") + } + + // Test that the ApplyDefaults method worked correctly by creating the same config + // and applying defaults manually + expectedConfig := partialConfig.ApplyDefaultsWithPoolSize(100) // Use explicit pool size for testing + + // Should preserve custom values (when >= 10) + if expectedConfig.MaxWorkers != 15 { + t.Errorf("Expected MaxWorkers to be 15, got %d", expectedConfig.MaxWorkers) + } + + if expectedConfig.LogLevel != 2 { + t.Errorf("Expected LogLevel to be 2, got %d", expectedConfig.LogLevel) + } + + // Should apply defaults for missing fields (auto-calculated queue size, capped by pool size) + workerBasedSize := expectedConfig.MaxWorkers * 10 + poolSize := 100 // Default pool size used in ApplyDefaults + expectedQueueSize := util.Min(workerBasedSize, poolSize) + if expectedConfig.HandoffQueueSize != expectedQueueSize { + t.Errorf("Expected HandoffQueueSize to be %d (util.Min(10*MaxWorkers=%d, poolSize=%d)), got %d", + expectedQueueSize, workerBasedSize, poolSize, expectedConfig.HandoffQueueSize) + } + + // Test that queue size is always capped by pool size + if expectedConfig.HandoffQueueSize > poolSize { + t.Errorf("HandoffQueueSize (%d) should never exceed pool size (%d)", + expectedConfig.HandoffQueueSize, poolSize) + } + + if expectedConfig.RelaxedTimeout != 10*time.Second { + t.Errorf("Expected RelaxedTimeout to be 10s (default), got %v", expectedConfig.RelaxedTimeout) + } + + if expectedConfig.HandoffTimeout != 15*time.Second { + t.Errorf("Expected HandoffTimeout to be 15s (default), got %v", expectedConfig.HandoffTimeout) + } + + if expectedConfig.PostHandoffRelaxedDuration != 20*time.Second { + t.Errorf("Expected PostHandoffRelaxedDuration to be 20s (2x RelaxedTimeout), got %v", expectedConfig.PostHandoffRelaxedDuration) + } + }) +} + +func TestEnhancedConfigValidation(t *testing.T) { + t.Run("ValidateFields", func(t *testing.T) { + config := DefaultConfig() + config.ApplyDefaultsWithPoolSize(100) // Apply defaults with pool size 100 + + // Should pass validation with default values + if err := config.Validate(); err != nil { + t.Errorf("Default config should be valid, got error: %v", err) + } + + // Test invalid MaxHandoffRetries + config.MaxHandoffRetries = 0 + if err := config.Validate(); err == nil { + t.Error("Expected validation error for MaxHandoffRetries = 0") + } + config.MaxHandoffRetries = 11 + if err := config.Validate(); err == nil { + t.Error("Expected validation error for MaxHandoffRetries = 11") + } + config.MaxHandoffRetries = 3 // Reset to valid value + + // Should pass validation again + if err := config.Validate(); err != nil { + t.Errorf("Config should be valid after reset, got error: %v", err) + } + }) +} + +func TestConfigClone(t *testing.T) { + original := DefaultConfig() + original.MaxHandoffRetries = 7 + original.HandoffTimeout = 8 * time.Second + + cloned := original.Clone() + + // Test that values are copied + if cloned.MaxHandoffRetries != 7 { + t.Errorf("Expected cloned MaxHandoffRetries to be 7, got %d", cloned.MaxHandoffRetries) + } + if cloned.HandoffTimeout != 8*time.Second { + t.Errorf("Expected cloned HandoffTimeout to be 8s, got %v", cloned.HandoffTimeout) + } + + // Test that modifying clone doesn't affect original + cloned.MaxHandoffRetries = 10 + if original.MaxHandoffRetries != 7 { + t.Errorf("Modifying clone should not affect original, original MaxHandoffRetries changed to %d", original.MaxHandoffRetries) + } +} + +func TestMaxWorkersLogic(t *testing.T) { + t.Run("AutoCalculatedMaxWorkers", func(t *testing.T) { + testCases := []struct { + poolSize int + expectedWorkers int + description string + }{ + {6, 2, "Small pool: min(10, 6/3) = min(10, 2) = 2"}, + {15, 5, "Medium pool: min(10, 15/3) = min(10, 5) = 5"}, + {30, 10, "Large pool: min(10, 30/3) = min(10, 10) = 10"}, + {60, 10, "Very large pool: min(10, 60/3) = min(10, 20) = 10"}, + {120, 10, "Huge pool: min(10, 120/3) = min(10, 40) = 10"}, + } + + for _, tc := range testCases { + config := &Config{} // MaxWorkers = 0 (not set) + result := config.ApplyDefaultsWithPoolSize(tc.poolSize) + + if result.MaxWorkers != tc.expectedWorkers { + t.Errorf("PoolSize=%d: expected MaxWorkers=%d, got %d (%s)", + tc.poolSize, tc.expectedWorkers, result.MaxWorkers, tc.description) + } + } + }) + + t.Run("ExplicitlySetMaxWorkers", func(t *testing.T) { + testCases := []struct { + setValue int + expectedWorkers int + description string + }{ + {1, 10, "Set 1: max(10, 1) = 10 (enforced minimum)"}, + {5, 10, "Set 5: max(10, 5) = 10 (enforced minimum)"}, + {8, 10, "Set 8: max(10, 8) = 10 (enforced minimum)"}, + {10, 10, "Set 10: max(10, 10) = 10 (exact minimum)"}, + {15, 15, "Set 15: max(10, 15) = 15 (respects user choice)"}, + {20, 20, "Set 20: max(10, 20) = 20 (respects user choice)"}, + } + + for _, tc := range testCases { + config := &Config{ + MaxWorkers: tc.setValue, // Explicitly set + } + result := config.ApplyDefaultsWithPoolSize(100) // Pool size doesn't affect explicit values + + if result.MaxWorkers != tc.expectedWorkers { + t.Errorf("Set MaxWorkers=%d: expected %d, got %d (%s)", + tc.setValue, tc.expectedWorkers, result.MaxWorkers, tc.description) + } + } + }) +} diff --git a/hitless/errors.go b/hitless/errors.go new file mode 100644 index 0000000000..5beb250aaa --- /dev/null +++ b/hitless/errors.go @@ -0,0 +1,76 @@ +package hitless + +import ( + "errors" + "fmt" +) + +// Configuration errors +var ( + ErrInvalidRelaxedTimeout = errors.New("hitless: relaxed timeout must be greater than 0") + ErrInvalidHandoffTimeout = errors.New("hitless: handoff timeout must be greater than 0") + ErrInvalidHandoffWorkers = errors.New("hitless: MaxWorkers must be greater than or equal to 0") + ErrInvalidHandoffQueueSize = errors.New("hitless: handoff queue size must be greater than 0") + ErrInvalidPostHandoffRelaxedDuration = errors.New("hitless: post-handoff relaxed duration must be greater than or equal to 0") + ErrInvalidLogLevel = errors.New("hitless: log level must be between 0 and 3") + ErrInvalidEndpointType = errors.New("hitless: invalid endpoint type") + ErrInvalidMaintNotifications = errors.New("hitless: invalid maintenance notifications setting (must be 'disabled', 'enabled', or 'auto')") + ErrMaxHandoffRetriesReached = errors.New("hitless: max handoff retries reached") + + // Configuration validation errors + ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10") + ErrInvalidConnectionValidationTimeout = errors.New("hitless: ConnectionValidationTimeout must be greater than 0 and less than 30 seconds") + ErrInvalidConnectionHealthCheckInterval = errors.New("hitless: ConnectionHealthCheckInterval must be between 0 and 1 hour") + ErrInvalidOperationCleanupInterval = errors.New("hitless: OperationCleanupInterval must be greater than 0 and less than 1 hour") + ErrInvalidMaxActiveOperations = errors.New("hitless: MaxActiveOperations must be between 100 and 100000") + ErrInvalidNotificationBufferSize = errors.New("hitless: NotificationBufferSize must be between 10 and 10000") + ErrInvalidNotificationTimeout = errors.New("hitless: NotificationTimeout must be greater than 0 and less than 30 seconds") +) + +// Integration errors +var ( + ErrInvalidClient = errors.New("hitless: invalid client type") +) + +// Handoff errors +var ( + ErrHandoffInProgress = errors.New("hitless: handoff already in progress") + ErrNoHandoffInProgress = errors.New("hitless: no handoff in progress") + ErrConnectionFailed = errors.New("hitless: failed to establish new connection") +) + +// Dead error variables removed - unused in simplified architecture + +// Notification errors +var ( + ErrInvalidNotification = errors.New("hitless: invalid notification format") +) + +// Dead error variables removed - unused in simplified architecture + +// HandoffError represents an error that occurred during connection handoff. +type HandoffError struct { + Operation string + Endpoint string + Cause error +} + +func (e *HandoffError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("hitless: handoff %s failed for endpoint %s: %v", e.Operation, e.Endpoint, e.Cause) + } + return fmt.Sprintf("hitless: handoff %s failed for endpoint %s", e.Operation, e.Endpoint) +} + +func (e *HandoffError) Unwrap() error { + return e.Cause +} + +// NewHandoffError creates a new HandoffError. +func NewHandoffError(operation, endpoint string, cause error) *HandoffError { + return &HandoffError{ + Operation: operation, + Endpoint: endpoint, + Cause: cause, + } +} diff --git a/hitless/example_hooks.go b/hitless/example_hooks.go new file mode 100644 index 0000000000..f03ea3ed59 --- /dev/null +++ b/hitless/example_hooks.go @@ -0,0 +1,63 @@ +package hitless + +import ( + "context" + "time" +) + +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const ( + startTimeKey contextKey = "notif_hitless_start_time" +) + +// MetricsHook collects metrics about notification processing. +type MetricsHook struct { + NotificationCounts map[string]int64 + ProcessingTimes map[string]time.Duration + ErrorCounts map[string]int64 +} + +// NewMetricsHook creates a new metrics collection hook. +func NewMetricsHook() *MetricsHook { + return &MetricsHook{ + NotificationCounts: make(map[string]int64), + ProcessingTimes: make(map[string]time.Duration), + ErrorCounts: make(map[string]int64), + } +} + +// PreHook records the start time for processing metrics. +func (mh *MetricsHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) { + mh.NotificationCounts[notificationType]++ + + // Store start time in context for duration calculation + startTime := time.Now() + _ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further + + return notification, true +} + +// PostHook records processing completion and any errors. +func (mh *MetricsHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) { + // Calculate processing duration + if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok { + duration := time.Since(startTime) + mh.ProcessingTimes[notificationType] = duration + } + + // Record errors + if result != nil { + mh.ErrorCounts[notificationType]++ + } +} + +// GetMetrics returns a summary of collected metrics. +func (mh *MetricsHook) GetMetrics() map[string]interface{} { + return map[string]interface{}{ + "notification_counts": mh.NotificationCounts, + "processing_times": mh.ProcessingTimes, + "error_counts": mh.ErrorCounts, + } +} diff --git a/hitless/hitless_manager.go b/hitless/hitless_manager.go new file mode 100644 index 0000000000..26c379a5b7 --- /dev/null +++ b/hitless/hitless_manager.go @@ -0,0 +1,299 @@ +package hitless + +import ( + "context" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/internal/pool" +) + + + +// Push notification type constants for hitless upgrades +const ( + NotificationMoving = "MOVING" + NotificationMigrating = "MIGRATING" + NotificationMigrated = "MIGRATED" + NotificationFailingOver = "FAILING_OVER" + NotificationFailedOver = "FAILED_OVER" +) + +// hitlessNotificationTypes contains all notification types that hitless upgrades handles +var hitlessNotificationTypes = []string{ + NotificationMoving, + NotificationMigrating, + NotificationMigrated, + NotificationFailingOver, + NotificationFailedOver, +} + +// NotificationHook is called before and after notification processing +// PreHook can modify the notification and return false to skip processing +// PostHook is called after successful processing +type NotificationHook interface { + PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) + PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) +} + +// MovingOperationKey provides a unique key for tracking MOVING operations +// that combines sequence ID with connection identifier to handle duplicate +// sequence IDs across multiple connections to the same node. +type MovingOperationKey struct { + SeqID int64 // Sequence ID from MOVING notification + ConnID uint64 // Unique connection identifier +} + +// String returns a string representation of the key for debugging +func (k MovingOperationKey) String() string { + return fmt.Sprintf("seq:%d-conn:%d", k.SeqID, k.ConnID) +} + +// HitlessManager provides a simplified hitless upgrade functionality with hooks and atomic state. +type HitlessManager struct { + client interfaces.ClientInterface + config *Config + options interfaces.OptionsInterface + pool pool.Pooler + + // MOVING operation tracking - using sync.Map for better concurrent performance + activeMovingOps sync.Map // map[MovingOperationKey]*MovingOperation + + // Atomic state tracking - no locks needed for state queries + activeOperationCount atomic.Int64 // Number of active operations + closed atomic.Bool // Manager closed state + + // Notification hooks for extensibility + hooks []NotificationHook + hooksMu sync.RWMutex // Protects hooks slice + poolHooksRef *PoolHook +} + +// MovingOperation tracks an active MOVING operation. +type MovingOperation struct { + SeqID int64 + NewEndpoint string + StartTime time.Time + Deadline time.Time +} + +// NewHitlessManager creates a new simplified hitless manager. +func NewHitlessManager(client interfaces.ClientInterface, pool pool.Pooler, config *Config) (*HitlessManager, error) { + if client == nil { + return nil, ErrInvalidClient + } + + hm := &HitlessManager{ + client: client, + pool: pool, + options: client.GetOptions(), + config: config.Clone(), + hooks: make([]NotificationHook, 0), + } + + // Set up push notification handling + if err := hm.setupPushNotifications(); err != nil { + return nil, err + } + + return hm, nil +} + +// GetPoolHook creates a pool hook with a custom dialer. +func (hm *HitlessManager) InitPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) { + poolHook := hm.createPoolHook(baseDialer) + hm.pool.AddPoolHook(poolHook) +} + +// setupPushNotifications sets up push notification handling by registering with the client's processor. +func (hm *HitlessManager) setupPushNotifications() error { + processor := hm.client.GetPushProcessor() + if processor == nil { + return ErrInvalidClient // Client doesn't support push notifications + } + + // Create our notification handler + handler := &NotificationHandler{manager: hm} + + // Register handlers for all hitless upgrade notifications with the client's processor + for _, notificationType := range hitlessNotificationTypes { + if err := processor.RegisterHandler(notificationType, handler, true); err != nil { + return fmt.Errorf("failed to register handler for %s: %w", notificationType, err) + } + } + + return nil +} + +// TrackMovingOperationWithConnID starts a new MOVING operation with a specific connection ID. +func (hm *HitlessManager) TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error { + // Create composite key + key := MovingOperationKey{ + SeqID: seqID, + ConnID: connID, + } + + // Create MOVING operation record + movingOp := &MovingOperation{ + SeqID: seqID, + NewEndpoint: newEndpoint, + StartTime: time.Now(), + Deadline: deadline, + } + + // Use LoadOrStore for atomic check-and-set operation + if _, loaded := hm.activeMovingOps.LoadOrStore(key, movingOp); loaded { + // Duplicate MOVING notification, ignore + internal.Logger.Printf(ctx, "Duplicate MOVING operation ignored: %s", key.String()) + return nil + } + + // Increment active operation count atomically + hm.activeOperationCount.Add(1) + + return nil +} + +// UntrackOperationWithConnID completes a MOVING operation with a specific connection ID. +func (hm *HitlessManager) UntrackOperationWithConnID(seqID int64, connID uint64) { + // Create composite key + key := MovingOperationKey{ + SeqID: seqID, + ConnID: connID, + } + + // Remove from active operations atomically + if _, loaded := hm.activeMovingOps.LoadAndDelete(key); loaded { + // Decrement active operation count only if operation existed + hm.activeOperationCount.Add(-1) + } +} + +// GetActiveMovingOperations returns active operations with composite keys. +func (hm *HitlessManager) GetActiveMovingOperations() map[MovingOperationKey]*MovingOperation { + result := make(map[MovingOperationKey]*MovingOperation) + + // Iterate over sync.Map to build result + hm.activeMovingOps.Range(func(key, value interface{}) bool { + k := key.(MovingOperationKey) + op := value.(*MovingOperation) + + // Create a copy to avoid sharing references + result[k] = &MovingOperation{ + SeqID: op.SeqID, + NewEndpoint: op.NewEndpoint, + StartTime: op.StartTime, + Deadline: op.Deadline, + } + return true // Continue iteration + }) + + return result +} + +// IsHandoffInProgress returns true if any handoff is in progress. +// Uses atomic counter for lock-free operation. +func (hm *HitlessManager) IsHandoffInProgress() bool { + return hm.activeOperationCount.Load() > 0 +} + +// GetActiveOperationCount returns the number of active operations. +// Uses atomic counter for lock-free operation. +func (hm *HitlessManager) GetActiveOperationCount() int64 { + return hm.activeOperationCount.Load() +} + +// Close closes the hitless manager. +func (hm *HitlessManager) Close() error { + // Use atomic operation for thread-safe close check + if !hm.closed.CompareAndSwap(false, true) { + return nil // Already closed + } + + // Shutdown the pool hook if it exists + if hm.poolHooksRef != nil { + // Use a timeout to prevent hanging indefinitely + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + err := hm.poolHooksRef.Shutdown(shutdownCtx) + if err != nil { + // was not able to close pool hook, keep closed state false + hm.closed.Store(false) + return err + } + // Remove the pool hook from the pool + if hm.pool != nil { + hm.pool.RemovePoolHook(hm.poolHooksRef) + } + } + + // Clear all active operations + hm.activeMovingOps.Range(func(key, value interface{}) bool { + hm.activeMovingOps.Delete(key) + return true + }) + + // Reset counter + hm.activeOperationCount.Store(0) + + return nil +} + +// GetState returns current state using atomic counter for lock-free operation. +func (hm *HitlessManager) GetState() State { + if hm.activeOperationCount.Load() > 0 { + return StateMoving + } + return StateIdle +} + +// processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing. +func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) { + hm.hooksMu.RLock() + defer hm.hooksMu.RUnlock() + + currentNotification := notification + + for _, hook := range hm.hooks { + modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationType, currentNotification) + if !shouldContinue { + return modifiedNotification, false + } + currentNotification = modifiedNotification + } + + return currentNotification, true +} + +// processPostHooks calls all post-hooks with the processing result. +func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationType string, notification []interface{}, result error) { + hm.hooksMu.RLock() + defer hm.hooksMu.RUnlock() + + for _, hook := range hm.hooks { + hook.PostHook(ctx, notificationType, notification, result) + } +} + +// createPoolHook creates a pool hook with this manager already set. +func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error)) *PoolHook { + if hm.poolHooksRef != nil { + return hm.poolHooksRef + } + // Get pool size from client options for better worker defaults + poolSize := 0 + if hm.options != nil { + poolSize = hm.options.GetPoolSize() + } + + hm.poolHooksRef = NewPoolHookWithPoolSize(baseDialer, hm.options.GetNetwork(), hm.config, hm, poolSize) + hm.poolHooksRef.SetPool(hm.pool) + + return hm.poolHooksRef +} diff --git a/hitless/hitless_manager_test.go b/hitless/hitless_manager_test.go new file mode 100644 index 0000000000..b1f55bf35a --- /dev/null +++ b/hitless/hitless_manager_test.go @@ -0,0 +1,260 @@ +package hitless + +import ( + "context" + "net" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/interfaces" +) + +// MockClient implements interfaces.ClientInterface for testing +type MockClient struct { + options interfaces.OptionsInterface +} + +func (mc *MockClient) GetOptions() interfaces.OptionsInterface { + return mc.options +} + +func (mc *MockClient) GetPushProcessor() interfaces.NotificationProcessor { + return &MockPushProcessor{} +} + +// MockPushProcessor implements interfaces.NotificationProcessor for testing +type MockPushProcessor struct{} + +func (mpp *MockPushProcessor) RegisterHandler(notificationType string, handler interface{}, protected bool) error { + return nil +} + +func (mpp *MockPushProcessor) UnregisterHandler(pushNotificationName string) error { + return nil +} + +func (mpp *MockPushProcessor) GetHandler(pushNotificationName string) interface{} { + return nil +} + +// MockOptions implements interfaces.OptionsInterface for testing +type MockOptions struct{} + +func (mo *MockOptions) GetReadTimeout() time.Duration { + return 5 * time.Second +} + +func (mo *MockOptions) GetWriteTimeout() time.Duration { + return 5 * time.Second +} + +func (mo *MockOptions) GetAddr() string { + return "localhost:6379" +} + +func (mo *MockOptions) IsTLSEnabled() bool { + return false +} + +func (mo *MockOptions) GetProtocol() int { + return 3 // RESP3 +} + +func (mo *MockOptions) GetPoolSize() int { + return 10 +} + +func (mo *MockOptions) GetNetwork() string { + return "tcp" +} + +func (mo *MockOptions) NewDialer() func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + return nil, nil + } +} + +func TestHitlessManagerRefactoring(t *testing.T) { + t.Run("AtomicStateTracking", func(t *testing.T) { + config := DefaultConfig() + client := &MockClient{options: &MockOptions{}} + + manager, err := NewHitlessManager(client, nil, config) + if err != nil { + t.Fatalf("Failed to create hitless manager: %v", err) + } + defer manager.Close() + + // Test initial state + if manager.IsHandoffInProgress() { + t.Error("Expected no handoff in progress initially") + } + + if manager.GetActiveOperationCount() != 0 { + t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount()) + } + + if manager.GetState() != StateIdle { + t.Errorf("Expected StateIdle, got %v", manager.GetState()) + } + + // Add an operation + ctx := context.Background() + deadline := time.Now().Add(30 * time.Second) + err = manager.TrackMovingOperationWithConnID(ctx, "new-endpoint:6379", deadline, 12345, 1) + if err != nil { + t.Fatalf("Failed to track operation: %v", err) + } + + // Test state after adding operation + if !manager.IsHandoffInProgress() { + t.Error("Expected handoff in progress after adding operation") + } + + if manager.GetActiveOperationCount() != 1 { + t.Errorf("Expected 1 active operation, got %d", manager.GetActiveOperationCount()) + } + + if manager.GetState() != StateMoving { + t.Errorf("Expected StateMoving, got %v", manager.GetState()) + } + + // Remove the operation + manager.UntrackOperationWithConnID(12345, 1) + + // Test state after removing operation + if manager.IsHandoffInProgress() { + t.Error("Expected no handoff in progress after removing operation") + } + + if manager.GetActiveOperationCount() != 0 { + t.Errorf("Expected 0 active operations, got %d", manager.GetActiveOperationCount()) + } + + if manager.GetState() != StateIdle { + t.Errorf("Expected StateIdle, got %v", manager.GetState()) + } + }) + + t.Run("SyncMapPerformance", func(t *testing.T) { + config := DefaultConfig() + client := &MockClient{options: &MockOptions{}} + + manager, err := NewHitlessManager(client, nil, config) + if err != nil { + t.Fatalf("Failed to create hitless manager: %v", err) + } + defer manager.Close() + + ctx := context.Background() + deadline := time.Now().Add(30 * time.Second) + + // Test concurrent operations + const numOps = 100 + for i := 0; i < numOps; i++ { + err := manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, int64(i), uint64(i)) + if err != nil { + t.Fatalf("Failed to track operation %d: %v", i, err) + } + } + + if manager.GetActiveOperationCount() != numOps { + t.Errorf("Expected %d active operations, got %d", numOps, manager.GetActiveOperationCount()) + } + + // Test GetActiveMovingOperations + operations := manager.GetActiveMovingOperations() + if len(operations) != numOps { + t.Errorf("Expected %d operations in map, got %d", numOps, len(operations)) + } + + // Remove all operations + for i := 0; i < numOps; i++ { + manager.UntrackOperationWithConnID(int64(i), uint64(i)) + } + + if manager.GetActiveOperationCount() != 0 { + t.Errorf("Expected 0 active operations after cleanup, got %d", manager.GetActiveOperationCount()) + } + }) + + t.Run("DuplicateOperationHandling", func(t *testing.T) { + config := DefaultConfig() + client := &MockClient{options: &MockOptions{}} + + manager, err := NewHitlessManager(client, nil, config) + if err != nil { + t.Fatalf("Failed to create hitless manager: %v", err) + } + defer manager.Close() + + ctx := context.Background() + deadline := time.Now().Add(30 * time.Second) + + // Add operation + err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1) + if err != nil { + t.Fatalf("Failed to track operation: %v", err) + } + + // Try to add duplicate operation + err = manager.TrackMovingOperationWithConnID(ctx, "endpoint:6379", deadline, 12345, 1) + if err != nil { + t.Fatalf("Duplicate operation should not return error: %v", err) + } + + // Should still have only 1 operation + if manager.GetActiveOperationCount() != 1 { + t.Errorf("Expected 1 active operation after duplicate, got %d", manager.GetActiveOperationCount()) + } + }) + + t.Run("NotificationTypeConstants", func(t *testing.T) { + // Test that constants are properly defined + expectedTypes := []string{ + NotificationMoving, + NotificationMigrating, + NotificationMigrated, + NotificationFailingOver, + NotificationFailedOver, + } + + if len(hitlessNotificationTypes) != len(expectedTypes) { + t.Errorf("Expected %d notification types, got %d", len(expectedTypes), len(hitlessNotificationTypes)) + } + + // Test that all expected types are present + typeMap := make(map[string]bool) + for _, t := range hitlessNotificationTypes { + typeMap[t] = true + } + + for _, expected := range expectedTypes { + if !typeMap[expected] { + t.Errorf("Expected notification type %s not found in hitlessNotificationTypes", expected) + } + } + + // Test that hitlessNotificationTypes contains all expected constants + expectedConstants := []string{ + NotificationMoving, + NotificationMigrating, + NotificationMigrated, + NotificationFailingOver, + NotificationFailedOver, + } + + for _, expected := range expectedConstants { + found := false + for _, actual := range hitlessNotificationTypes { + if actual == expected { + found = true + break + } + } + if !found { + t.Errorf("Expected constant %s not found in hitlessNotificationTypes", expected) + } + } + }) +} diff --git a/hitless/hooks.go b/hitless/hooks.go new file mode 100644 index 0000000000..7e84e032d2 --- /dev/null +++ b/hitless/hooks.go @@ -0,0 +1,48 @@ +package hitless + +import ( + "context" + + "github.com/redis/go-redis/v9/internal" +) + +// LoggingHook is an example hook implementation that logs all notifications. +type LoggingHook struct { + LogLevel int +} + +// PreHook logs the notification before processing and allows modification. +func (lh *LoggingHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) { + if lh.LogLevel >= 2 { // Info level + internal.Logger.Printf(ctx, "hitless: processing %s notification: %v", notificationType, notification) + } + return notification, true // Continue processing with unmodified notification +} + +// PostHook logs the result after processing. +func (lh *LoggingHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) { + if result != nil && lh.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, "hitless: %s notification processing failed: %v", notificationType, result) + } else if lh.LogLevel >= 3 { // Debug level + internal.Logger.Printf(ctx, "hitless: %s notification processed successfully", notificationType) + } +} + +// FilterHook is an example hook that can filter out certain notifications. +type FilterHook struct { + BlockedTypes map[string]bool +} + +// PreHook filters notifications based on type. +func (fh *FilterHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) { + if fh.BlockedTypes[notificationType] { + internal.Logger.Printf(ctx, "hitless: filtering out %s notification", notificationType) + return notification, false // Skip processing + } + return notification, true +} + +// PostHook does nothing for filter hook. +func (fh *FilterHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) { + // No post-processing needed for filter hook +} diff --git a/hitless/notification_handler.go b/hitless/notification_handler.go new file mode 100644 index 0000000000..933e0ea68e --- /dev/null +++ b/hitless/notification_handler.go @@ -0,0 +1,247 @@ +package hitless + +import ( + "context" + "fmt" + "strconv" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/interfaces" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" +) + +// NotificationHandler handles push notifications for the simplified manager. +type NotificationHandler struct { + manager *HitlessManager +} + +// HandlePushNotification processes push notifications with hook support. +func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) == 0 { + return ErrInvalidNotification + } + + notificationType, ok := notification[0].(string) + if !ok { + return ErrInvalidNotification + } + + // Process pre-hooks - they can modify the notification or skip processing + modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, notificationType, notification) + if !shouldContinue { + return nil // Hooks decided to skip processing + } + + var err error + switch notificationType { + case NotificationMoving: + err = snh.handleMoving(ctx, handlerCtx, modifiedNotification) + case NotificationMigrating: + err = snh.handleMigrating(ctx, handlerCtx, modifiedNotification) + case NotificationMigrated: + err = snh.handleMigrated(ctx, handlerCtx, modifiedNotification) + case NotificationFailingOver: + err = snh.handleFailingOver(ctx, handlerCtx, modifiedNotification) + case NotificationFailedOver: + err = snh.handleFailedOver(ctx, handlerCtx, modifiedNotification) + default: + // Ignore other notification types (e.g., pub/sub messages) + err = nil + } + + // Process post-hooks with the result + snh.manager.processPostHooks(ctx, notificationType, modifiedNotification, err) + + return err +} + +// handleMoving processes MOVING notifications. +// ["MOVING", seqNum, timeS, endpoint] - per-connection handoff +func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + if len(notification) < 3 { + return ErrInvalidNotification + } + seqIDStr, ok := notification[1].(string) + if !ok { + return ErrInvalidNotification + } + + seqID, err := strconv.ParseInt(seqIDStr, 10, 64) + if err != nil { + return ErrInvalidNotification + } + + // Extract timeS + timeSStr, ok := notification[2].(string) + if !ok { + return ErrInvalidNotification + } + + timeS, err := strconv.ParseInt(timeSStr, 10, 64) + if err != nil { + return ErrInvalidNotification + } + + newEndpoint := "" + if len(notification) > 3 { + // Extract new endpoint + newEndpoint, ok = notification[3].(string) + if !ok { + return ErrInvalidNotification + } + } + + // Get the connection that received this notification + conn := handlerCtx.Conn + if conn == nil { + return ErrInvalidNotification + } + + // Type assert to get the underlying pool connection + var poolConn *pool.Conn + if connAdapter, ok := conn.(interface{ GetPoolConn() *pool.Conn }); ok { + poolConn = connAdapter.GetPoolConn() + } else if pc, ok := conn.(*pool.Conn); ok { + poolConn = pc + } else { + return ErrInvalidNotification + } + + deadline := time.Now().Add(time.Duration(timeS) * time.Second) + // If newEndpoint is empty, we should schedule a handoff to the current endpoint in timeS/2 seconds + if newEndpoint == "" || newEndpoint == internal.RedisNull { + // same as current endpoint + newEndpoint = snh.manager.options.GetAddr() + // delay the handoff for timeS/2 seconds to the same endpoint + // do this in a goroutine to avoid blocking the notification handler + go func() { + time.Sleep(time.Duration(timeS/2) * time.Second) + if poolConn == nil || poolConn.IsClosed() { + return + } + if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil { + // Log error but don't fail the goroutine + internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err) + } + }() + return nil + } + + return snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline) +} + +func (snh *NotificationHandler) markConnForHandoff(conn *pool.Conn, newEndpoint string, seqID int64, deadline time.Time) error { + if err := conn.MarkForHandoff(newEndpoint, seqID); err != nil { + // Connection is already marked for handoff, which is acceptable + // This can happen if multiple MOVING notifications are received for the same connection + return nil + } + // Optionally track in hitless manager for monitoring/debugging + if snh.manager != nil { + connID := conn.GetID() + + // Track the operation (ignore errors since this is optional) + _ = snh.manager.TrackMovingOperationWithConnID(context.Background(), newEndpoint, deadline, seqID, connID) + } else { + return fmt.Errorf("hitless: manager not initialized") + } + return nil +} + +// handleMigrating processes MIGRATING notifications. +func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // MIGRATING notifications indicate that a connection is about to be migrated + // Apply relaxed timeouts to the specific connection that received this notification + if len(notification) < 2 { + return ErrInvalidNotification + } + + // Get the connection from handler context and type assert to connectionAdapter + if handlerCtx.Conn == nil { + return ErrInvalidNotification + } + + // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout + connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) + if !ok { + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + return nil +} + +// handleMigrated processes MIGRATED notifications. +func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // MIGRATED notifications indicate that a connection migration has completed + // Restore normal timeouts for the specific connection that received this notification + if len(notification) < 2 { + return ErrInvalidNotification + } + + // Get the connection from handler context and type assert to connectionAdapter + if handlerCtx.Conn == nil { + return ErrInvalidNotification + } + + // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout + connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) + if !ok { + return ErrInvalidNotification + } + + // Clear relaxed timeout for this specific connection + connAdapter.ClearRelaxedTimeout() + return nil +} + +// handleFailingOver processes FAILING_OVER notifications. +func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // FAILING_OVER notifications indicate that a connection is about to failover + // Apply relaxed timeouts to the specific connection that received this notification + if len(notification) < 2 { + return ErrInvalidNotification + } + + // Get the connection from handler context and type assert to connectionAdapter + if handlerCtx.Conn == nil { + return ErrInvalidNotification + } + + // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout + connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) + if !ok { + return ErrInvalidNotification + } + + // Apply relaxed timeout to this specific connection + connAdapter.SetRelaxedTimeout(snh.manager.config.RelaxedTimeout, snh.manager.config.RelaxedTimeout) + return nil +} + +// handleFailedOver processes FAILED_OVER notifications. +func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { + // FAILED_OVER notifications indicate that a connection failover has completed + // Restore normal timeouts for the specific connection that received this notification + if len(notification) < 2 { + return ErrInvalidNotification + } + + // Get the connection from handler context and type assert to connectionAdapter + if handlerCtx.Conn == nil { + return ErrInvalidNotification + } + + // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout + connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) + if !ok { + return ErrInvalidNotification + } + + // Clear relaxed timeout for this specific connection + connAdapter.ClearRelaxedTimeout() + return nil +} diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go new file mode 100644 index 0000000000..eb3eaf905d --- /dev/null +++ b/hitless/pool_hook.go @@ -0,0 +1,477 @@ +package hitless + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" +) + +// HitlessManagerInterface defines the interface for completing handoff operations +type HitlessManagerInterface interface { + TrackMovingOperationWithConnID(ctx context.Context, newEndpoint string, deadline time.Time, seqID int64, connID uint64) error + UntrackOperationWithConnID(seqID int64, connID uint64) +} + +// HandoffRequest represents a request to handoff a connection to a new endpoint +type HandoffRequest struct { + Conn *pool.Conn + ConnID uint64 // Unique connection identifier + Endpoint string + SeqID int64 + Pool pool.Pooler // Pool to remove connection from on failure +} + +// PoolHook implements pool.PoolHook for Redis-specific connection handling +// with hitless upgrade support. +type PoolHook struct { + // Base dialer for creating connections to new endpoints during handoffs + // args are network and address + baseDialer func(context.Context, string, string) (net.Conn, error) + + // Network type (e.g., "tcp", "unix") + network string + + // Event-driven handoff support + handoffQueue chan HandoffRequest // Queue for handoff requests + shutdown chan struct{} // Shutdown signal + shutdownOnce sync.Once // Ensure clean shutdown + workerWg sync.WaitGroup // Track worker goroutines + + // On-demand worker management + maxWorkers int + activeWorkers int32 // Atomic counter for active workers + workerTimeout time.Duration // How long workers wait for work before exiting + + // Simple state tracking + pending sync.Map // map[uint64]int64 (connID -> seqID) + + // Configuration for the hitless upgrade + config *Config + + // Hitless manager for operation completion tracking + hitlessManager HitlessManagerInterface + + // Pool interface for removing connections on handoff failure + pool pool.Pooler +} + +// NewPoolHook creates a new pool hook +func NewPoolHook(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface) *PoolHook { + return NewPoolHookWithPoolSize(baseDialer, network, config, hitlessManager, 0) +} + +// NewPoolHookWithPoolSize creates a new pool hook with pool size for better worker defaults +func NewPoolHookWithPoolSize(baseDialer func(context.Context, string, string) (net.Conn, error), network string, config *Config, hitlessManager HitlessManagerInterface, poolSize int) *PoolHook { + // Apply defaults to any missing configuration fields, using pool size for worker calculations + config = config.ApplyDefaultsWithPoolSize(poolSize) + + ph := &PoolHook{ + // baseDialer is used to create connections to new endpoints during handoffs + baseDialer: baseDialer, + network: network, + // handoffQueue is a buffered channel for queuing handoff requests + handoffQueue: make(chan HandoffRequest, config.HandoffQueueSize), + // shutdown is a channel for signaling shutdown + shutdown: make(chan struct{}), + maxWorkers: config.MaxWorkers, + activeWorkers: 0, // Start with no workers - create on demand + workerTimeout: 30 * time.Second, // Workers exit after 30s of inactivity + config: config, + // Hitless manager for operation completion tracking + hitlessManager: hitlessManager, + } + + // No upfront worker creation - workers are created on demand + + return ph +} + +// SetPool sets the pool interface for removing connections on handoff failure +func (ph *PoolHook) SetPool(pooler pool.Pooler) { + ph.pool = pooler +} + +// GetCurrentWorkers returns the current number of active workers (for testing) +func (ph *PoolHook) GetCurrentWorkers() int { + return int(atomic.LoadInt32(&ph.activeWorkers)) +} + +// GetScaleLevel returns 1 if workers are active, 0 if none (for testing compatibility) +func (ph *PoolHook) GetScaleLevel() int { + if atomic.LoadInt32(&ph.activeWorkers) > 0 { + return 1 + } + return 0 +} + +// IsHandoffPending returns true if the given connection has a pending handoff +func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool { + _, pending := ph.pending.Load(conn.GetID()) + return pending +} + +// OnGet is called when a connection is retrieved from the pool +func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, isNewConn bool) error { + // NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is + // in a handoff state at the moment. + + // Check if connection is usable (not in a handoff state) + // Should not happen since the pool will not return a connection that is not usable. + if !conn.IsUsable() { + return ErrConnectionMarkedForHandoff + } + + // Check if connection is marked for handoff, which means it will be queued for handoff on put. + if conn.ShouldHandoff() { + return ErrConnectionMarkedForHandoff + } + + return nil +} + +// OnPut is called when a connection is returned to the pool +func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) { + // first check if we should handoff for faster rejection + if conn.ShouldHandoff() { + // check pending handoff to not queue the same connection twice + _, hasPendingHandoff := ph.pending.Load(conn.GetID()) + if !hasPendingHandoff { + // Check for empty endpoint first (synchronous check) + if conn.GetHandoffEndpoint() == "" { + conn.ClearHandoffState() + } else { + if err := ph.queueHandoff(conn); err != nil { + // Failed to queue handoff, remove the connection + internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err) + return false, true, nil // Don't pool, remove connection, no error to caller + } + + // Check if handoff was already processed by a worker before we can mark it as queued + if !conn.ShouldHandoff() { + // Handoff was already processed - this is normal and the connection should be pooled + return true, false, nil + } + + if err := conn.MarkQueuedForHandoff(); err != nil { + // If marking fails, check if handoff was processed in the meantime + if !conn.ShouldHandoff() { + // Handoff was processed - this is normal, pool the connection + return true, false, nil + } + // Other error - remove the connection + return false, true, nil + } + return true, false, nil + } + } + } + // Default: pool the connection + return true, false, nil +} + +// ensureWorkerAvailable ensures at least one worker is available to process requests +// Creates a new worker if needed and under the max limit +func (ph *PoolHook) ensureWorkerAvailable() { + select { + case <-ph.shutdown: + return + default: + // Check if we need a new worker + currentWorkers := atomic.LoadInt32(&ph.activeWorkers) + if currentWorkers < int32(ph.maxWorkers) { + // Try to create a new worker (atomic increment to prevent race) + if atomic.CompareAndSwapInt32(&ph.activeWorkers, currentWorkers, currentWorkers+1) { + ph.workerWg.Add(1) + go ph.onDemandWorker() + } + } + } +} + +// onDemandWorker processes handoff requests and exits when idle +func (ph *PoolHook) onDemandWorker() { + defer func() { + // Decrement active worker count when exiting + atomic.AddInt32(&ph.activeWorkers, -1) + ph.workerWg.Done() + }() + + for { + select { + case request := <-ph.handoffQueue: + // Check for shutdown before processing + select { + case <-ph.shutdown: + // Clean up the request before exiting + ph.pending.Delete(request.ConnID) + return + default: + // Process the request + ph.processHandoffRequest(request) + } + + case <-time.After(ph.workerTimeout): + // Worker has been idle for too long, exit to save resources + if ph.config != nil && ph.config.LogLevel >= 3 { // Debug level + internal.Logger.Printf(context.Background(), + "hitless: worker exiting due to inactivity timeout (%v)", ph.workerTimeout) + } + return + + case <-ph.shutdown: + return + } + } +} + +// processHandoffRequest processes a single handoff request +func (ph *PoolHook) processHandoffRequest(request HandoffRequest) { + // Remove from pending map + defer ph.pending.Delete(request.Conn.GetID()) + + // Create a context with handoff timeout from config + handoffTimeout := 30 * time.Second // Default fallback + if ph.config != nil && ph.config.HandoffTimeout > 0 { + handoffTimeout = ph.config.HandoffTimeout + } + ctx, cancel := context.WithTimeout(context.Background(), handoffTimeout) + defer cancel() + + // Create a context that also respects the shutdown signal + shutdownCtx, shutdownCancel := context.WithCancel(ctx) + defer shutdownCancel() + + // Monitor shutdown signal in a separate goroutine + go func() { + select { + case <-ph.shutdown: + shutdownCancel() + case <-shutdownCtx.Done(): + } + }() + + // Perform the handoff with cancellable context + err := ph.performConnectionHandoffWithPool(shutdownCtx, request.Conn, request.Pool) + + // If handoff failed, restore the handoff state for potential retry + if err != nil { + request.Conn.RestoreHandoffState() + internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY: %v", err) + } + + // No need for scale down scheduling with on-demand workers + // Workers automatically exit when idle +} + +// queueHandoff queues a handoff request for processing +// if err is returned, connection will be removed from pool +func (ph *PoolHook) queueHandoff(conn *pool.Conn) error { + // Create handoff request + request := HandoffRequest{ + Conn: conn, + ConnID: conn.GetID(), + Endpoint: conn.GetHandoffEndpoint(), + SeqID: conn.GetMovingSeqID(), + Pool: ph.pool, // Include pool for connection removal on failure + } + + select { + // priority to shutdown + case <-ph.shutdown: + return errors.New("shutdown") + default: + select { + case <-ph.shutdown: + return errors.New("shutdown") + case ph.handoffQueue <- request: + // Store in pending map + ph.pending.Store(request.ConnID, request.SeqID) + // Ensure we have a worker to process this request + ph.ensureWorkerAvailable() + return nil + default: + // Queue is full - log and attempt scaling + queueLen := len(ph.handoffQueue) + queueCap := cap(ph.handoffQueue) + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(context.Background(), + "hitless: handoff queue is full (%d/%d), attempting timeout queuing and scaling workers", + queueLen, queueCap) + } + } + } + + // Ensure we have workers available to handle the load + ph.ensureWorkerAvailable() + return errors.New("queue full") +} + +// performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure +// if err is returned, connection will be removed from pool +func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn *pool.Conn, pooler pool.Pooler) error { + // Clear handoff state after successful handoff + seqID := conn.GetMovingSeqID() + connID := conn.GetID() + + // Notify hitless manager of completion if available + if ph.hitlessManager != nil { + defer ph.hitlessManager.UntrackOperationWithConnID(seqID, connID) + } + + newEndpoint := conn.GetHandoffEndpoint() + if newEndpoint == "" { + // TODO(hitless): Handle by performing the handoff to the current endpoint in N seconds, + // Where N is the time in the moving notification... + // For now, clear the handoff state and return + conn.ClearHandoffState() + return nil + } + + retries := conn.IncrementAndGetHandoffRetries(1) + maxRetries := 3 // Default fallback + if ph.config != nil { + maxRetries = ph.config.MaxHandoffRetries + } + + if retries > maxRetries { + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: reached max retries (%d) for handoff of connection %d to %s", + maxRetries, conn.GetID(), conn.GetHandoffEndpoint()) + } + err := ErrMaxHandoffRetriesReached + if pooler != nil { + go pooler.Remove(ctx, conn, err) + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: removed connection %d from pool due to max handoff retries reached", + conn.GetID()) + } + } else { + go conn.Close() + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", + conn.GetID(), err) + } + } + return err + } + + // Create endpoint-specific dialer + endpointDialer := ph.createEndpointDialer(newEndpoint) + + // Create new connection to the new endpoint + newNetConn, err := endpointDialer(ctx) + if err != nil { + // TODO(hitless): retry + // This is the only case where we should retry the handoff request + // Should we do anything else other than return the error? + return err + } + + // Get the old connection + oldConn := conn.GetNetConn() + + // Replace the connection and execute initialization + err = conn.SetNetConnAndInitConn(ctx, newNetConn) + if err != nil { + // Remove the connection from the pool since it's in a bad state + if pooler != nil { + // Use pool.Pooler interface directly - no adapter needed + go pooler.Remove(ctx, conn, err) + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: removed connection %d from pool due to handoff initialization failure: %v", + conn.GetID(), err) + } + } else { + go conn.Close() + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", + conn.GetID(), err) + } + } + + // Keep the handoff state for retry + return err + } + defer func() { + if oldConn != nil { + oldConn.Close() + } + }() + + conn.ClearHandoffState() + + // Apply relaxed timeout to the new connection for the configured post-handoff duration + // This gives the new connection more time to handle operations during cluster transition + if ph.config != nil && ph.config.PostHandoffRelaxedDuration > 0 { + relaxedTimeout := ph.config.RelaxedTimeout + postHandoffDuration := ph.config.PostHandoffRelaxedDuration + + // Set relaxed timeout with deadline - no background goroutine needed + deadline := time.Now().Add(postHandoffDuration) + conn.SetRelaxedTimeoutWithDeadline(relaxedTimeout, relaxedTimeout, deadline) + + if ph.config.LogLevel >= 2 { // Info level + internal.Logger.Printf(context.Background(), + "hitless: applied post-handoff relaxed timeout (%v) until %v for connection %d", + relaxedTimeout, deadline.Format("15:04:05.000"), connID) + } + } + + return nil +} + +// createEndpointDialer creates a dialer function that connects to a specific endpoint +func (ph *PoolHook) createEndpointDialer(endpoint string) func(context.Context) (net.Conn, error) { + return func(ctx context.Context) (net.Conn, error) { + // Parse endpoint to extract host and port + host, port, err := net.SplitHostPort(endpoint) + if err != nil { + // If no port specified, assume default Redis port + host = endpoint + if port == "" { + port = "6379" + } + } + + // Use the base dialer to connect to the new endpoint + return ph.baseDialer(ctx, ph.network, net.JoinHostPort(host, port)) + } +} + +// Shutdown gracefully shuts down the processor, waiting for workers to complete +func (ph *PoolHook) Shutdown(ctx context.Context) error { + ph.shutdownOnce.Do(func() { + close(ph.shutdown) + + // No timers to clean up with on-demand workers + }) + + // Wait for workers to complete + done := make(chan struct{}) + go func() { + ph.workerWg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff +// and should not be used until the handoff is complete +var ErrConnectionMarkedForHandoff = errors.New("connection marked for handoff") diff --git a/hitless/pool_hook_test.go b/hitless/pool_hook_test.go new file mode 100644 index 0000000000..6dbb7a0472 --- /dev/null +++ b/hitless/pool_hook_test.go @@ -0,0 +1,959 @@ +package hitless + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/redis/go-redis/v9/internal/pool" +) + +// mockNetConn implements net.Conn for testing +type mockNetConn struct { + addr string + shouldFailInit bool +} + +func (m *mockNetConn) Read(b []byte) (n int, err error) { return 0, nil } +func (m *mockNetConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (m *mockNetConn) Close() error { return nil } +func (m *mockNetConn) LocalAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) RemoteAddr() net.Addr { return &mockAddr{m.addr} } +func (m *mockNetConn) SetDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetReadDeadline(t time.Time) error { return nil } +func (m *mockNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type mockAddr struct { + addr string +} + +func (m *mockAddr) Network() string { return "tcp" } +func (m *mockAddr) String() string { return m.addr } + +// createMockPoolConnection creates a mock pool connection for testing +func createMockPoolConnection() *pool.Conn { + mockNetConn := &mockNetConn{addr: "test:6379"} + conn := pool.NewConn(mockNetConn) + conn.SetUsable(true) // Make connection usable for testing + return conn +} + +// mockPool implements pool.Pooler for testing +type mockPool struct { + removedConnections map[uint64]bool + mu sync.Mutex +} + +func (mp *mockPool) NewConn(ctx context.Context) (*pool.Conn, error) { + return nil, errors.New("not implemented") +} + +func (mp *mockPool) CloseConn(conn *pool.Conn) error { + return nil +} + +func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) { + return nil, errors.New("not implemented") +} + +func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) { + // Not implemented for testing +} + +func (mp *mockPool) Remove(ctx context.Context, conn *pool.Conn, reason error) { + mp.mu.Lock() + defer mp.mu.Unlock() + + // Use pool.Conn directly - no adapter needed + mp.removedConnections[conn.GetID()] = true +} + +// WasRemoved safely checks if a connection was removed from the pool +func (mp *mockPool) WasRemoved(connID uint64) bool { + mp.mu.Lock() + defer mp.mu.Unlock() + return mp.removedConnections[connID] +} + +func (mp *mockPool) Len() int { + return 0 +} + +func (mp *mockPool) IdleLen() int { + return 0 +} + +func (mp *mockPool) Stats() *pool.Stats { + return &pool.Stats{} +} + +func (mp *mockPool) AddPoolHook(hook pool.PoolHook) { + // Mock implementation - do nothing +} + +func (mp *mockPool) RemovePoolHook(hook pool.PoolHook) { + // Mock implementation - do nothing +} + +func (mp *mockPool) Close() error { + return nil +} + +// TestConnectionHook tests the Redis connection processor functionality +func TestConnectionHook(t *testing.T) { + // Create a base dialer for testing + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) { + config := &Config{ + Mode: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MaxWorkers: 1, // Use only 1 worker to ensure synchronization + HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue + MaxHandoffRetries: 3, + LogLevel: 2, + } + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Verify connection is marked for handoff + if !conn.ShouldHandoff() { + t.Fatal("Connection should be marked for handoff") + } + // Set a mock initialization function with synchronization + initConnCalled := make(chan bool, 1) + proceedWithInit := make(chan bool, 1) + initConnFunc := func(ctx context.Context, cn *pool.Conn) error { + select { + case initConnCalled <- true: + default: + } + // Wait for test to proceed + <-proceedWithInit + return nil + } + conn.SetInitConnFunc(initConnFunc) + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error: %v", err) + } + + // Should pool the connection immediately (handoff queued) + if !shouldPool { + t.Error("Connection should be pooled immediately with event-driven handoff") + } + if shouldRemove { + t.Error("Connection should not be removed when queuing handoff") + } + + // Wait for initialization to be called (indicates handoff started) + select { + case <-initConnCalled: + // Good, initialization was called + case <-time.After(1 * time.Second): + t.Fatal("Timeout waiting for initialization function to be called") + } + + // Connection should be in pending map while initialization is blocked + if _, pending := processor.pending.Load(conn.GetID()); !pending { + t.Error("Connection should be in pending handoffs map") + } + + // Allow initialization to proceed + proceedWithInit <- true + + // Wait for handoff to complete with proper timeout and polling + timeout := time.After(2 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.pending.Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify handoff completed (removed from pending map) + if _, pending := processor.pending.Load(conn); pending { + t.Error("Connection should be removed from pending map after handoff") + } + + // Verify connection is usable again + if !conn.IsUsable() { + t.Error("Connection should be usable after successful handoff") + } + + // Verify handoff state is cleared + if conn.ShouldHandoff() { + t.Error("Connection should not be marked for handoff after completion") + } + }) + + t.Run("HandoffNotNeeded", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + conn := createMockPoolConnection() + // Don't mark for handoff + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error when handoff not needed: %v", err) + } + + // Should pool the connection normally + if !shouldPool { + t.Error("Connection should be pooled when no handoff needed") + } + if shouldRemove { + t.Error("Connection should not be removed when no handoff needed") + } + }) + + t.Run("EmptyEndpoint", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("", 12345); err != nil { // Empty endpoint + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error with empty endpoint: %v", err) + } + + // Should pool the connection (empty endpoint clears state) + if !shouldPool { + t.Error("Connection should be pooled after clearing empty endpoint") + } + if shouldRemove { + t.Error("Connection should not be removed after clearing empty endpoint") + } + + // State should be cleared + if conn.ShouldHandoff() { + t.Error("Connection should not be marked for handoff after clearing empty endpoint") + } + }) + + t.Run("EventDrivenHandoffDialerError", func(t *testing.T) { + // Create a failing base dialer + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("dial failed") + } + + config := &Config{ + Mode: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, + HandoffTimeout: 1 * time.Second, // Shorter timeout for faster test + LogLevel: 2, + } + processor := NewPoolHook(failingDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not return error to caller: %v", err) + } + + // Should pool the connection initially (handoff queued) + if !shouldPool { + t.Error("Connection should be pooled initially with event-driven handoff") + } + if shouldRemove { + t.Error("Connection should not be removed when queuing handoff") + } + + // Wait for handoff to complete and fail with proper timeout and polling + // Use longer timeout to account for handoff timeout + processing time + timeout := time.After(5 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + // wait for handoff to start + time.Sleep(100 * time.Millisecond) + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for failed handoff to complete") + case <-ticker.C: + if _, pending := processor.pending.Load(conn.GetID()); !pending { + handoffCompleted = true + } + } + } + + // Connection should be removed from pending map after failed handoff + if _, pending := processor.pending.Load(conn.GetID()); pending { + t.Error("Connection should be removed from pending map after failed handoff") + } + + // Handoff state should still be set (since handoff failed) + if !conn.ShouldHandoff() { + t.Error("Connection should still be marked for handoff after failed handoff") + } + }) + + t.Run("BufferedDataRESP2", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + conn := createMockPoolConnection() + + // For this test, we'll just verify the logic works for connections without buffered data + // The actual buffered data detection is handled by the pool's connection health check + // which is outside the scope of the Redis connection processor + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error: %v", err) + } + + // Should pool the connection normally (no buffered data in mock) + if !shouldPool { + t.Error("Connection should be pooled when no buffered data") + } + if shouldRemove { + t.Error("Connection should not be removed when no buffered data") + } + }) + + t.Run("OnGet", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + conn := createMockPoolConnection() + + ctx := context.Background() + err := processor.OnGet(ctx, conn, false) + if err != nil { + t.Errorf("OnGet should not error for normal connection: %v", err) + } + }) + + t.Run("OnGetWithPendingHandoff", func(t *testing.T) { + config := &Config{ + Mode: MaintNotificationsAuto, + EndpointType: EndpointTypeAuto, + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, // Explicit queue size to avoid 0-size queue + LogLevel: 2, + } + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + + // Simulate a pending handoff by marking for handoff and queuing + conn.MarkForHandoff("new-endpoint:6379", 12345) + processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID + conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + + ctx := context.Background() + err := processor.OnGet(ctx, conn, false) + if err != ErrConnectionMarkedForHandoff { + t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) + } + + // Clean up + processor.pending.Delete(conn) + }) + + t.Run("EventDrivenStateManagement", func(t *testing.T) { + processor := NewPoolHook(baseDialer, "tcp", nil, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + + // Test initial state - no pending handoffs + if _, pending := processor.pending.Load(conn); pending { + t.Error("New connection should not have pending handoffs") + } + + // Test adding to pending map + conn.MarkForHandoff("new-endpoint:6379", 12345) + processor.pending.Store(conn.GetID(), int64(12345)) // Store connID -> seqID + conn.MarkQueuedForHandoff() // Mark as queued (sets usable=false) + + if _, pending := processor.pending.Load(conn.GetID()); !pending { + t.Error("Connection should be in pending map") + } + + // Test OnGet with pending handoff + ctx := context.Background() + err := processor.OnGet(ctx, conn, false) + if err != ErrConnectionMarkedForHandoff { + t.Error("Should return ErrConnectionMarkedForHandoff for pending connection") + } + + // Test removing from pending map and clearing handoff state + processor.pending.Delete(conn) + if _, pending := processor.pending.Load(conn); pending { + t.Error("Connection should be removed from pending map") + } + + // Clear handoff state to simulate completed handoff + conn.ClearHandoffState() + conn.SetUsable(true) // Make connection usable again + + // Test OnGet without pending handoff + err = processor.OnGet(ctx, conn, false) + if err != nil { + t.Errorf("Should not return error for non-pending connection: %v", err) + } + }) + + t.Run("EventDrivenQueueOptimization", func(t *testing.T) { + // Create processor with small queue to test optimization features + config := &Config{ + MaxWorkers: 3, + HandoffQueueSize: 2, + MaxHandoffRetries: 3, // Small queue to trigger optimizations + LogLevel: 3, // Debug level to see optimization logs + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Add small delay to simulate network latency + time.Sleep(10 * time.Millisecond) + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Create multiple connections that need handoff to fill the queue + connections := make([]*pool.Conn, 5) + for i := 0; i < 5; i++ { + connections[i] = createMockPoolConnection() + if err := connections[i].MarkForHandoff("new-endpoint:6379", int64(i)); err != nil { + t.Fatalf("Failed to mark connection %d for handoff: %v", i, err) + } + // Set a mock initialization function + connections[i].SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + } + + ctx := context.Background() + successCount := 0 + + // Process connections - should trigger scaling and timeout logic + for _, conn := range connections { + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Logf("OnPut returned error (expected with timeout): %v", err) + } + + if shouldPool && !shouldRemove { + successCount++ + } + } + + // With timeout and scaling, most handoffs should eventually succeed + if successCount == 0 { + t.Error("Should have queued some handoffs with timeout and scaling") + } + + t.Logf("Successfully queued %d handoffs with optimization features", successCount) + + // Give time for workers to process and scaling to occur + time.Sleep(100 * time.Millisecond) + }) + + t.Run("WorkerScalingBehavior", func(t *testing.T) { + // Create processor with small queue to test scaling behavior + config := &Config{ + MaxWorkers: 15, // Set to >= 10 to test explicit value preservation + HandoffQueueSize: 1, + MaxHandoffRetries: 3, // Very small queue to force scaling + LogLevel: 2, // Info level to see scaling logs + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Verify initial worker count (should be 0 with on-demand workers) + if processor.GetCurrentWorkers() != 0 { + t.Errorf("Expected 0 initial workers with on-demand system, got %d", processor.GetCurrentWorkers()) + } + if processor.GetScaleLevel() != 0 { + t.Errorf("Processor should be at scale level 0 initially, got %d", processor.GetScaleLevel()) + } + if processor.maxWorkers != 15 { + t.Errorf("Expected maxWorkers=15, got %d", processor.maxWorkers) + } + + // The on-demand worker behavior creates workers only when needed + // This test just verifies the basic configuration is correct + t.Logf("On-demand worker configuration verified - Max: %d, Current: %d", + processor.maxWorkers, processor.GetCurrentWorkers()) + }) + + t.Run("PassiveTimeoutRestoration", func(t *testing.T) { + // Create processor with fast post-handoff duration for testing + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + PostHandoffRelaxedDuration: 100 * time.Millisecond, // Fast expiration for testing + RelaxedTimeout: 5 * time.Second, + LogLevel: 2, + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + ctx := context.Background() + + // Create a connection and trigger handoff + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + // Process the connection to trigger handoff + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("Handoff should succeed: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after handoff") + } + + // Wait for handoff to complete with proper timeout and polling + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.pending.Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify relaxed timeout is set with deadline + if !conn.HasRelaxedTimeout() { + t.Error("Connection should have relaxed timeout after handoff") + } + + // Test that timeout is still active before deadline + // We'll use HasRelaxedTimeout which internally checks the deadline + if !conn.HasRelaxedTimeout() { + t.Error("Connection should still have active relaxed timeout before deadline") + } + + // Wait for deadline to pass + time.Sleep(150 * time.Millisecond) // 100ms deadline + buffer + + // Test that timeout is automatically restored after deadline + // HasRelaxedTimeout should return false after deadline passes + if conn.HasRelaxedTimeout() { + t.Error("Connection should not have active relaxed timeout after deadline") + } + + // Additional verification: calling HasRelaxedTimeout again should still return false + // and should have cleared the internal timeout values + if conn.HasRelaxedTimeout() { + t.Error("Connection should not have relaxed timeout after deadline (second check)") + } + + t.Logf("Passive timeout restoration test completed successfully") + }) + + t.Run("UsableFlagBehavior", func(t *testing.T) { + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, + LogLevel: 2, + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + ctx := context.Background() + + // Create a new connection without setting it usable + mockNetConn := &mockNetConn{addr: "test:6379"} + conn := pool.NewConn(mockNetConn) + + // Initially, connection should not be usable (not initialized) + if conn.IsUsable() { + t.Error("New connection should not be usable before initialization") + } + + // Simulate initialization by setting usable to true + conn.SetUsable(true) + if !conn.IsUsable() { + t.Error("Connection should be usable after initialization") + } + + // OnGet should succeed for usable connection + err := processor.OnGet(ctx, conn, false) + if err != nil { + t.Errorf("OnGet should succeed for usable connection: %v", err) + } + + // Mark connection for handoff + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + // Connection should still be usable until queued, but marked for handoff + if !conn.IsUsable() { + t.Error("Connection should still be usable after being marked for handoff (until queued)") + } + if !conn.ShouldHandoff() { + t.Error("Connection should be marked for handoff") + } + + // OnGet should fail for connection marked for handoff + err = processor.OnGet(ctx, conn, false) + if err == nil { + t.Error("OnGet should fail for connection marked for handoff") + } + if err != ErrConnectionMarkedForHandoff { + t.Errorf("Expected ErrConnectionMarkedForHandoff, got %v", err) + } + + // Process the connection to trigger handoff + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should succeed: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after handoff") + } + + // Wait for handoff to complete + time.Sleep(50 * time.Millisecond) + + // After handoff completion, connection should be usable again + if !conn.IsUsable() { + t.Error("Connection should be usable after handoff completion") + } + + // OnGet should succeed again + err = processor.OnGet(ctx, conn, false) + if err != nil { + t.Errorf("OnGet should succeed after handoff completion: %v", err) + } + + t.Logf("Usable flag behavior test completed successfully") + }) + + t.Run("StaticQueueBehavior", func(t *testing.T) { + config := &Config{ + MaxWorkers: 3, + HandoffQueueSize: 50, + MaxHandoffRetries: 3, // Explicit static queue size + LogLevel: 2, + } + + processor := NewPoolHookWithPoolSize(baseDialer, "tcp", config, nil, 100) // Pool size: 100 + defer processor.Shutdown(context.Background()) + + // Verify queue capacity matches configured size + queueCapacity := cap(processor.handoffQueue) + if queueCapacity != 50 { + t.Errorf("Expected queue capacity 50, got %d", queueCapacity) + } + + // Test that queue size is static regardless of pool size + // (No dynamic resizing should occur) + + ctx := context.Background() + + // Fill part of the queue + for i := 0; i < 10; i++ { + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", int64(i+1)); err != nil { + t.Fatalf("Failed to mark connection %d for handoff: %v", i, err) + } + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("Failed to queue handoff %d: %v", i, err) + } + + if !shouldPool || shouldRemove { + t.Errorf("Connection %d should be pooled after handoff (shouldPool=%v, shouldRemove=%v)", + i, shouldPool, shouldRemove) + } + } + + // Verify queue capacity remains static (the main purpose of this test) + finalCapacity := cap(processor.handoffQueue) + + if finalCapacity != 50 { + t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity) + } + + // Note: We don't check queue size here because workers process items quickly + // The important thing is that the capacity remains static regardless of pool size + }) + + t.Run("ConnectionRemovalOnHandoffFailure", func(t *testing.T) { + // Create a failing dialer that will cause handoff initialization to fail + failingDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + // Return a connection that will fail during initialization + return &mockNetConn{addr: addr, shouldFailInit: true}, nil + } + + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + MaxHandoffRetries: 3, + LogLevel: 2, + } + + processor := NewPoolHook(failingDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Create a mock pool that tracks removals + mockPool := &mockPool{removedConnections: make(map[uint64]bool)} + processor.SetPool(mockPool) + + ctx := context.Background() + + // Create a connection and mark it for handoff + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a failing initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return fmt.Errorf("initialization failed") + }) + + // Process the connection - handoff should fail and connection should be removed + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + if err != nil { + t.Errorf("OnPut should not error: %v", err) + } + if !shouldPool || shouldRemove { + t.Error("Connection should be pooled after failed handoff attempt") + } + + // Wait for handoff to be attempted and fail + time.Sleep(100 * time.Millisecond) + + // Verify that the connection was removed from the pool + if !mockPool.WasRemoved(conn.GetID()) { + t.Errorf("Connection %d should have been removed from pool after handoff failure", conn.GetID()) + } + + t.Logf("Connection removal on handoff failure test completed successfully") + }) + + t.Run("PostHandoffRelaxedTimeout", func(t *testing.T) { + // Create config with short post-handoff duration for testing + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + RelaxedTimeout: 5 * time.Second, + PostHandoffRelaxedDuration: 100 * time.Millisecond, // Short for testing + } + + baseDialer := func(ctx context.Context, network, addr string) (net.Conn, error) { + return &mockNetConn{addr: addr}, nil + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a mock initialization function + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + return nil + }) + + ctx := context.Background() + shouldPool, shouldRemove, err := processor.OnPut(ctx, conn) + + if err != nil { + t.Fatalf("OnPut failed: %v", err) + } + + if !shouldPool { + t.Error("Connection should be pooled after successful handoff") + } + + if shouldRemove { + t.Error("Connection should not be removed after successful handoff") + } + + // Wait for the handoff to complete (it happens asynchronously) + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + + handoffCompleted := false + for !handoffCompleted { + select { + case <-timeout: + t.Fatal("Timeout waiting for handoff to complete") + case <-ticker.C: + if _, pending := processor.pending.Load(conn); !pending { + handoffCompleted = true + } + } + } + + // Verify that relaxed timeout was applied to the new connection + if !conn.HasRelaxedTimeout() { + t.Error("New connection should have relaxed timeout applied after handoff") + } + + // Wait for the post-handoff duration to expire + time.Sleep(150 * time.Millisecond) // Slightly longer than PostHandoffRelaxedDuration + + // Verify that relaxed timeout was automatically cleared + if conn.HasRelaxedTimeout() { + t.Error("Relaxed timeout should be automatically cleared after post-handoff duration") + } + }) + + t.Run("MarkForHandoff returns error when already marked", func(t *testing.T) { + conn := createMockPoolConnection() + + // First mark should succeed + if err := conn.MarkForHandoff("new-endpoint:6379", 1); err != nil { + t.Fatalf("First MarkForHandoff should succeed: %v", err) + } + + // Second mark should fail + if err := conn.MarkForHandoff("another-endpoint:6379", 2); err == nil { + t.Fatal("Second MarkForHandoff should return error") + } else if err.Error() != "connection is already marked for handoff" { + t.Fatalf("Expected specific error message, got: %v", err) + } + + // Verify original handoff data is preserved + if !conn.ShouldHandoff() { + t.Fatal("Connection should still be marked for handoff") + } + if conn.GetHandoffEndpoint() != "new-endpoint:6379" { + t.Fatalf("Expected original endpoint, got: %s", conn.GetHandoffEndpoint()) + } + if conn.GetMovingSeqID() != 1 { + t.Fatalf("Expected original sequence ID, got: %d", conn.GetMovingSeqID()) + } + }) + + t.Run("HandoffTimeoutConfiguration", func(t *testing.T) { + // Test that HandoffTimeout from config is actually used + customTimeout := 2 * time.Second + config := &Config{ + MaxWorkers: 2, + HandoffQueueSize: 10, + HandoffTimeout: customTimeout, // Custom timeout + MaxHandoffRetries: 1, // Single retry to speed up test + LogLevel: 2, + } + + processor := NewPoolHook(baseDialer, "tcp", config, nil) + defer processor.Shutdown(context.Background()) + + // Create a connection that will test the timeout + conn := createMockPoolConnection() + if err := conn.MarkForHandoff("test-endpoint:6379", 123); err != nil { + t.Fatalf("Failed to mark connection for handoff: %v", err) + } + + // Set a dialer that will check the context timeout + var timeoutVerified int32 // Use atomic for thread safety + conn.SetInitConnFunc(func(ctx context.Context, cn *pool.Conn) error { + // Check that the context has the expected timeout + deadline, ok := ctx.Deadline() + if !ok { + t.Error("Context should have a deadline") + return errors.New("no deadline") + } + + // The deadline should be approximately customTimeout from now + expectedDeadline := time.Now().Add(customTimeout) + timeDiff := deadline.Sub(expectedDeadline) + if timeDiff < -500*time.Millisecond || timeDiff > 500*time.Millisecond { + t.Errorf("Context deadline not as expected. Expected around %v, got %v (diff: %v)", + expectedDeadline, deadline, timeDiff) + } else { + atomic.StoreInt32(&timeoutVerified, 1) + } + + return nil // Successful handoff + }) + + // Trigger handoff + shouldPool, shouldRemove, err := processor.OnPut(context.Background(), conn) + if err != nil { + t.Errorf("OnPut should not return error: %v", err) + } + + // Connection should be queued for handoff + if !shouldPool || shouldRemove { + t.Errorf("Connection should be pooled for handoff processing") + } + + // Wait for handoff to complete + time.Sleep(500 * time.Millisecond) + + if atomic.LoadInt32(&timeoutVerified) == 0 { + t.Error("HandoffTimeout was not properly applied to context") + } + + t.Logf("HandoffTimeout configuration test completed successfully") + }) +} diff --git a/hitless/state.go b/hitless/state.go new file mode 100644 index 0000000000..109d939fc0 --- /dev/null +++ b/hitless/state.go @@ -0,0 +1,24 @@ +package hitless + +// State represents the current state of a hitless upgrade operation. +type State int + +const ( + // StateIdle indicates no upgrade is in progress + StateIdle State = iota + + // StateHandoff indicates a connection handoff is in progress + StateMoving +) + +// String returns a string representation of the state. +func (s State) String() string { + switch s { + case StateIdle: + return "idle" + case StateMoving: + return "moving" + default: + return "unknown" + } +} diff --git a/internal/interfaces/interfaces.go b/internal/interfaces/interfaces.go new file mode 100644 index 0000000000..3b0596b89e --- /dev/null +++ b/internal/interfaces/interfaces.go @@ -0,0 +1,67 @@ +// Package interfaces provides shared interfaces used by both the main redis package +// and the hitless upgrade package to avoid circular dependencies. +package interfaces + +import ( + "context" + "net" + "time" +) + +// Forward declaration to avoid circular imports +type NotificationProcessor interface { + RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error + UnregisterHandler(pushNotificationName string) error + GetHandler(pushNotificationName string) interface{} +} + +// ClientInterface defines the interface that clients must implement for hitless upgrades. +type ClientInterface interface { + // GetOptions returns the client options. + GetOptions() OptionsInterface + + // GetPushProcessor returns the client's push notification processor. + GetPushProcessor() NotificationProcessor +} + +// OptionsInterface defines the interface for client options. +type OptionsInterface interface { + // GetReadTimeout returns the read timeout. + GetReadTimeout() time.Duration + + // GetWriteTimeout returns the write timeout. + GetWriteTimeout() time.Duration + + // GetNetwork returns the network type. + GetNetwork() string + + // GetAddr returns the connection address. + GetAddr() string + + // IsTLSEnabled returns true if TLS is enabled. + IsTLSEnabled() bool + + // GetProtocol returns the protocol version. + GetProtocol() int + + // GetPoolSize returns the connection pool size. + GetPoolSize() int + + // NewDialer returns a new dialer function for the connection. + NewDialer() func(context.Context) (net.Conn, error) +} + +// ConnectionWithRelaxedTimeout defines the interface for connections that support relaxed timeout adjustment. +// This is used by the hitless upgrade system for per-connection timeout management. +type ConnectionWithRelaxedTimeout interface { + // SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades. + // These timeouts remain active until explicitly cleared. + SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) + + // SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. + // After the deadline, timeouts automatically revert to normal values. + SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) + + // ClearRelaxedTimeout clears relaxed timeouts for this connection. + ClearRelaxedTimeout() +} diff --git a/internal/pool/bench_test.go b/internal/pool/bench_test.go index 72308e1242..fc37b82121 100644 --- a/internal/pool/bench_test.go +++ b/internal/pool/bench_test.go @@ -2,6 +2,7 @@ package pool_test import ( "context" + "errors" "fmt" "testing" "time" @@ -31,7 +32,7 @@ func BenchmarkPoolGetPut(b *testing.B) { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: bm.poolSize, + PoolSize: int32(bm.poolSize), PoolTimeout: time.Second, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Hour, @@ -75,7 +76,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { b.Run(bm.String(), func(b *testing.B) { connPool := pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: bm.poolSize, + PoolSize: int32(bm.poolSize), PoolTimeout: time.Second, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Hour, @@ -89,7 +90,7 @@ func BenchmarkPoolGetRemove(b *testing.B) { if err != nil { b.Fatal(err) } - connPool.Remove(ctx, cn, nil) + connPool.Remove(ctx, cn, errors.New("Bench test remove")) } }) }) diff --git a/internal/pool/buffer_size_test.go b/internal/pool/buffer_size_test.go index 7f4bd37ee4..71223d7081 100644 --- a/internal/pool/buffer_size_test.go +++ b/internal/pool/buffer_size_test.go @@ -26,7 +26,7 @@ var _ = Describe("Buffer Size Configuration", func() { It("should use default buffer sizes when not specified", func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 1000, }) @@ -48,7 +48,7 @@ var _ = Describe("Buffer Size Configuration", func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 1000, ReadBufferSize: customReadSize, WriteBufferSize: customWriteSize, @@ -69,7 +69,7 @@ var _ = Describe("Buffer Size Configuration", func() { It("should handle zero buffer sizes by using defaults", func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 1000, ReadBufferSize: 0, // Should use default WriteBufferSize: 0, // Should use default @@ -105,7 +105,7 @@ var _ = Describe("Buffer Size Configuration", func() { // without setting ReadBufferSize and WriteBufferSize connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 1000, // ReadBufferSize and WriteBufferSize are not set (will be 0) }) diff --git a/internal/pool/conn.go b/internal/pool/conn.go index 8fcdfa6768..aa2da01a7f 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -3,7 +3,10 @@ package pool import ( "bufio" "context" + "errors" + "fmt" "net" + "sync" "sync/atomic" "time" @@ -12,17 +15,64 @@ import ( var noDeadline = time.Time{} +// Global atomic counter for connection IDs +var connIDCounter uint64 + +// atomicNetConn is a wrapper to ensure consistent typing in atomic.Value +type atomicNetConn struct { + conn net.Conn +} + +// generateConnID generates a fast unique identifier for a connection with zero allocations +func generateConnID() uint64 { + return atomic.AddUint64(&connIDCounter, 1) +} + type Conn struct { - usedAt int64 // atomic - netConn net.Conn + usedAt int64 // atomic + + // Lock-free netConn access using atomic.Value + // Contains *atomicNetConn wrapper, accessed atomically for better performance + netConnAtomic atomic.Value // stores *atomicNetConn rd *proto.Reader bw *bufio.Writer wr *proto.Writer - Inited bool + // Lightweight mutex to protect reader operations during handoff + // Only used for the brief period during SetNetConn and HasBufferedData/PeekReplyTypeSafe + readerMu sync.RWMutex + + Inited atomic.Bool pooled bool + closed atomic.Bool createdAt time.Time + expiresAt time.Time + + // Hitless upgrade support: relaxed timeouts during migrations/failovers + // Using atomic operations for lock-free access to avoid mutex contention + relaxedReadTimeoutNs atomic.Int64 // time.Duration as nanoseconds + relaxedWriteTimeoutNs atomic.Int64 // time.Duration as nanoseconds + relaxedDeadlineNs atomic.Int64 // time.Time as nanoseconds since epoch + + // Counter to track multiple relaxed timeout setters if we have nested calls + // will be decremented when ClearRelaxedTimeout is called or deadline is reached + // if counter reaches 0, we clear the relaxed timeouts + relaxedCounter atomic.Int32 + + // Connection initialization function for reconnections + initConnFunc func(context.Context, *Conn) error + + // Connection identifier for unique tracking across handoffs + id uint64 // Unique numeric identifier for this connection + + // Handoff state - using atomic operations for lock-free access + usableAtomic atomic.Bool // Connection usability state + shouldHandoffAtomic atomic.Bool // Whether connection should be handed off + movingSeqIDAtomic atomic.Int64 // Sequence ID from MOVING notification + handoffRetriesAtomic atomic.Uint32 // Retry counter for handoff attempts + // newEndpointAtomic needs special handling as it's a string + newEndpointAtomic atomic.Value // stores string onClose func() error } @@ -33,8 +83,8 @@ func NewConn(netConn net.Conn) *Conn { func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Conn { cn := &Conn{ - netConn: netConn, createdAt: time.Now(), + id: generateConnID(), // Generate unique ID for this connection } // Use specified buffer sizes, or fall back to 32KiB defaults if 0 @@ -50,6 +100,16 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con cn.bw = bufio.NewWriterSize(netConn, proto.DefaultBufferSize) } + // Store netConn atomically for lock-free access using wrapper + cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) + + // Initialize atomic handoff state + cn.usableAtomic.Store(false) // false initially, set to true after initialization + cn.shouldHandoffAtomic.Store(false) // false initially + cn.movingSeqIDAtomic.Store(0) // 0 initially + cn.handoffRetriesAtomic.Store(0) // 0 initially + cn.newEndpointAtomic.Store("") // empty string initially + cn.wr = proto.NewWriter(cn.bw) cn.SetUsedAt(time.Now()) return cn @@ -64,23 +124,368 @@ func (cn *Conn) SetUsedAt(tm time.Time) { atomic.StoreInt64(&cn.usedAt, tm.Unix()) } +// getNetConn returns the current network connection using atomic load (lock-free). +// This is the fast path for accessing netConn without mutex overhead. +func (cn *Conn) getNetConn() net.Conn { + if v := cn.netConnAtomic.Load(); v != nil { + if wrapper, ok := v.(*atomicNetConn); ok { + return wrapper.conn + } + } + return nil +} + +// setNetConn stores the network connection atomically (lock-free). +// This is used for the fast path of connection replacement. +func (cn *Conn) setNetConn(netConn net.Conn) { + cn.netConnAtomic.Store(&atomicNetConn{conn: netConn}) +} + +// Lock-free helper methods for handoff state management + +// isUsable returns true if the connection is safe to use (lock-free). +func (cn *Conn) isUsable() bool { + return cn.usableAtomic.Load() +} + +// setUsable sets the usable flag atomically (lock-free). +func (cn *Conn) setUsable(usable bool) { + cn.usableAtomic.Store(usable) +} + +// shouldHandoff returns true if connection needs handoff (lock-free). +func (cn *Conn) shouldHandoff() bool { + return cn.shouldHandoffAtomic.Load() +} + +// setShouldHandoff sets the handoff flag atomically (lock-free). +func (cn *Conn) setShouldHandoff(should bool) { + cn.shouldHandoffAtomic.Store(should) +} + +// getMovingSeqID returns the sequence ID atomically (lock-free). +func (cn *Conn) getMovingSeqID() int64 { + return cn.movingSeqIDAtomic.Load() +} + +// setMovingSeqID sets the sequence ID atomically (lock-free). +func (cn *Conn) setMovingSeqID(seqID int64) { + cn.movingSeqIDAtomic.Store(seqID) +} + +// getNewEndpoint returns the new endpoint atomically (lock-free). +func (cn *Conn) getNewEndpoint() string { + if endpoint := cn.newEndpointAtomic.Load(); endpoint != nil { + return endpoint.(string) + } + return "" +} + +// setNewEndpoint sets the new endpoint atomically (lock-free). +func (cn *Conn) setNewEndpoint(endpoint string) { + cn.newEndpointAtomic.Store(endpoint) +} + +// setHandoffRetries sets the retry count atomically (lock-free). +func (cn *Conn) setHandoffRetries(retries int) { + cn.handoffRetriesAtomic.Store(uint32(retries)) +} + +// incrementHandoffRetries atomically increments and returns the new retry count (lock-free). +func (cn *Conn) incrementHandoffRetries(delta int) int { + return int(cn.handoffRetriesAtomic.Add(uint32(delta))) +} + +// IsUsable returns true if the connection is safe to use for new commands (lock-free). +func (cn *Conn) IsUsable() bool { + return cn.isUsable() +} + +func (cn *Conn) IsInited() bool { + return cn.Inited.Load() +} + +// SetUsable sets the usable flag for the connection (lock-free). +func (cn *Conn) SetUsable(usable bool) { + cn.setUsable(usable) +} + +// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades. +// These timeouts will be used for all subsequent commands until the deadline expires. +// Uses atomic operations for lock-free access. +func (cn *Conn) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { + cn.relaxedCounter.Add(1) + cn.relaxedReadTimeoutNs.Store(int64(readTimeout)) + cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout)) +} + +// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. +// After the deadline, timeouts automatically revert to normal values. +// Uses atomic operations for lock-free access. +func (cn *Conn) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { + cn.relaxedCounter.Add(1) + cn.relaxedReadTimeoutNs.Store(int64(readTimeout)) + cn.relaxedWriteTimeoutNs.Store(int64(writeTimeout)) + cn.relaxedDeadlineNs.Store(deadline.UnixNano()) +} + +// ClearRelaxedTimeout removes relaxed timeouts, returning to normal timeout behavior. +// Uses atomic operations for lock-free access. +func (cn *Conn) ClearRelaxedTimeout() { + // Atomically decrement counter and check if we should clear + newCount := cn.relaxedCounter.Add(-1) + if newCount <= 0 { + // Use compare-and-swap to ensure only one goroutine clears + if cn.relaxedCounter.CompareAndSwap(newCount, 0) { + cn.clearRelaxedTimeout() + } + } +} + +func (cn *Conn) clearRelaxedTimeout() { + cn.relaxedReadTimeoutNs.Store(0) + cn.relaxedWriteTimeoutNs.Store(0) + cn.relaxedDeadlineNs.Store(0) + cn.relaxedCounter.Store(0) +} + +// HasRelaxedTimeout returns true if relaxed timeouts are currently active on this connection. +// This checks both the timeout values and the deadline (if set). +// Uses atomic operations for lock-free access. +func (cn *Conn) HasRelaxedTimeout() bool { + // Fast path: no relaxed timeouts are set + if cn.relaxedCounter.Load() <= 0 { + return false + } + + readTimeoutNs := cn.relaxedReadTimeoutNs.Load() + writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load() + + // If no relaxed timeouts are set, return false + if readTimeoutNs <= 0 && writeTimeoutNs <= 0 { + return false + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, relaxed timeouts are active + if deadlineNs == 0 { + return true + } + + // If deadline is set, check if it's still in the future + return time.Now().UnixNano() < deadlineNs +} + +// getEffectiveReadTimeout returns the timeout to use for read operations. +// If relaxed timeout is set and not expired, it takes precedence over the provided timeout. +// This method automatically clears expired relaxed timeouts using atomic operations. +func (cn *Conn) getEffectiveReadTimeout(normalTimeout time.Duration) time.Duration { + readTimeoutNs := cn.relaxedReadTimeoutNs.Load() + + // Fast path: no relaxed timeout set + if readTimeoutNs <= 0 { + return normalTimeout + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, use relaxed timeout + if deadlineNs == 0 { + return time.Duration(readTimeoutNs) + } + + nowNs := time.Now().UnixNano() + // Check if deadline has passed + if nowNs < deadlineNs { + // Deadline is in the future, use relaxed timeout + return time.Duration(readTimeoutNs) + } else { + // Deadline has passed, clear relaxed timeouts atomically and use normal timeout + cn.relaxedCounter.Add(-1) + if cn.relaxedCounter.Load() <= 0 { + cn.clearRelaxedTimeout() + } + return normalTimeout + } +} + +// getEffectiveWriteTimeout returns the timeout to use for write operations. +// If relaxed timeout is set and not expired, it takes precedence over the provided timeout. +// This method automatically clears expired relaxed timeouts using atomic operations. +func (cn *Conn) getEffectiveWriteTimeout(normalTimeout time.Duration) time.Duration { + writeTimeoutNs := cn.relaxedWriteTimeoutNs.Load() + + // Fast path: no relaxed timeout set + if writeTimeoutNs <= 0 { + return normalTimeout + } + + deadlineNs := cn.relaxedDeadlineNs.Load() + // If no deadline is set, use relaxed timeout + if deadlineNs == 0 { + return time.Duration(writeTimeoutNs) + } + + nowNs := time.Now().UnixNano() + // Check if deadline has passed + if nowNs < deadlineNs { + // Deadline is in the future, use relaxed timeout + return time.Duration(writeTimeoutNs) + } else { + // Deadline has passed, clear relaxed timeouts atomically and use normal timeout + cn.relaxedCounter.Add(-1) + if cn.relaxedCounter.Load() <= 0 { + cn.clearRelaxedTimeout() + } + return normalTimeout + } +} + func (cn *Conn) SetOnClose(fn func() error) { cn.onClose = fn } +// SetInitConnFunc sets the connection initialization function to be called on reconnections. +func (cn *Conn) SetInitConnFunc(fn func(context.Context, *Conn) error) { + cn.initConnFunc = fn +} + +// ExecuteInitConn runs the stored connection initialization function if available. +func (cn *Conn) ExecuteInitConn(ctx context.Context) error { + if cn.initConnFunc != nil { + return cn.initConnFunc(ctx, cn) + } + return fmt.Errorf("redis: no initConnFunc set for connection %d", cn.GetID()) +} + func (cn *Conn) SetNetConn(netConn net.Conn) { - cn.netConn = netConn + // Store the new connection atomically first (lock-free) + cn.setNetConn(netConn) + // Clear relaxed timeouts when connection is replaced + cn.clearRelaxedTimeout() + + // Protect reader reset operations to avoid data races + // Use write lock since we're modifying the reader state + cn.readerMu.Lock() cn.rd.Reset(netConn) + cn.readerMu.Unlock() + cn.bw.Reset(netConn) } +// GetNetConn safely returns the current network connection using atomic load (lock-free). +// This method is used by the pool for health checks and provides better performance. +func (cn *Conn) GetNetConn() net.Conn { + return cn.getNetConn() +} + +// SetNetConnAndInitConn replaces the underlying connection and executes the initialization. +func (cn *Conn) SetNetConnAndInitConn(ctx context.Context, netConn net.Conn) error { + // New connection is not initialized yet + cn.Inited.Store(false) + // Replace the underlying connection + cn.SetNetConn(netConn) + return cn.ExecuteInitConn(ctx) +} + +// MarkForHandoff marks the connection for handoff due to MOVING notification (lock-free). +// Returns an error if the connection is already marked for handoff. +func (cn *Conn) MarkForHandoff(newEndpoint string, seqID int64) error { + // Use single atomic CAS operation for state transition + if !cn.shouldHandoffAtomic.CompareAndSwap(false, true) { + return errors.New("connection is already marked for handoff") + } + + cn.setNewEndpoint(newEndpoint) + cn.setMovingSeqID(seqID) + return nil +} + +func (cn *Conn) MarkQueuedForHandoff() error { + // Use single atomic CAS operation for state transition + if !cn.shouldHandoffAtomic.CompareAndSwap(true, false) { + return errors.New("connection was not marked for handoff") + } + cn.setUsable(false) + return nil +} + +// RestoreHandoffState restores the handoff state after a failed handoff (lock-free). +func (cn *Conn) RestoreHandoffState() { + // Restore shouldHandoff flag for retry + cn.shouldHandoffAtomic.Store(true) + // Keep usable=false to prevent the connection from being used until handoff succeeds + cn.setUsable(false) +} + +// ShouldHandoff returns true if the connection needs to be handed off (lock-free). +func (cn *Conn) ShouldHandoff() bool { + return cn.shouldHandoff() +} + +// GetHandoffEndpoint returns the new endpoint for handoff (lock-free). +func (cn *Conn) GetHandoffEndpoint() string { + return cn.getNewEndpoint() +} + +// GetMovingSeqID returns the sequence ID from the MOVING notification (lock-free). +func (cn *Conn) GetMovingSeqID() int64 { + return cn.getMovingSeqID() +} + +// GetID returns the unique identifier for this connection. +func (cn *Conn) GetID() uint64 { + return cn.id +} + +// ClearHandoffState clears the handoff state after successful handoff (lock-free). +func (cn *Conn) ClearHandoffState() { + // clear handoff state + cn.setShouldHandoff(false) + cn.setNewEndpoint("") + cn.setMovingSeqID(0) + cn.setHandoffRetries(0) + cn.setUsable(true) // Connection is safe to use again after handoff completes +} + +// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free). +func (cn *Conn) IncrementAndGetHandoffRetries(n int) int { + return cn.incrementHandoffRetries(n) +} + +// HasBufferedData safely checks if the connection has buffered data. +// This method is used to avoid data races when checking for push notifications. +func (cn *Conn) HasBufferedData() bool { + // Use read lock for concurrent access to reader state + cn.readerMu.RLock() + defer cn.readerMu.RUnlock() + return cn.rd.Buffered() > 0 +} + +// PeekReplyTypeSafe safely peeks at the reply type. +// This method is used to avoid data races when checking for push notifications. +func (cn *Conn) PeekReplyTypeSafe() (byte, error) { + // Use read lock for concurrent access to reader state + cn.readerMu.RLock() + defer cn.readerMu.RUnlock() + + if cn.rd.Buffered() <= 0 { + return 0, fmt.Errorf("redis: can't peek reply type, no data available") + } + return cn.rd.PeekReplyType() +} + func (cn *Conn) Write(b []byte) (int, error) { - return cn.netConn.Write(b) + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.Write(b) + } + return 0, net.ErrClosed } func (cn *Conn) RemoteAddr() net.Addr { - if cn.netConn != nil { - return cn.netConn.RemoteAddr() + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.RemoteAddr() } return nil } @@ -89,7 +494,16 @@ func (cn *Conn) WithReader( ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error, ) error { if timeout >= 0 { - if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { + // Use relaxed timeout if set, otherwise use provided timeout + effectiveTimeout := cn.getEffectiveReadTimeout(timeout) + + // Get the connection directly from atomic storage + netConn := cn.getNetConn() + if netConn == nil { + return fmt.Errorf("redis: connection not available") + } + + if err := netConn.SetReadDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { return err } } @@ -100,13 +514,26 @@ func (cn *Conn) WithWriter( ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, ) error { if timeout >= 0 { - if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { - return err + // Use relaxed timeout if set, otherwise use provided timeout + effectiveTimeout := cn.getEffectiveWriteTimeout(timeout) + + // Always set write deadline, even if getNetConn() returns nil + // This prevents write operations from hanging indefinitely + if netConn := cn.getNetConn(); netConn != nil { + if err := netConn.SetWriteDeadline(cn.deadline(ctx, effectiveTimeout)); err != nil { + return err + } + } else { + // If getNetConn() returns nil, we still need to respect the timeout + // Return an error to prevent indefinite blocking + return fmt.Errorf("redis: connection not available for write operation") } } if cn.bw.Buffered() > 0 { - cn.bw.Reset(cn.netConn) + if netConn := cn.getNetConn(); netConn != nil { + cn.bw.Reset(netConn) + } } if err := fn(cn.wr); err != nil { @@ -116,19 +543,33 @@ func (cn *Conn) WithWriter( return cn.bw.Flush() } +func (cn *Conn) IsClosed() bool { + return cn.closed.Load() +} + func (cn *Conn) Close() error { + cn.closed.Store(true) if cn.onClose != nil { // ignore error _ = cn.onClose() } - return cn.netConn.Close() + + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return netConn.Close() + } + return nil } // MaybeHasData tries to peek at the next byte in the socket without consuming it // This is used to check if there are push notifications available // Important: This will work on Linux, but not on Windows func (cn *Conn) MaybeHasData() bool { - return maybeHasData(cn.netConn) + // Lock-free netConn access for better performance + if netConn := cn.getNetConn(); netConn != nil { + return maybeHasData(netConn) + } + return false } func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { diff --git a/internal/pool/export_test.go b/internal/pool/export_test.go index 40e387c9a0..20456b8100 100644 --- a/internal/pool/export_test.go +++ b/internal/pool/export_test.go @@ -10,7 +10,7 @@ func (cn *Conn) SetCreatedAt(tm time.Time) { } func (cn *Conn) NetConn() net.Conn { - return cn.netConn + return cn.getNetConn() } func (p *ConnPool) CheckMinIdleConns() { diff --git a/internal/pool/hooks.go b/internal/pool/hooks.go new file mode 100644 index 0000000000..adbcfbbf94 --- /dev/null +++ b/internal/pool/hooks.go @@ -0,0 +1,114 @@ +package pool + +import ( + "context" + "sync" +) + +// PoolHook defines the interface for connection lifecycle hooks. +type PoolHook interface { + // OnGet is called when a connection is retrieved from the pool. + // It can modify the connection or return an error to prevent its use. + // It has isNewConn flag to indicate if this is a new connection (rather than idle from the pool) + // The flag can be used for gathering metrics on pool hit/miss ratio. + OnGet(ctx context.Context, conn *Conn, isNewConn bool) error + + // OnPut is called when a connection is returned to the pool. + // It returns whether the connection should be pooled and whether it should be removed. + OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) +} + +// PoolHookManager manages multiple pool hooks. +type PoolHookManager struct { + hooks []PoolHook + hooksMu sync.RWMutex +} + +// NewPoolHookManager creates a new pool hook manager. +func NewPoolHookManager() *PoolHookManager { + return &PoolHookManager{ + hooks: make([]PoolHook, 0), + } +} + +// AddHook adds a pool hook to the manager. +// Hooks are called in the order they were added. +func (phm *PoolHookManager) AddHook(hook PoolHook) { + phm.hooksMu.Lock() + defer phm.hooksMu.Unlock() + phm.hooks = append(phm.hooks, hook) +} + +// RemoveHook removes a pool hook from the manager. +func (phm *PoolHookManager) RemoveHook(hook PoolHook) { + phm.hooksMu.Lock() + defer phm.hooksMu.Unlock() + + for i, h := range phm.hooks { + if h == hook { + // Remove hook by swapping with last element and truncating + phm.hooks[i] = phm.hooks[len(phm.hooks)-1] + phm.hooks = phm.hooks[:len(phm.hooks)-1] + break + } + } +} + +// ProcessOnGet calls all OnGet hooks in order. +// If any hook returns an error, processing stops and the error is returned. +func (phm *PoolHookManager) ProcessOnGet(ctx context.Context, conn *Conn, isNewConn bool) error { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + for _, hook := range phm.hooks { + if err := hook.OnGet(ctx, conn, isNewConn); err != nil { + return err + } + } + return nil +} + +// ProcessOnPut calls all OnPut hooks in order. +// The first hook that returns shouldRemove=true or shouldPool=false will stop processing. +func (phm *PoolHookManager) ProcessOnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + shouldPool = true // Default to pooling the connection + + for _, hook := range phm.hooks { + hookShouldPool, hookShouldRemove, hookErr := hook.OnPut(ctx, conn) + + if hookErr != nil { + return false, true, hookErr + } + + // If any hook says to remove or not pool, respect that decision + if hookShouldRemove { + return false, true, nil + } + + if !hookShouldPool { + shouldPool = false + } + } + + return shouldPool, false, nil +} + +// GetHookCount returns the number of registered hooks (for testing). +func (phm *PoolHookManager) GetHookCount() int { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + return len(phm.hooks) +} + +// GetHooks returns a copy of all registered hooks. +func (phm *PoolHookManager) GetHooks() []PoolHook { + phm.hooksMu.RLock() + defer phm.hooksMu.RUnlock() + + hooks := make([]PoolHook, len(phm.hooks)) + copy(hooks, phm.hooks) + return hooks +} diff --git a/internal/pool/hooks_test.go b/internal/pool/hooks_test.go new file mode 100644 index 0000000000..e6100115ce --- /dev/null +++ b/internal/pool/hooks_test.go @@ -0,0 +1,213 @@ +package pool + +import ( + "context" + "errors" + "net" + "testing" + "time" +) + +// TestHook for testing hook functionality +type TestHook struct { + OnGetCalled int + OnPutCalled int + GetError error + PutError error + ShouldPool bool + ShouldRemove bool +} + +func (th *TestHook) OnGet(ctx context.Context, conn *Conn, isNewConn bool) error { + th.OnGetCalled++ + return th.GetError +} + +func (th *TestHook) OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) { + th.OnPutCalled++ + return th.ShouldPool, th.ShouldRemove, th.PutError +} + +func TestPoolHookManager(t *testing.T) { + manager := NewPoolHookManager() + + // Test initial state + if manager.GetHookCount() != 0 { + t.Errorf("Expected 0 hooks initially, got %d", manager.GetHookCount()) + } + + // Add hooks + hook1 := &TestHook{ShouldPool: true} + hook2 := &TestHook{ShouldPool: true} + + manager.AddHook(hook1) + manager.AddHook(hook2) + + if manager.GetHookCount() != 2 { + t.Errorf("Expected 2 hooks after adding, got %d", manager.GetHookCount()) + } + + // Test ProcessOnGet + ctx := context.Background() + conn := &Conn{} // Mock connection + + err := manager.ProcessOnGet(ctx, conn, false) + if err != nil { + t.Errorf("ProcessOnGet should not error: %v", err) + } + + if hook1.OnGetCalled != 1 { + t.Errorf("Expected hook1.OnGetCalled to be 1, got %d", hook1.OnGetCalled) + } + + if hook2.OnGetCalled != 1 { + t.Errorf("Expected hook2.OnGetCalled to be 1, got %d", hook2.OnGetCalled) + } + + // Test ProcessOnPut + shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessOnPut should not error: %v", err) + } + + if !shouldPool { + t.Error("Expected shouldPool to be true") + } + + if shouldRemove { + t.Error("Expected shouldRemove to be false") + } + + if hook1.OnPutCalled != 1 { + t.Errorf("Expected hook1.OnPutCalled to be 1, got %d", hook1.OnPutCalled) + } + + if hook2.OnPutCalled != 1 { + t.Errorf("Expected hook2.OnPutCalled to be 1, got %d", hook2.OnPutCalled) + } + + // Remove a hook + manager.RemoveHook(hook1) + + if manager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook after removing, got %d", manager.GetHookCount()) + } +} + +func TestHookErrorHandling(t *testing.T) { + manager := NewPoolHookManager() + + // Hook that returns error on Get + errorHook := &TestHook{ + GetError: errors.New("test error"), + ShouldPool: true, + } + + normalHook := &TestHook{ShouldPool: true} + + manager.AddHook(errorHook) + manager.AddHook(normalHook) + + ctx := context.Background() + conn := &Conn{} + + // Test that error stops processing + err := manager.ProcessOnGet(ctx, conn, false) + if err == nil { + t.Error("Expected error from ProcessOnGet") + } + + if errorHook.OnGetCalled != 1 { + t.Errorf("Expected errorHook.OnGetCalled to be 1, got %d", errorHook.OnGetCalled) + } + + // normalHook should not be called due to error + if normalHook.OnGetCalled != 0 { + t.Errorf("Expected normalHook.OnGetCalled to be 0, got %d", normalHook.OnGetCalled) + } +} + +func TestHookShouldRemove(t *testing.T) { + manager := NewPoolHookManager() + + // Hook that says to remove connection + removeHook := &TestHook{ + ShouldPool: false, + ShouldRemove: true, + } + + normalHook := &TestHook{ShouldPool: true} + + manager.AddHook(removeHook) + manager.AddHook(normalHook) + + ctx := context.Background() + conn := &Conn{} + + shouldPool, shouldRemove, err := manager.ProcessOnPut(ctx, conn) + if err != nil { + t.Errorf("ProcessOnPut should not error: %v", err) + } + + if shouldPool { + t.Error("Expected shouldPool to be false") + } + + if !shouldRemove { + t.Error("Expected shouldRemove to be true") + } + + if removeHook.OnPutCalled != 1 { + t.Errorf("Expected removeHook.OnPutCalled to be 1, got %d", removeHook.OnPutCalled) + } + + // normalHook should not be called due to early return + if normalHook.OnPutCalled != 0 { + t.Errorf("Expected normalHook.OnPutCalled to be 0, got %d", normalHook.OnPutCalled) + } +} + +func TestPoolWithHooks(t *testing.T) { + // Create a pool with hooks + hookManager := NewPoolHookManager() + testHook := &TestHook{ShouldPool: true} + hookManager.AddHook(testHook) + + opt := &Options{ + Dialer: func(ctx context.Context) (net.Conn, error) { + return &net.TCPConn{}, nil // Mock connection + }, + PoolSize: 1, + DialTimeout: time.Second, + } + + pool := NewConnPool(opt) + defer pool.Close() + + // Add hook to pool after creation + pool.AddPoolHook(testHook) + + // Verify hooks are initialized + if pool.hookManager == nil { + t.Error("Expected hookManager to be initialized") + } + + if pool.hookManager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook in pool, got %d", pool.hookManager.GetHookCount()) + } + + // Test adding hook to pool + additionalHook := &TestHook{ShouldPool: true} + pool.AddPoolHook(additionalHook) + + if pool.hookManager.GetHookCount() != 2 { + t.Errorf("Expected 2 hooks after adding, got %d", pool.hookManager.GetHookCount()) + } + + // Test removing hook from pool + pool.RemovePoolHook(additionalHook) + + if pool.hookManager.GetHookCount() != 1 { + t.Errorf("Expected 1 hook after removing, got %d", pool.hookManager.GetHookCount()) + } +} diff --git a/internal/pool/pool.go b/internal/pool/pool.go index fa0306c3b9..43c6a81907 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -3,6 +3,7 @@ package pool import ( "context" "errors" + "log" "net" "sync" "sync/atomic" @@ -22,6 +23,12 @@ var ( // ErrPoolTimeout timed out waiting to get a connection from the connection pool. ErrPoolTimeout = errors.New("redis: connection pool timeout") + + popAttempts = 10 + getAttempts = 3 + minTime = time.Unix(-2208988800, 0) // Jan 1, 1900 + maxTime = minTime.Add(1<<63 - 1) + noExpiration = maxTime ) var timers = sync.Pool{ @@ -38,11 +45,14 @@ type Stats struct { Misses uint32 // number of times free connection was NOT found in the pool Timeouts uint32 // number of times a wait timeout occurred WaitCount uint32 // number of times a connection was waited + Unusable uint32 // number of times a connection was found to be unusable WaitDurationNs int64 // total time spent for waiting a connection in nanoseconds TotalConns uint32 // number of total connections in the pool IdleConns uint32 // number of idle connections in the pool StaleConns uint32 // number of stale connections removed from the pool + + PubSubStats PubSubStats } type Pooler interface { @@ -57,29 +67,27 @@ type Pooler interface { IdleLen() int Stats() *Stats + AddPoolHook(hook PoolHook) + RemovePoolHook(hook PoolHook) + Close() error } type Options struct { - Dialer func(context.Context) (net.Conn, error) - - PoolFIFO bool - PoolSize int - DialTimeout time.Duration - PoolTimeout time.Duration - MinIdleConns int - MaxIdleConns int - MaxActiveConns int - ConnMaxIdleTime time.Duration - ConnMaxLifetime time.Duration - - - // Protocol version for optimization (3 = RESP3 with push notifications, 2 = RESP2 without) - Protocol int - + Dialer func(context.Context) (net.Conn, error) ReadBufferSize int WriteBufferSize int + PoolFIFO bool + PoolSize int32 + DialTimeout time.Duration + PoolTimeout time.Duration + MinIdleConns int32 + MaxIdleConns int32 + MaxActiveConns int32 + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + PushNotificationsEnabled bool } type lastDialErrorWrap struct { @@ -95,16 +103,21 @@ type ConnPool struct { queue chan struct{} connsMu sync.Mutex - conns []*Conn + conns map[uint64]*Conn idleConns []*Conn - poolSize int - idleConnsLen int + poolSize atomic.Int32 + idleConnsLen atomic.Int32 + idleCheckInProgress atomic.Bool stats Stats waitDurationNs atomic.Int64 _closed uint32 // atomic + + // Pool hooks manager for flexible connection processing + hookManagerMu sync.RWMutex + hookManager *PoolHookManager } var _ Pooler = (*ConnPool)(nil) @@ -114,34 +127,69 @@ func NewConnPool(opt *Options) *ConnPool { cfg: opt, queue: make(chan struct{}, opt.PoolSize), - conns: make([]*Conn, 0, opt.PoolSize), + conns: make(map[uint64]*Conn), idleConns: make([]*Conn, 0, opt.PoolSize), } - p.connsMu.Lock() - p.checkMinIdleConns() - p.connsMu.Unlock() + // Only create MinIdleConns if explicitly requested (> 0) + // This avoids creating connections during pool initialization for tests + if opt.MinIdleConns > 0 { + p.connsMu.Lock() + p.checkMinIdleConns() + p.connsMu.Unlock() + } return p } +// initializeHooks sets up the pool hooks system. +func (p *ConnPool) initializeHooks() { + p.hookManager = NewPoolHookManager() +} + +// AddPoolHook adds a pool hook to the pool. +func (p *ConnPool) AddPoolHook(hook PoolHook) { + p.hookManagerMu.Lock() + defer p.hookManagerMu.Unlock() + + if p.hookManager == nil { + p.initializeHooks() + } + p.hookManager.AddHook(hook) +} + +// RemovePoolHook removes a pool hook from the pool. +func (p *ConnPool) RemovePoolHook(hook PoolHook) { + p.hookManagerMu.Lock() + defer p.hookManagerMu.Unlock() + + if p.hookManager != nil { + p.hookManager.RemoveHook(hook) + } +} + func (p *ConnPool) checkMinIdleConns() { + if !p.idleCheckInProgress.CompareAndSwap(false, true) { + return + } + defer p.idleCheckInProgress.Store(false) + if p.cfg.MinIdleConns == 0 { return } - for p.poolSize < p.cfg.PoolSize && p.idleConnsLen < p.cfg.MinIdleConns { + + // Only create idle connections if we haven't reached the total pool size limit + // MinIdleConns should be a subset of PoolSize, not additional connections + for p.poolSize.Load() < p.cfg.PoolSize && p.idleConnsLen.Load() < p.cfg.MinIdleConns { select { case p.queue <- struct{}{}: - p.poolSize++ - p.idleConnsLen++ - + p.poolSize.Add(1) + p.idleConnsLen.Add(1) go func() { defer func() { if err := recover(); err != nil { - p.connsMu.Lock() - p.poolSize-- - p.idleConnsLen-- - p.connsMu.Unlock() + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) p.freeTurn() internal.Logger.Printf(context.Background(), "addIdleConn panic: %+v", err) @@ -150,12 +198,9 @@ func (p *ConnPool) checkMinIdleConns() { err := p.addIdleConn() if err != nil && err != ErrClosed { - p.connsMu.Lock() - p.poolSize-- - p.idleConnsLen-- - p.connsMu.Unlock() + p.poolSize.Add(-1) + p.idleConnsLen.Add(-1) } - p.freeTurn() }() default: @@ -172,6 +217,9 @@ func (p *ConnPool) addIdleConn() error { if err != nil { return err } + // Mark connection as usable after successful creation + // This is essential for normal pool operations + cn.SetUsable(true) p.connsMu.Lock() defer p.connsMu.Unlock() @@ -182,11 +230,15 @@ func (p *ConnPool) addIdleConn() error { return ErrClosed } - p.conns = append(p.conns, cn) + p.conns[cn.GetID()] = cn p.idleConns = append(p.idleConns, cn) return nil } +// NewConn creates a new connection and returns it to the user. +// This will still obey MaxActiveConns but will not include it in the pool and won't increase the pool size. +// +// NOTE: If you directly get a connection from the pool, it won't be pooled and won't support hitless upgrades. func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { return p.newConn(ctx, false) } @@ -196,33 +248,42 @@ func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { return nil, ErrClosed } - p.connsMu.Lock() - if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns { - p.connsMu.Unlock() + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() >= int32(p.cfg.MaxActiveConns) { return nil, ErrPoolExhausted } - p.connsMu.Unlock() cn, err := p.dialConn(ctx, pooled) if err != nil { return nil, err } + // Mark connection as usable after successful creation + // This is essential for normal pool operations + cn.SetUsable(true) + + if p.cfg.MaxActiveConns > 0 && p.poolSize.Load() > int32(p.cfg.MaxActiveConns) { + _ = cn.Close() + return nil, ErrPoolExhausted + } p.connsMu.Lock() defer p.connsMu.Unlock() - - if p.cfg.MaxActiveConns > 0 && p.poolSize >= p.cfg.MaxActiveConns { + if p.closed() { _ = cn.Close() - return nil, ErrPoolExhausted + return nil, ErrClosed + } + // Check if pool was closed while we were waiting for the lock + if p.conns == nil { + p.conns = make(map[uint64]*Conn) } + p.conns[cn.GetID()] = cn - p.conns = append(p.conns, cn) if pooled { // If pool is full remove the cn on next Put. - if p.poolSize >= p.cfg.PoolSize { + currentPoolSize := p.poolSize.Load() + if currentPoolSize >= int32(p.cfg.PoolSize) { cn.pooled = false } else { - p.poolSize++ + p.poolSize.Add(1) } } @@ -249,6 +310,12 @@ func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { cn := NewConnWithBufferSize(netConn, p.cfg.ReadBufferSize, p.cfg.WriteBufferSize) cn.pooled = pooled + if p.cfg.ConnMaxLifetime > 0 { + cn.expiresAt = time.Now().Add(p.cfg.ConnMaxLifetime) + } else { + cn.expiresAt = noExpiration + } + return cn, nil } @@ -289,6 +356,14 @@ func (p *ConnPool) getLastDialError() error { // Get returns existed connection from the pool or creates a new one. func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { + return p.getConn(ctx) +} + +// getConn returns a connection from the pool. +func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { + var cn *Conn + var err error + if p.closed() { return nil, ErrClosed } @@ -297,9 +372,17 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { return nil, err } + now := time.Now() + attempts := 0 for { + if attempts >= getAttempts { + log.Printf("redis: connection pool: failed to get an connection accepted by hook after %d attempts", attempts) + break + } + attempts++ + p.connsMu.Lock() - cn, err := p.popIdle() + cn, err = p.popIdle() p.connsMu.Unlock() if err != nil { @@ -311,11 +394,25 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { break } - if !p.isHealthyConn(cn) { + if !p.isHealthyConn(cn, now) { _ = p.CloseConn(cn) continue } + // Process connection using the hooks system + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() + + if hookManager != nil { + if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil { + log.Printf("redis: connection pool: failed to process idle connection by hook: %v", err) + // Failed to process connection, discard it + _ = p.CloseConn(cn) + continue + } + } + atomic.AddUint32(&p.stats.Hits, 1) return cn, nil } @@ -328,6 +425,20 @@ func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { return nil, err } + // Process connection using the hooks system + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() + + if hookManager != nil { + if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil { + // Failed to process connection, discard it + log.Printf("redis: connection pool: failed to process new connection by hook: %v", err) + _ = p.CloseConn(newcn) + return nil, err + } + } + return newcn, nil } @@ -356,7 +467,7 @@ func (p *ConnPool) waitTurn(ctx context.Context) error { } return ctx.Err() case p.queue <- struct{}{}: - p.waitDurationNs.Add(time.Since(start).Nanoseconds()) + p.waitDurationNs.Add(time.Now().UnixNano() - start.UnixNano()) atomic.AddUint32(&p.stats.WaitCount, 1) if !timer.Stop() { <-timer.C @@ -376,68 +487,128 @@ func (p *ConnPool) popIdle() (*Conn, error) { if p.closed() { return nil, ErrClosed } + n := len(p.idleConns) if n == 0 { return nil, nil } var cn *Conn - if p.cfg.PoolFIFO { - cn = p.idleConns[0] - copy(p.idleConns, p.idleConns[1:]) - p.idleConns = p.idleConns[:n-1] - } else { - idx := n - 1 - cn = p.idleConns[idx] - p.idleConns = p.idleConns[:idx] + attempts := 0 + + for attempts < popAttempts { + if len(p.idleConns) == 0 { + return nil, nil + } + + if p.cfg.PoolFIFO { + cn = p.idleConns[0] + copy(p.idleConns, p.idleConns[1:]) + p.idleConns = p.idleConns[:len(p.idleConns)-1] + } else { + idx := len(p.idleConns) - 1 + cn = p.idleConns[idx] + p.idleConns = p.idleConns[:idx] + } + attempts++ + + if cn.IsUsable() { + p.idleConnsLen.Add(-1) + break + } + + // Connection is not usable, put it back in the pool + if p.cfg.PoolFIFO { + // FIFO: put at end (will be picked up last since we pop from front) + p.idleConns = append(p.idleConns, cn) + } else { + // LIFO: put at beginning (will be picked up last since we pop from end) + p.idleConns = append([]*Conn{cn}, p.idleConns...) + } } - p.idleConnsLen-- + + // If we exhausted all attempts without finding a usable connection, return nil + if attempts >= popAttempts { + log.Printf("redis: connection pool: failed to get an usable connection after %d attempts", popAttempts) + return nil, nil + } + p.checkMinIdleConns() return cn, nil } func (p *ConnPool) Put(ctx context.Context, cn *Conn) { + // Process connection using the hooks system + shouldPool := true shouldRemove := false - if cn.rd.Buffered() > 0 { - // Check if this might be push notification data - if p.cfg.Protocol == 3 { - // we know that there is something in the buffer, so peek at the next reply type without - // the potential to block and check if it's a push notification - if replyType, err := cn.rd.PeekReplyType(); err != nil || replyType != proto.RespPush { - shouldRemove = true - } - } else { - // not a push notification since protocol 2 doesn't support them - shouldRemove = true + var err error + + if cn.HasBufferedData() { + // Peek at the reply type to check if it's a push notification + if replyType, err := cn.PeekReplyTypeSafe(); err != nil || replyType != proto.RespPush { + // Not a push notification or error peeking, remove connection + internal.Logger.Printf(ctx, "Conn has unread data (not push notification), removing it") + p.Remove(ctx, cn, err) } + // It's a push notification, allow pooling (client will handle it) + } + + p.hookManagerMu.RLock() + hookManager := p.hookManager + p.hookManagerMu.RUnlock() - if shouldRemove { - // For non-RESP3 or data that is not a push notification, buffered data is unexpected - internal.Logger.Printf(ctx, "Conn has unread data, closing it") - p.Remove(ctx, cn, BadConnError{}) + if hookManager != nil { + shouldPool, shouldRemove, err = hookManager.ProcessOnPut(ctx, cn) + if err != nil { + internal.Logger.Printf(ctx, "Connection hook error: %v", err) + p.Remove(ctx, cn, err) return } } + // If hooks say to remove the connection, do so + if shouldRemove { + p.Remove(ctx, cn, errors.New("hook requested removal")) + return + } + + // If processor says not to pool the connection, remove it + if !shouldPool { + p.Remove(ctx, cn, errors.New("hook requested no pooling")) + return + } + if !cn.pooled { - p.Remove(ctx, cn, nil) + p.Remove(ctx, cn, errors.New("connection not pooled")) return } var shouldCloseConn bool - p.connsMu.Lock() - - if p.cfg.MaxIdleConns == 0 || p.idleConnsLen < p.cfg.MaxIdleConns { - p.idleConns = append(p.idleConns, cn) - p.idleConnsLen++ + if p.cfg.MaxIdleConns == 0 || p.idleConnsLen.Load() < p.cfg.MaxIdleConns { + // unusable conns are expected to become usable at some point (background process is reconnecting them) + // put them at the opposite end of the queue + if !cn.IsUsable() { + if p.cfg.PoolFIFO { + p.connsMu.Lock() + p.idleConns = append(p.idleConns, cn) + p.connsMu.Unlock() + } else { + p.connsMu.Lock() + p.idleConns = append([]*Conn{cn}, p.idleConns...) + p.connsMu.Unlock() + } + } else { + p.connsMu.Lock() + p.idleConns = append(p.idleConns, cn) + p.connsMu.Unlock() + } + p.idleConnsLen.Add(1) } else { - p.removeConn(cn) + p.removeConnWithLock(cn) shouldCloseConn = true } - p.connsMu.Unlock() - p.freeTurn() if shouldCloseConn { @@ -449,6 +620,9 @@ func (p *ConnPool) Remove(_ context.Context, cn *Conn, reason error) { p.removeConnWithLock(cn) p.freeTurn() _ = p.closeConn(cn) + + // Check if we need to create new idle connections to maintain MinIdleConns + p.checkMinIdleConns() } func (p *ConnPool) CloseConn(cn *Conn) error { @@ -463,17 +637,13 @@ func (p *ConnPool) removeConnWithLock(cn *Conn) { } func (p *ConnPool) removeConn(cn *Conn) { - for i, c := range p.conns { - if c == cn { - p.conns = append(p.conns[:i], p.conns[i+1:]...) - if cn.pooled { - p.poolSize-- - p.checkMinIdleConns() - } - break - } - } + delete(p.conns, cn.GetID()) atomic.AddUint32(&p.stats.StaleConns, 1) + + // Decrement pool size counter when removing a connection + if cn.pooled { + p.poolSize.Add(-1) + } } func (p *ConnPool) closeConn(cn *Conn) error { @@ -491,9 +661,9 @@ func (p *ConnPool) Len() int { // IdleLen returns number of idle connections. func (p *ConnPool) IdleLen() int { p.connsMu.Lock() - n := p.idleConnsLen + n := p.idleConnsLen.Load() p.connsMu.Unlock() - return n + return int(n) } func (p *ConnPool) Stats() *Stats { @@ -502,6 +672,7 @@ func (p *ConnPool) Stats() *Stats { Misses: atomic.LoadUint32(&p.stats.Misses), Timeouts: atomic.LoadUint32(&p.stats.Timeouts), WaitCount: atomic.LoadUint32(&p.stats.WaitCount), + Unusable: atomic.LoadUint32(&p.stats.Unusable), WaitDurationNs: p.waitDurationNs.Load(), TotalConns: uint32(p.Len()), @@ -542,30 +713,32 @@ func (p *ConnPool) Close() error { } } p.conns = nil - p.poolSize = 0 + p.poolSize.Store(0) p.idleConns = nil - p.idleConnsLen = 0 + p.idleConnsLen.Store(0) p.connsMu.Unlock() return firstErr } -func (p *ConnPool) isHealthyConn(cn *Conn) bool { - now := time.Now() - - if p.cfg.ConnMaxLifetime > 0 && now.Sub(cn.createdAt) >= p.cfg.ConnMaxLifetime { +func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { + // slight optimization, check expiresAt first. + if cn.expiresAt.Before(now) { return false } + if p.cfg.ConnMaxIdleTime > 0 && now.Sub(cn.UsedAt()) >= p.cfg.ConnMaxIdleTime { return false } - // Check connection health, but be aware of push notifications - if err := connCheck(cn.netConn); err != nil { + cn.SetUsedAt(now) + // Check basic connection health + // Use GetNetConn() to safely access netConn and avoid data races + if err := connCheck(cn.getNetConn()); err != nil { // If there's unexpected data, it might be push notifications (RESP3) // However, push notification processing is now handled by the client // before WithReader to ensure proper context is available to handlers - if err == errUnexpectedRead && p.cfg.Protocol == 3 { + if p.cfg.PushNotificationsEnabled && err == errUnexpectedRead { // we know that there is something in the buffer, so peek at the next reply type without // the potential to block if replyType, err := cn.rd.PeekReplyType(); err == nil && replyType == proto.RespPush { @@ -579,7 +752,7 @@ func (p *ConnPool) isHealthyConn(cn *Conn) bool { return false } } - - cn.SetUsedAt(now) return true } + + diff --git a/internal/pool/pool_single.go b/internal/pool/pool_single.go index 5a3fde191b..136d6f2dd8 100644 --- a/internal/pool/pool_single.go +++ b/internal/pool/pool_single.go @@ -1,6 +1,8 @@ package pool -import "context" +import ( + "context" +) type SingleConnPool struct { pool Pooler @@ -56,3 +58,7 @@ func (p *SingleConnPool) IdleLen() int { func (p *SingleConnPool) Stats() *Stats { return &Stats{} } + +func (p *SingleConnPool) AddPoolHook(hook PoolHook) {} + +func (p *SingleConnPool) RemovePoolHook(hook PoolHook) {} diff --git a/internal/pool/pool_sticky.go b/internal/pool/pool_sticky.go index 3adb99bc82..dc4266a4fc 100644 --- a/internal/pool/pool_sticky.go +++ b/internal/pool/pool_sticky.go @@ -199,3 +199,7 @@ func (p *StickyConnPool) IdleLen() int { func (p *StickyConnPool) Stats() *Stats { return &Stats{} } + +func (p *StickyConnPool) AddPoolHook(hook PoolHook) {} + +func (p *StickyConnPool) RemovePoolHook(hook PoolHook) {} diff --git a/internal/pool/pool_test.go b/internal/pool/pool_test.go index 736323d9dd..01cda618d1 100644 --- a/internal/pool/pool_test.go +++ b/internal/pool/pool_test.go @@ -2,6 +2,7 @@ package pool_test import ( "context" + "errors" "net" "sync" "testing" @@ -20,7 +21,7 @@ var _ = Describe("ConnPool", func() { BeforeEach(func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 10, + PoolSize: int32(10), PoolTimeout: time.Hour, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Millisecond, @@ -45,11 +46,11 @@ var _ = Describe("ConnPool", func() { <-closedChan return &net.TCPConn{}, nil }, - PoolSize: 10, + PoolSize: int32(10), PoolTimeout: time.Hour, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Millisecond, - MinIdleConns: minIdleConns, + MinIdleConns: int32(minIdleConns), }) wg.Wait() Expect(connPool.Close()).NotTo(HaveOccurred()) @@ -105,7 +106,7 @@ var _ = Describe("ConnPool", func() { // ok } - connPool.Remove(ctx, cn, nil) + connPool.Remove(ctx, cn, errors.New("test")) // Check that Get is unblocked. select { @@ -130,8 +131,8 @@ var _ = Describe("MinIdleConns", func() { newConnPool := func() *pool.ConnPool { connPool := pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: poolSize, - MinIdleConns: minIdleConns, + PoolSize: int32(poolSize), + MinIdleConns: int32(minIdleConns), PoolTimeout: 100 * time.Millisecond, DialTimeout: 1 * time.Second, ConnMaxIdleTime: -1, @@ -168,7 +169,7 @@ var _ = Describe("MinIdleConns", func() { Context("after Remove", func() { BeforeEach(func() { - connPool.Remove(ctx, cn, nil) + connPool.Remove(ctx, cn, errors.New("test")) }) It("has idle connections", func() { @@ -245,7 +246,7 @@ var _ = Describe("MinIdleConns", func() { BeforeEach(func() { perform(len(cns), func(i int) { mu.RLock() - connPool.Remove(ctx, cns[i], nil) + connPool.Remove(ctx, cns[i], errors.New("test")) mu.RUnlock() }) @@ -309,7 +310,7 @@ var _ = Describe("race", func() { It("does not happen on Get, Put, and Remove", func() { connPool = pool.NewConnPool(&pool.Options{ Dialer: dummyDialer, - PoolSize: 10, + PoolSize: int32(10), PoolTimeout: time.Minute, DialTimeout: 1 * time.Second, ConnMaxIdleTime: time.Millisecond, @@ -328,7 +329,7 @@ var _ = Describe("race", func() { cn, err := connPool.Get(ctx) Expect(err).NotTo(HaveOccurred()) if err == nil { - connPool.Remove(ctx, cn, nil) + connPool.Remove(ctx, cn, errors.New("test")) } } }) @@ -339,15 +340,15 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil }, - PoolSize: 1000, - MinIdleConns: 50, + PoolSize: int32(1000), + MinIdleConns: int32(50), PoolTimeout: 3 * time.Second, DialTimeout: 1 * time.Second, } p := pool.NewConnPool(opt) var wg sync.WaitGroup - for i := 0; i < opt.PoolSize; i++ { + for i := int32(0); i < opt.PoolSize; i++ { wg.Add(1) go func() { defer wg.Done() @@ -366,8 +367,8 @@ var _ = Describe("race", func() { Dialer: func(ctx context.Context) (net.Conn, error) { panic("test panic") }, - PoolSize: 100, - MinIdleConns: 30, + PoolSize: int32(100), + MinIdleConns: int32(30), } p := pool.NewConnPool(opt) @@ -377,14 +378,14 @@ var _ = Describe("race", func() { state := p.Stats() return state.TotalConns == 0 && state.IdleConns == 0 && p.QueueLen() == 0 }, "3s", "50ms").Should(BeTrue()) - }) - + }) + It("wait", func() { opt := &pool.Options{ Dialer: func(ctx context.Context) (net.Conn, error) { return &net.TCPConn{}, nil }, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: 3 * time.Second, } p := pool.NewConnPool(opt) @@ -415,7 +416,7 @@ var _ = Describe("race", func() { return &net.TCPConn{}, nil }, - PoolSize: 1, + PoolSize: int32(1), PoolTimeout: testPoolTimeout, } p := pool.NewConnPool(opt) diff --git a/internal/pool/pubsub.go b/internal/pool/pubsub.go new file mode 100644 index 0000000000..a06abcd6b8 --- /dev/null +++ b/internal/pool/pubsub.go @@ -0,0 +1,77 @@ +package pool + +import ( + "context" + "net" + "sync" + "sync/atomic" +) + +type PubSubStats struct { + Created uint32 + Untracked uint32 + Active uint32 +} + +// PubSubPool manages a pool of PubSub connections. +type PubSubPool struct { + opt *Options + netDialer func(ctx context.Context, network, addr string) (net.Conn, error) + + // Map to track active PubSub connections + activeConns sync.Map // map[uint64]*Conn (connID -> conn) + closed atomic.Bool + stats PubSubStats +} + +func NewPubSubPool(opt *Options, netDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *PubSubPool { + return &PubSubPool{ + opt: opt, + netDialer: netDialer, + } +} + +func (p *PubSubPool) NewConn(ctx context.Context, network string, addr string, channels []string) (*Conn, error) { + if p.closed.Load() { + return nil, ErrClosed + } + + netConn, err := p.netDialer(ctx, network, addr) + if err != nil { + return nil, err + } + cn := NewConnWithBufferSize(netConn, p.opt.ReadBufferSize, p.opt.WriteBufferSize) + atomic.AddUint32(&p.stats.Created, 1) + return cn, nil + +} + +func (p *PubSubPool) TrackConn(cn *Conn) { + atomic.AddUint32(&p.stats.Active, 1) + p.activeConns.Store(cn.GetID(), cn) +} + +func (p *PubSubPool) UntrackConn(cn *Conn) { + atomic.AddUint32(&p.stats.Active, ^uint32(0)) + atomic.AddUint32(&p.stats.Untracked, 1) + p.activeConns.Delete(cn.GetID()) +} + +func (p *PubSubPool) Close() error { + p.closed.Store(true) + p.activeConns.Range(func(key, value interface{}) bool { + cn := value.(*Conn) + _ = cn.Close() + return true + }) + return nil +} + +func (p *PubSubPool) Stats() *PubSubStats { + // load stats atomically + return &PubSubStats{ + Created: atomic.LoadUint32(&p.stats.Created), + Untracked: atomic.LoadUint32(&p.stats.Untracked), + Active: atomic.LoadUint32(&p.stats.Active), + } +} diff --git a/internal/redis.go b/internal/redis.go new file mode 100644 index 0000000000..0459e42ba9 --- /dev/null +++ b/internal/redis.go @@ -0,0 +1,3 @@ +package internal + +const RedisNull = "null" diff --git a/internal/util/math.go b/internal/util/math.go new file mode 100644 index 0000000000..e707c47a64 --- /dev/null +++ b/internal/util/math.go @@ -0,0 +1,17 @@ +package util + +// Max returns the maximum of two integers +func Max(a, b int) int { + if a > b { + return a + } + return b +} + +// Min returns the minimum of two integers +func Min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/options.go b/options.go index 237be6be0f..3c5d364cef 100644 --- a/options.go +++ b/options.go @@ -14,9 +14,10 @@ import ( "time" "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/hitless" "github.com/redis/go-redis/v9/internal/pool" - "github.com/redis/go-redis/v9/push" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/push" ) // Limiter is the interface of a rate limiter or a circuit breaker. @@ -153,6 +154,7 @@ type Options struct { // // Note that FIFO has slightly higher overhead compared to LIFO, // but it helps closing idle connections faster reducing the pool size. + // default: false PoolFIFO bool // PoolSize is the base number of socket connections. @@ -244,8 +246,19 @@ type Options struct { // When a node is marked as failing, it will be avoided for this duration. // Default is 15 seconds. FailingTimeoutSeconds int + + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Mode is not "disabled", the client will handle + // cluster upgrade notifications gracefully and manage connection/pool state + // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it. + HitlessUpgradeConfig *HitlessUpgradeConfig } +// HitlessUpgradeConfig provides configuration options for hitless upgrades. +// This is an alias to hitless.Config for convenience. +type HitlessUpgradeConfig = hitless.Config + func (opt *Options) init() { if opt.Addr == "" { opt.Addr = "localhost:6379" @@ -320,13 +333,36 @@ func (opt *Options) init() { case 0: opt.MaxRetryBackoff = 512 * time.Millisecond } + + opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolSize(opt.PoolSize) + + // auto-detect endpoint type if not specified + endpointType := opt.HitlessUpgradeConfig.EndpointType + if endpointType == "" || endpointType == hitless.EndpointTypeAuto { + // Auto-detect endpoint type if not specified + endpointType = hitless.DetectEndpointType(opt.Addr, opt.TLSConfig != nil) + } + opt.HitlessUpgradeConfig.EndpointType = endpointType } func (opt *Options) clone() *Options { clone := *opt + + // Deep clone HitlessUpgradeConfig to avoid sharing between clients + if opt.HitlessUpgradeConfig != nil { + configClone := *opt.HitlessUpgradeConfig + clone.HitlessUpgradeConfig = &configClone + } + return &clone } +// NewDialer returns a function that will be used as the default dialer +// when none is specified in Options.Dialer. +func (opt *Options) NewDialer() func(context.Context, string, string) (net.Conn, error) { + return NewDialer(opt) +} + // NewDialer returns a function that will be used as the default dialer // when none is specified in Options.Dialer. func NewDialer(opt *Options) func(context.Context, string, string) (net.Conn, error) { @@ -617,18 +653,35 @@ func newConnPool( Dialer: func(ctx context.Context) (net.Conn, error) { return dialer(ctx, opt.Network, opt.Addr) }, - PoolFIFO: opt.PoolFIFO, - PoolSize: opt.PoolSize, - PoolTimeout: opt.PoolTimeout, - DialTimeout: opt.DialTimeout, - MinIdleConns: opt.MinIdleConns, - MaxIdleConns: opt.MaxIdleConns, - MaxActiveConns: opt.MaxActiveConns, - ConnMaxIdleTime: opt.ConnMaxIdleTime, - ConnMaxLifetime: opt.ConnMaxLifetime, - // Pass protocol version for push notification optimization - Protocol: opt.Protocol, - ReadBufferSize: opt.ReadBufferSize, - WriteBufferSize: opt.WriteBufferSize, + PoolFIFO: opt.PoolFIFO, + PoolSize: int32(opt.PoolSize), + PoolTimeout: opt.PoolTimeout, + DialTimeout: opt.DialTimeout, + MinIdleConns: int32(opt.MinIdleConns), + MaxIdleConns: int32(opt.MaxIdleConns), + MaxActiveConns: int32(opt.MaxActiveConns), + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ReadBufferSize: opt.ReadBufferSize, + WriteBufferSize: opt.WriteBufferSize, + PushNotificationsEnabled: opt.Protocol == 3, }) } + +func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error), +) *pool.PubSubPool { + return pool.NewPubSubPool(&pool.Options{ + PoolFIFO: opt.PoolFIFO, + PoolSize: int32(opt.PoolSize), + PoolTimeout: opt.PoolTimeout, + DialTimeout: opt.DialTimeout, + MinIdleConns: int32(opt.MinIdleConns), + MaxIdleConns: int32(opt.MaxIdleConns), + MaxActiveConns: int32(opt.MaxActiveConns), + ConnMaxIdleTime: opt.ConnMaxIdleTime, + ConnMaxLifetime: opt.ConnMaxLifetime, + ReadBufferSize: 32 * 1024, + WriteBufferSize: 32 * 1024, + PushNotificationsEnabled: opt.Protocol == 3, + }, dialer) +} diff --git a/osscluster.go b/osscluster.go index ec77a95cde..63c7481961 100644 --- a/osscluster.go +++ b/osscluster.go @@ -38,6 +38,7 @@ type ClusterOptions struct { ClientName string // NewClient creates a cluster node client with provided name and options. + // If NewClient is set by the user, the user is responsible for handling hitless upgrades and push notifications. NewClient func(opt *Options) *Client // The maximum number of retries before giving up. Command is retried @@ -129,6 +130,14 @@ type ClusterOptions struct { // When a node is marked as failing, it will be avoided for this duration. // Default is 15 seconds. FailingTimeoutSeconds int + + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Mode is not "disabled", the client will handle + // cluster upgrade notifications gracefully and manage connection/pool state + // transitions seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are in "auto" mode and will be enabled if the server supports it. + // The ClusterClient does not directly work with hitless, it is up to the clients in the Nodes map to work with hitless. + HitlessUpgradeConfig *HitlessUpgradeConfig } func (opt *ClusterOptions) init() { @@ -319,6 +328,13 @@ func setupClusterQueryParams(u *url.URL, o *ClusterOptions) (*ClusterOptions, er } func (opt *ClusterOptions) clientOptions() *Options { + // Clone HitlessUpgradeConfig to avoid sharing between cluster node clients + var hitlessConfig *HitlessUpgradeConfig + if opt.HitlessUpgradeConfig != nil { + configClone := *opt.HitlessUpgradeConfig + hitlessConfig = &configClone + } + return &Options{ ClientName: opt.ClientName, Dialer: opt.Dialer, @@ -360,8 +376,9 @@ func (opt *ClusterOptions) clientOptions() *Options { // much use for ClusterSlots config). This means we cannot execute the // READONLY command against that node -- setting readOnly to false in such // situations in the options below will prevent that from happening. - readOnly: opt.ReadOnly && opt.ClusterSlots == nil, - UnstableResp3: opt.UnstableResp3, + readOnly: opt.ReadOnly && opt.ClusterSlots == nil, + UnstableResp3: opt.UnstableResp3, + HitlessUpgradeConfig: hitlessConfig, } } @@ -1830,12 +1847,12 @@ func (c *ClusterClient) Watch(ctx context.Context, fn func(*Tx) error, keys ...s return err } +// hitless won't work here for now func (c *ClusterClient) pubSub() *PubSub { var node *clusterNode pubsub := &PubSub{ opt: c.opt.clientOptions(), - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { if node != nil { panic("node != nil") } @@ -1850,18 +1867,25 @@ func (c *ClusterClient) pubSub() *PubSub { if err != nil { return nil, err } - - cn, err := node.Client.newConn(context.TODO()) + cn, err := node.Client.pubSubPool.NewConn(ctx, node.Client.opt.Network, node.Client.opt.Addr, channels) if err != nil { node = nil - return nil, err } - + // will return nil if already initialized + err = node.Client.initConn(ctx, cn) + if err != nil { + _ = cn.Close() + node = nil + return nil, err + } + node.Client.pubSubPool.TrackConn(cn) return cn, nil }, closeConn: func(cn *pool.Conn) error { - err := node.Client.connPool.CloseConn(cn) + // Untrack connection from PubSubPool + node.Client.pubSubPool.UntrackConn(cn) + err := cn.Close() node = nil return err }, diff --git a/pool_pubsub_bench_test.go b/pool_pubsub_bench_test.go new file mode 100644 index 0000000000..0db8ec55fa --- /dev/null +++ b/pool_pubsub_bench_test.go @@ -0,0 +1,375 @@ +// Pool and PubSub Benchmark Suite +// +// This file contains comprehensive benchmarks for both pool operations and PubSub initialization. +// It's designed to be run against different branches to compare performance. +// +// Usage Examples: +// # Run all benchmarks +// go test -bench=. -run='^$' -benchtime=1s pool_pubsub_bench_test.go +// +// # Run only pool benchmarks +// go test -bench=BenchmarkPool -run='^$' pool_pubsub_bench_test.go +// +// # Run only PubSub benchmarks +// go test -bench=BenchmarkPubSub -run='^$' pool_pubsub_bench_test.go +// +// # Compare between branches +// git checkout branch1 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch1.txt +// git checkout branch2 && go test -bench=. -run='^$' pool_pubsub_bench_test.go > branch2.txt +// benchcmp branch1.txt branch2.txt +// +// # Run with memory profiling +// go test -bench=BenchmarkPoolGetPut -run='^$' -memprofile=mem.prof pool_pubsub_bench_test.go +// +// # Run with CPU profiling +// go test -bench=BenchmarkPoolGetPut -run='^$' -cpuprofile=cpu.prof pool_pubsub_bench_test.go + +package redis_test + +import ( + "context" + "fmt" + "net" + "sync" + "testing" + "time" + + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/internal/pool" +) + +// dummyDialer creates a mock connection for benchmarking +func dummyDialer(ctx context.Context) (net.Conn, error) { + return &dummyConn{}, nil +} + +// dummyConn implements net.Conn for benchmarking +type dummyConn struct{} + +func (c *dummyConn) Read(b []byte) (n int, err error) { return len(b), nil } +func (c *dummyConn) Write(b []byte) (n int, err error) { return len(b), nil } +func (c *dummyConn) Close() error { return nil } +func (c *dummyConn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} } +func (c *dummyConn) RemoteAddr() net.Addr { + return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 6379} +} +func (c *dummyConn) SetDeadline(t time.Time) error { return nil } +func (c *dummyConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dummyConn) SetWriteDeadline(t time.Time) error { return nil } + +// ============================================================================= +// POOL BENCHMARKS +// ============================================================================= + +// BenchmarkPoolGetPut benchmarks the core pool Get/Put operations +func BenchmarkPoolGetPut(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 2, 4, 8, 16, 32, 64, 128} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: int32(poolSize), + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: int32(0), // Start with no idle connections + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// BenchmarkPoolGetPutWithMinIdle benchmarks pool operations with MinIdleConns +func BenchmarkPoolGetPutWithMinIdle(b *testing.B) { + ctx := context.Background() + + configs := []struct { + poolSize int + minIdleConns int + }{ + {8, 2}, + {16, 4}, + {32, 8}, + {64, 16}, + } + + for _, config := range configs { + b.Run(fmt.Sprintf("Pool_%d_MinIdle_%d", config.poolSize, config.minIdleConns), func(b *testing.B) { + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: int32(config.poolSize), + MinIdleConns: int32(config.minIdleConns), + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// BenchmarkPoolConcurrentGetPut benchmarks pool under high concurrency +func BenchmarkPoolConcurrentGetPut(b *testing.B) { + ctx := context.Background() + + connPool := pool.NewConnPool(&pool.Options{ + Dialer: dummyDialer, + PoolSize: int32(32), + PoolTimeout: time.Second, + DialTimeout: time.Second, + ConnMaxIdleTime: time.Hour, + MinIdleConns: int32(0), + }) + defer connPool.Close() + + b.ResetTimer() + b.ReportAllocs() + + // Test with different levels of concurrency + concurrencyLevels := []int{1, 2, 4, 8, 16, 32, 64} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) { + b.SetParallelism(concurrency) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + cn, err := connPool.Get(ctx) + if err != nil { + b.Fatal(err) + } + connPool.Put(ctx, cn) + } + }) + }) + } +} + +// ============================================================================= +// PUBSUB BENCHMARKS +// ============================================================================= + +// benchmarkClient creates a Redis client for benchmarking with mock dialer +func benchmarkClient(poolSize int) *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: "localhost:6379", // Mock address + DialTimeout: time.Second, + ReadTimeout: time.Second, + WriteTimeout: time.Second, + PoolSize: poolSize, + MinIdleConns: 0, // Start with no idle connections for consistent benchmarks + }) +} + +// BenchmarkPubSubCreation benchmarks PubSub creation and subscription +func BenchmarkPubSubCreation(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 4, 8, 16, 32} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + client := benchmarkClient(poolSize) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubPatternCreation benchmarks PubSub pattern subscription +func BenchmarkPubSubPatternCreation(b *testing.B) { + ctx := context.Background() + + poolSizes := []int{1, 4, 8, 16, 32} + + for _, poolSize := range poolSizes { + b.Run(fmt.Sprintf("PoolSize_%d", poolSize), func(b *testing.B) { + client := benchmarkClient(poolSize) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.PSubscribe(ctx, "test-*") + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubConcurrentCreation benchmarks concurrent PubSub creation +func BenchmarkPubSubConcurrentCreation(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + concurrencyLevels := []int{1, 2, 4, 8, 16} + + for _, concurrency := range concurrencyLevels { + b.Run(fmt.Sprintf("Concurrency_%d", concurrency), func(b *testing.B) { + b.ResetTimer() + b.ReportAllocs() + + var wg sync.WaitGroup + semaphore := make(chan struct{}, concurrency) + + for i := 0; i < b.N; i++ { + wg.Add(1) + semaphore <- struct{}{} + + go func() { + defer wg.Done() + defer func() { <-semaphore }() + + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + }() + } + + wg.Wait() + }) + } +} + +// BenchmarkPubSubMultipleChannels benchmarks subscribing to multiple channels +func BenchmarkPubSubMultipleChannels(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(16) + defer client.Close() + + channelCounts := []int{1, 5, 10, 25, 50, 100} + + for _, channelCount := range channelCounts { + b.Run(fmt.Sprintf("Channels_%d", channelCount), func(b *testing.B) { + // Prepare channel names + channels := make([]string, channelCount) + for i := 0; i < channelCount; i++ { + channels[i] = fmt.Sprintf("channel-%d", i) + } + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + pubsub := client.Subscribe(ctx, channels...) + pubsub.Close() + } + }) + } +} + +// BenchmarkPubSubReuse benchmarks reusing PubSub connections +func BenchmarkPubSubReuse(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(16) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + // Benchmark just the creation and closing of PubSub connections + // This simulates reuse patterns without requiring actual Redis operations + pubsub := client.Subscribe(ctx, fmt.Sprintf("test-channel-%d", i)) + pubsub.Close() + } +} + +// ============================================================================= +// COMBINED BENCHMARKS +// ============================================================================= + +// BenchmarkPoolAndPubSubMixed benchmarks mixed pool stats and PubSub operations +func BenchmarkPoolAndPubSubMixed(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // Mix of pool stats collection and PubSub creation + if pb.Next() { + // Pool stats operation + stats := client.PoolStats() + _ = stats.Hits + stats.Misses // Use the stats to prevent optimization + } + + if pb.Next() { + // PubSub operation + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + } + }) +} + +// BenchmarkPoolStatsCollection benchmarks pool statistics collection +func BenchmarkPoolStatsCollection(b *testing.B) { + client := benchmarkClient(16) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + stats := client.PoolStats() + _ = stats.Hits + stats.Misses + stats.Timeouts // Use the stats to prevent optimization + } +} + +// BenchmarkPoolHighContention tests pool performance under high contention +func BenchmarkPoolHighContention(b *testing.B) { + ctx := context.Background() + client := benchmarkClient(32) + defer client.Close() + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + // High contention Get/Put operations + pubsub := client.Subscribe(ctx, "test-channel") + pubsub.Close() + } + }) +} diff --git a/pubsub.go b/pubsub.go index 75327dd2aa..01188d10cf 100644 --- a/pubsub.go +++ b/pubsub.go @@ -22,7 +22,7 @@ import ( type PubSub struct { opt *Options - newConn func(ctx context.Context, channels []string) (*pool.Conn, error) + newConn func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) closeConn func(*pool.Conn) error mu sync.Mutex @@ -42,6 +42,9 @@ type PubSub struct { // Push notification processor for handling generic push notifications pushProcessor push.NotificationProcessor + + // Cleanup callback for hitless upgrade tracking + onClose func() } func (c *PubSub) init() { @@ -73,10 +76,18 @@ func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, er return c.cn, nil } + if c.opt.Addr == "" { + // TODO(hitless): + // this is probably cluster client + // c.newConn will ignore the addr argument + // will be changed when we have hitless upgrades for cluster clients + c.opt.Addr = internal.RedisNull + } + channels := mapKeys(c.channels) channels = append(channels, newChannels...) - cn, err := c.newConn(ctx, channels) + cn, err := c.newConn(ctx, c.opt.Addr, channels) if err != nil { return nil, err } @@ -157,12 +168,28 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo if c.cn != cn { return } + + if !cn.IsUsable() || cn.ShouldHandoff() { + c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable")) + } + if isBadConn(err, allowTimeout, c.opt.Addr) { c.reconnect(ctx, err) } } func (c *PubSub) reconnect(ctx context.Context, reason error) { + if c.cn != nil && c.cn.ShouldHandoff() { + newEndpoint := c.cn.GetHandoffEndpoint() + // If new endpoint is NULL, use the original address + if newEndpoint == internal.RedisNull { + newEndpoint = c.opt.Addr + } + + if newEndpoint != "" { + c.opt.Addr = newEndpoint + } + } _ = c.closeTheCn(reason) _, _ = c.conn(ctx, nil) } @@ -189,6 +216,11 @@ func (c *PubSub) Close() error { c.closed = true close(c.exit) + // Call cleanup callback if set + if c.onClose != nil { + c.onClose() + } + return c.closeTheCn(pool.ErrClosed) } @@ -461,6 +493,7 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int // Receive returns a message as a Subscription, Message, Pong or error. // See PubSub example for details. This is low-level API and in most cases // Channel should be used instead. +// This will block until a message is received. func (c *PubSub) Receive(ctx context.Context) (interface{}, error) { return c.ReceiveTimeout(ctx, 0) } @@ -543,7 +576,8 @@ func (c *PubSub) ChannelWithSubscriptions(opts ...ChannelOption) <-chan interfac } func (c *PubSub) processPendingPushNotificationWithReader(ctx context.Context, cn *pool.Conn, rd *proto.Reader) error { - if c.pushProcessor == nil { + // Only process push notifications for RESP3 connections with a processor + if c.opt.Protocol != 3 || c.pushProcessor == nil { return nil } diff --git a/push/handler_context.go b/push/handler_context.go index 3bcf128f18..f89f87fa1b 100644 --- a/push/handler_context.go +++ b/push/handler_context.go @@ -1,8 +1,6 @@ package push -import ( - "github.com/redis/go-redis/v9/internal/pool" -) +// No imports needed for this file // NotificationHandlerContext provides context information about where a push notification was received. // This struct allows handlers to make informed decisions based on the source of the notification @@ -35,7 +33,12 @@ type NotificationHandlerContext struct { PubSub interface{} // Conn is the specific connection on which the notification was received. - Conn *pool.Conn + // It is interface to both allow for future expansion and to avoid + // circular dependencies. The developer is responsible for type assertion. + // It can be one of the following types: + // - *pool.Conn + // - *connectionAdapter (for hitless upgrades) + Conn interface{} // IsBlocking indicates if the notification was received on a blocking connection. IsBlocking bool diff --git a/push/processor_unit_test.go b/push/processor_unit_test.go new file mode 100644 index 0000000000..ce7990489f --- /dev/null +++ b/push/processor_unit_test.go @@ -0,0 +1,315 @@ +package push + +import ( + "context" + "testing" +) + +// TestProcessorCreation tests processor creation and initialization +func TestProcessorCreation(t *testing.T) { + t.Run("NewProcessor", func(t *testing.T) { + processor := NewProcessor() + if processor == nil { + t.Fatal("NewProcessor should not return nil") + } + if processor.registry == nil { + t.Error("Processor should have a registry") + } + }) + + t.Run("NewVoidProcessor", func(t *testing.T) { + voidProcessor := NewVoidProcessor() + if voidProcessor == nil { + t.Fatal("NewVoidProcessor should not return nil") + } + }) +} + +// TestProcessorHandlerManagement tests handler registration and retrieval +func TestProcessorHandlerManagement(t *testing.T) { + processor := NewProcessor() + handler := &UnitTestHandler{name: "test-handler"} + + t.Run("RegisterHandler", func(t *testing.T) { + err := processor.RegisterHandler("TEST", handler, false) + if err != nil { + t.Errorf("RegisterHandler should not error: %v", err) + } + + // Verify handler is registered + retrievedHandler := processor.GetHandler("TEST") + if retrievedHandler != handler { + t.Error("GetHandler should return the registered handler") + } + }) + + t.Run("RegisterProtectedHandler", func(t *testing.T) { + protectedHandler := &UnitTestHandler{name: "protected-handler"} + err := processor.RegisterHandler("PROTECTED", protectedHandler, true) + if err != nil { + t.Errorf("RegisterHandler should not error for protected handler: %v", err) + } + + // Verify handler is registered + retrievedHandler := processor.GetHandler("PROTECTED") + if retrievedHandler != protectedHandler { + t.Error("GetHandler should return the protected handler") + } + }) + + t.Run("GetNonExistentHandler", func(t *testing.T) { + handler := processor.GetHandler("NONEXISTENT") + if handler != nil { + t.Error("GetHandler should return nil for non-existent handler") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + err := processor.UnregisterHandler("TEST") + if err != nil { + t.Errorf("UnregisterHandler should not error: %v", err) + } + + // Verify handler is removed + retrievedHandler := processor.GetHandler("TEST") + if retrievedHandler != nil { + t.Error("GetHandler should return nil after unregistering") + } + }) + + t.Run("UnregisterProtectedHandler", func(t *testing.T) { + err := processor.UnregisterHandler("PROTECTED") + if err == nil { + t.Error("UnregisterHandler should error for protected handler") + } + + // Verify handler is still there + retrievedHandler := processor.GetHandler("PROTECTED") + if retrievedHandler == nil { + t.Error("Protected handler should not be removed") + } + }) +} + +// TestVoidProcessorBehavior tests void processor behavior +func TestVoidProcessorBehavior(t *testing.T) { + voidProcessor := NewVoidProcessor() + handler := &UnitTestHandler{name: "test-handler"} + + t.Run("GetHandler", func(t *testing.T) { + retrievedHandler := voidProcessor.GetHandler("ANY") + if retrievedHandler != nil { + t.Error("VoidProcessor GetHandler should always return nil") + } + }) + + t.Run("RegisterHandler", func(t *testing.T) { + err := voidProcessor.RegisterHandler("TEST", handler, false) + if err == nil { + t.Error("VoidProcessor RegisterHandler should return error") + } + + // Check error type + if !IsVoidProcessorError(err) { + t.Error("Error should be a VoidProcessorError") + } + }) + + t.Run("UnregisterHandler", func(t *testing.T) { + err := voidProcessor.UnregisterHandler("TEST") + if err == nil { + t.Error("VoidProcessor UnregisterHandler should return error") + } + + // Check error type + if !IsVoidProcessorError(err) { + t.Error("Error should be a VoidProcessorError") + } + }) +} + +// TestProcessPendingNotificationsNilReader tests handling of nil reader +func TestProcessPendingNotificationsNilReader(t *testing.T) { + t.Run("ProcessorWithNilReader", func(t *testing.T) { + processor := NewProcessor() + ctx := context.Background() + handlerCtx := NotificationHandlerContext{} + + err := processor.ProcessPendingNotifications(ctx, handlerCtx, nil) + if err != nil { + t.Errorf("ProcessPendingNotifications should not error with nil reader: %v", err) + } + }) + + t.Run("VoidProcessorWithNilReader", func(t *testing.T) { + voidProcessor := NewVoidProcessor() + ctx := context.Background() + handlerCtx := NotificationHandlerContext{} + + err := voidProcessor.ProcessPendingNotifications(ctx, handlerCtx, nil) + if err != nil { + t.Errorf("VoidProcessor ProcessPendingNotifications should not error with nil reader: %v", err) + } + }) +} + +// TestWillHandleNotificationInClient tests the notification filtering logic +func TestWillHandleNotificationInClient(t *testing.T) { + testCases := []struct { + name string + notificationType string + shouldHandle bool + }{ + // Pub/Sub notifications (should be handled in client) + {"message", "message", true}, + {"pmessage", "pmessage", true}, + {"subscribe", "subscribe", true}, + {"unsubscribe", "unsubscribe", true}, + {"psubscribe", "psubscribe", true}, + {"punsubscribe", "punsubscribe", true}, + {"smessage", "smessage", true}, + {"ssubscribe", "ssubscribe", true}, + {"sunsubscribe", "sunsubscribe", true}, + + // Push notifications (should be handled by processor) + {"MOVING", "MOVING", false}, + {"MIGRATING", "MIGRATING", false}, + {"MIGRATED", "MIGRATED", false}, + {"FAILING_OVER", "FAILING_OVER", false}, + {"FAILED_OVER", "FAILED_OVER", false}, + {"custom", "custom", false}, + {"unknown", "unknown", false}, + {"empty", "", false}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + result := willHandleNotificationInClient(tc.notificationType) + if result != tc.shouldHandle { + t.Errorf("willHandleNotificationInClient(%q) = %v, want %v", tc.notificationType, result, tc.shouldHandle) + } + }) + } +} + +// TestProcessorErrorHandlingUnit tests error handling scenarios +func TestProcessorErrorHandlingUnit(t *testing.T) { + processor := NewProcessor() + + t.Run("RegisterNilHandler", func(t *testing.T) { + err := processor.RegisterHandler("TEST", nil, false) + if err == nil { + t.Error("RegisterHandler should error with nil handler") + } + + // Check error type + if !IsHandlerNilError(err) { + t.Error("Error should be a HandlerNilError") + } + }) + + t.Run("RegisterDuplicateHandler", func(t *testing.T) { + handler1 := &UnitTestHandler{name: "handler1"} + handler2 := &UnitTestHandler{name: "handler2"} + + // Register first handler + err := processor.RegisterHandler("DUPLICATE", handler1, false) + if err != nil { + t.Errorf("First RegisterHandler should not error: %v", err) + } + + // Try to register second handler with same name + err = processor.RegisterHandler("DUPLICATE", handler2, false) + if err == nil { + t.Error("RegisterHandler should error when registering duplicate handler") + } + + // Verify original handler is still there + retrievedHandler := processor.GetHandler("DUPLICATE") + if retrievedHandler != handler1 { + t.Error("Original handler should remain after failed duplicate registration") + } + }) + + t.Run("UnregisterNonExistentHandler", func(t *testing.T) { + err := processor.UnregisterHandler("NONEXISTENT") + if err != nil { + t.Errorf("UnregisterHandler should not error for non-existent handler: %v", err) + } + }) +} + +// TestProcessorConcurrentAccess tests concurrent access to processor +func TestProcessorConcurrentAccess(t *testing.T) { + processor := NewProcessor() + + t.Run("ConcurrentRegisterAndGet", func(t *testing.T) { + done := make(chan bool, 2) + + // Goroutine 1: Register handlers + go func() { + defer func() { done <- true }() + for i := 0; i < 100; i++ { + handler := &UnitTestHandler{name: "concurrent-handler"} + processor.RegisterHandler("CONCURRENT", handler, false) + processor.UnregisterHandler("CONCURRENT") + } + }() + + // Goroutine 2: Get handlers + go func() { + defer func() { done <- true }() + for i := 0; i < 100; i++ { + processor.GetHandler("CONCURRENT") + } + }() + + // Wait for both goroutines to complete + <-done + <-done + }) +} + +// TestProcessorInterfaceCompliance tests interface compliance +func TestProcessorInterfaceCompliance(t *testing.T) { + t.Run("ProcessorImplementsInterface", func(t *testing.T) { + var _ NotificationProcessor = (*Processor)(nil) + }) + + t.Run("VoidProcessorImplementsInterface", func(t *testing.T) { + var _ NotificationProcessor = (*VoidProcessor)(nil) + }) +} + +// UnitTestHandler is a test implementation of NotificationHandler +type UnitTestHandler struct { + name string + lastNotification []interface{} + errorToReturn error + callCount int +} + +func (h *UnitTestHandler) HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error { + h.callCount++ + h.lastNotification = notification + return h.errorToReturn +} + +// Helper methods for UnitTestHandler +func (h *UnitTestHandler) GetCallCount() int { + return h.callCount +} + +func (h *UnitTestHandler) GetLastNotification() []interface{} { + return h.lastNotification +} + +func (h *UnitTestHandler) SetErrorToReturn(err error) { + h.errorToReturn = err +} + +func (h *UnitTestHandler) Reset() { + h.callCount = 0 + h.lastNotification = nil + h.errorToReturn = nil +} diff --git a/push_notifications.go b/push_notifications.go index ceffe04ad5..572955fecb 100644 --- a/push_notifications.go +++ b/push_notifications.go @@ -4,24 +4,6 @@ import ( "github.com/redis/go-redis/v9/push" ) -// Push notification constants for cluster operations -const ( - // MOVING indicates a slot is being moved to a different node - PushNotificationMoving = "MOVING" - - // MIGRATING indicates a slot is being migrated from this node - PushNotificationMigrating = "MIGRATING" - - // MIGRATED indicates a slot has been migrated to this node - PushNotificationMigrated = "MIGRATED" - - // FAILING_OVER indicates a failover is starting - PushNotificationFailingOver = "FAILING_OVER" - - // FAILED_OVER indicates a failover has completed - PushNotificationFailedOver = "FAILED_OVER" -) - // NewPushNotificationProcessor creates a new push notification processor // This processor maintains a registry of handlers and processes push notifications // It is used for RESP3 connections where push notifications are available diff --git a/redis.go b/redis.go index b3608c5ff8..e5cca4af29 100644 --- a/redis.go +++ b/redis.go @@ -10,6 +10,7 @@ import ( "time" "github.com/redis/go-redis/v9/auth" + "github.com/redis/go-redis/v9/hitless" "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/hscan" "github.com/redis/go-redis/v9/internal/pool" @@ -204,19 +205,35 @@ func (hs *hooksMixin) processTxPipelineHook(ctx context.Context, cmds []Cmder) e //------------------------------------------------------------------------------ type baseClient struct { - opt *Options - connPool pool.Pooler + opt *Options + optLock sync.RWMutex + connPool pool.Pooler + pubSubPool *pool.PubSubPool hooksMixin onClose func() error // hook called when client is closed // Push notification processing pushProcessor push.NotificationProcessor + + // Hitless upgrade manager + hitlessManager *hitless.HitlessManager + hitlessManagerLock sync.RWMutex } func (c *baseClient) clone() *baseClient { - clone := *c - return &clone + c.hitlessManagerLock.RLock() + hitlessManager := c.hitlessManager + c.hitlessManagerLock.RUnlock() + + clone := &baseClient{ + opt: c.opt, + connPool: c.connPool, + onClose: c.onClose, + pushProcessor: c.pushProcessor, + hitlessManager: hitlessManager, + } + return clone } func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { @@ -234,21 +251,6 @@ func (c *baseClient) String() string { return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) } -func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { - cn, err := c.connPool.NewConn(ctx) - if err != nil { - return nil, err - } - - err = c.initConn(ctx, cn) - if err != nil { - _ = c.connPool.CloseConn(cn) - return nil, err - } - - return cn, nil -} - func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { if c.opt.Limiter != nil { err := c.opt.Limiter.Allow() @@ -274,7 +276,7 @@ func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { return nil, err } - if cn.Inited { + if cn.IsInited() { return cn, nil } @@ -356,12 +358,10 @@ func (c *baseClient) wrappedOnClose(newOnClose func() error) func() error { } func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { - if cn.Inited { + if !cn.Inited.CompareAndSwap(false, true) { return nil } - var err error - cn.Inited = true connPool := pool.NewSingleConnPool(c.connPool, cn) conn := newConn(c.opt, connPool, &c.hooksMixin) @@ -430,6 +430,50 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { return fmt.Errorf("failed to initialize connection options: %w", err) } + // Enable maintenance notifications if hitless upgrades are configured + c.optLock.RLock() + hitlessEnabled := c.opt.HitlessUpgradeConfig != nil && c.opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled + protocol := c.opt.Protocol + endpointType := c.opt.HitlessUpgradeConfig.EndpointType + c.optLock.RUnlock() + var hitlessHandshakeErr error + if hitlessEnabled && protocol == 3 { + hitlessHandshakeErr = conn.ClientMaintNotifications( + ctx, + true, + endpointType.String(), + ).Err() + if hitlessHandshakeErr != nil { + if !isRedisError(hitlessHandshakeErr) { + // if not redis error, fail the connection + return hitlessHandshakeErr + } + c.optLock.Lock() + // handshake failed - check and modify config atomically + switch c.opt.HitlessUpgradeConfig.Mode { + case hitless.MaintNotificationsEnabled: + // enabled mode, fail the connection + c.optLock.Unlock() + return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr) + default: // will handle auto and any other + c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled + c.optLock.Unlock() + // auto mode, disable hitless upgrades and continue + if err := c.disableHitlessUpgrades(); err != nil { + // Log error but continue - auto mode should be resilient + internal.Logger.Printf(ctx, "hitless: failed to disable hitless upgrades in auto mode: %v", err) + } + } + } else { + // handshake was executed successfully + // to make sure that the handshake will be executed on other connections as well if it was successfully + // executed on this connection, we will force the handshake to be executed on all connections + c.optLock.Lock() + c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsEnabled + c.optLock.Unlock() + } + } + if !c.opt.DisableIdentity && !c.opt.DisableIndentity { libName := "" libVer := Version() @@ -446,6 +490,12 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { } } + cn.SetUsable(true) + cn.Inited.Store(true) + + // Set the connection initialization function for potential reconnections + cn.SetInitConnFunc(c.createInitConnFunc()) + if c.opt.OnConnect != nil { return c.opt.OnConnect(ctx, conn) } @@ -593,19 +643,76 @@ func (c *baseClient) context(ctx context.Context) context.Context { return context.Background() } +// createInitConnFunc creates a connection initialization function that can be used for reconnections. +func (c *baseClient) createInitConnFunc() func(context.Context, *pool.Conn) error { + return func(ctx context.Context, cn *pool.Conn) error { + return c.initConn(ctx, cn) + } +} + +// enableHitlessUpgrades initializes the hitless upgrade manager and pool hook. +// This function is called during client initialization. +// will register push notification handlers for all hitless upgrade events. +// will start background workers for handoff processing in the pool hook. +func (c *baseClient) enableHitlessUpgrades() error { + // Create client adapter + clientAdapterInstance := newClientAdapter(c) + + // Create hitless manager directly + manager, err := hitless.NewHitlessManager(clientAdapterInstance, c.connPool, c.opt.HitlessUpgradeConfig) + if err != nil { + return err + } + // Set the manager reference and initialize pool hook + c.hitlessManagerLock.Lock() + c.hitlessManager = manager + c.hitlessManagerLock.Unlock() + + // Initialize pool hook (safe to call without lock since manager is now set) + manager.InitPoolHook(c.dialHook) + return nil +} + +func (c *baseClient) disableHitlessUpgrades() error { + c.hitlessManagerLock.Lock() + defer c.hitlessManagerLock.Unlock() + + // Close the hitless manager + if c.hitlessManager != nil { + // Closing the manager will also shutdown the pool hook + // and remove it from the pool + c.hitlessManager.Close() + c.hitlessManager = nil + } + return nil +} + // Close closes the client, releasing any open resources. // // It is rare to Close a Client, as the Client is meant to be // long-lived and shared between many goroutines. func (c *baseClient) Close() error { var firstErr error + + // Close hitless manager first + if err := c.disableHitlessUpgrades(); err != nil { + firstErr = err + } + if c.onClose != nil { - if err := c.onClose(); err != nil { + if err := c.onClose(); err != nil && firstErr == nil { firstErr = err } } - if err := c.connPool.Close(); err != nil && firstErr == nil { - firstErr = err + if c.connPool != nil { + if err := c.connPool.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + if c.pubSubPool != nil { + if err := c.pubSubPool.Close(); err != nil && firstErr == nil { + firstErr = err + } } return firstErr } @@ -810,11 +917,24 @@ func NewClient(opt *Options) *Client { // Initialize push notification processor using shared helper // Use void processor for RESP2 connections (push notifications not available) c.pushProcessor = initializePushProcessor(opt) - - // Update options with the initialized push processor for connection pool + // Update options with the initialized push processor opt.PushNotificationProcessor = c.pushProcessor + // Create connection pools c.connPool = newConnPool(opt, c.dialHook) + c.pubSubPool = newPubSubPool(opt, c.dialHook) + + // Initialize hitless upgrades first if enabled and protocol is RESP3 + if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 { + err := c.enableHitlessUpgrades() + if err != nil { + internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err) + if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled { + // panic so we fail fast without breaking existing clients api + panic(fmt.Errorf("failed to enable hitless upgrades: %w", err)) + } + } + } return &c } @@ -851,6 +971,14 @@ func (c *Client) Options() *Options { return c.opt } +// GetHitlessManager returns the hitless manager instance for monitoring and control. +// Returns nil if hitless upgrades are not enabled. +func (c *Client) GetHitlessManager() *hitless.HitlessManager { + c.hitlessManagerLock.RLock() + defer c.hitlessManagerLock.RUnlock() + return c.hitlessManager +} + // initializePushProcessor initializes the push notification processor for any client type. // This is a shared helper to avoid duplication across NewClient, NewFailoverClient, and NewSentinelClient. func initializePushProcessor(opt *Options) push.NotificationProcessor { @@ -887,6 +1015,7 @@ type PoolStats pool.Stats // PoolStats returns connection pool stats. func (c *Client) PoolStats() *PoolStats { stats := c.connPool.Stats() + stats.PubSubStats = *(c.pubSubPool.Stats()) return (*PoolStats)(stats) } @@ -921,11 +1050,27 @@ func (c *Client) TxPipeline() Pipeliner { func (c *Client) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { + cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels) + if err != nil { + return nil, err + } + // will return nil if already initialized + err = c.initConn(ctx, cn) + if err != nil { + _ = cn.Close() + return nil, err + } + // Track connection in PubSubPool + c.pubSubPool.TrackConn(cn) + return cn, nil + }, + closeConn: func(cn *pool.Conn) error { + // Untrack connection from PubSubPool + c.pubSubPool.UntrackConn(cn) + _ = cn.Close() + return nil }, - closeConn: c.connPool.CloseConn, pushProcessor: c.pushProcessor, } pubsub.init() @@ -1113,6 +1258,6 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica return push.NotificationHandlerContext{ Client: c, ConnPool: c.connPool, - Conn: cn, + Conn: &connectionAdapter{conn: cn}, // Wrap in adapter for easier interface access } } diff --git a/redis_test.go b/redis_test.go index 6aaa0a7547..27b69ed14b 100644 --- a/redis_test.go +++ b/redis_test.go @@ -12,7 +12,6 @@ import ( . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" - "github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9/auth" ) diff --git a/sentinel.go b/sentinel.go index 2509d70fe3..8ae284dec1 100644 --- a/sentinel.go +++ b/sentinel.go @@ -16,8 +16,8 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/rand" - "github.com/redis/go-redis/v9/push" "github.com/redis/go-redis/v9/internal/util" + "github.com/redis/go-redis/v9/push" ) //------------------------------------------------------------------------------ @@ -139,6 +139,14 @@ type FailoverOptions struct { FailingTimeoutSeconds int UnstableResp3 bool + + // Hitless is not supported for FailoverClients at the moment + // HitlessUpgradeConfig provides custom configuration for hitless upgrades. + // When HitlessUpgradeConfig.Mode is not "disabled", the client will handle + // upgrade notifications gracefully and manage connection/pool state transitions + // seamlessly. Requires Protocol: 3 (RESP3) for push notifications. + // If nil, hitless upgrades are disabled. + //HitlessUpgradeConfig *HitlessUpgradeConfig } func (opt *FailoverOptions) clientOptions() *Options { @@ -456,8 +464,6 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { opt.Dialer = masterReplicaDialer(failover) opt.init() - var connPool *pool.ConnPool - rdb := &Client{ baseClient: &baseClient{ opt: opt, @@ -469,15 +475,18 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { // Use void processor by default for RESP2 connections rdb.pushProcessor = initializePushProcessor(opt) - connPool = newConnPool(opt, rdb.dialHook) - rdb.connPool = connPool + rdb.connPool = newConnPool(opt, rdb.dialHook) + rdb.pubSubPool = newPubSubPool(opt, rdb.dialHook) + rdb.onClose = rdb.wrappedOnClose(failover.Close) failover.mu.Lock() failover.onFailover = func(ctx context.Context, addr string) { - _ = connPool.Filter(func(cn *pool.Conn) bool { - return cn.RemoteAddr().String() != addr - }) + if connPool, ok := rdb.connPool.(*pool.ConnPool); ok { + _ = connPool.Filter(func(cn *pool.Conn) bool { + return cn.RemoteAddr().String() != addr + }) + } } failover.mu.Unlock() @@ -544,6 +553,7 @@ func NewSentinelClient(opt *Options) *SentinelClient { process: c.baseClient.process, }) c.connPool = newConnPool(opt, c.dialHook) + c.pubSubPool = newPubSubPool(opt, c.dialHook) return c } @@ -570,13 +580,31 @@ func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { func (c *SentinelClient) pubSub() *PubSub { pubsub := &PubSub{ opt: c.opt, - - newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { - return c.newConn(ctx) + newConn: func(ctx context.Context, addr string, channels []string) (*pool.Conn, error) { + cn, err := c.pubSubPool.NewConn(ctx, c.opt.Network, addr, channels) + if err != nil { + return nil, err + } + // will return nil if already initialized + err = c.initConn(ctx, cn) + if err != nil { + _ = cn.Close() + return nil, err + } + // Track connection in PubSubPool + c.pubSubPool.TrackConn(cn) + return cn, nil + }, + closeConn: func(cn *pool.Conn) error { + // Untrack connection from PubSubPool + c.pubSubPool.UntrackConn(cn) + _ = cn.Close() + return nil }, - closeConn: c.connPool.CloseConn, + pushProcessor: c.pushProcessor, } pubsub.init() + return pubsub } diff --git a/tx.go b/tx.go index 67689f57af..ee7b76e46d 100644 --- a/tx.go +++ b/tx.go @@ -24,7 +24,7 @@ type Tx struct { func (c *Client) newTx() *Tx { tx := Tx{ baseClient: baseClient{ - opt: c.opt, + opt: c.opt.clone(), // Clone options to avoid sharing HitlessUpgradeConfig connPool: pool.NewStickyConnPool(c.connPool), hooksMixin: c.hooksMixin.clone(), pushProcessor: c.pushProcessor, // Copy push processor from parent client diff --git a/universal.go b/universal.go index 02da3be82b..2f4b4a5398 100644 --- a/universal.go +++ b/universal.go @@ -122,6 +122,9 @@ type UniversalOptions struct { // IsClusterMode can be used when only one Addrs is provided (e.g. Elasticache supports setting up cluster mode with configuration endpoint). IsClusterMode bool + + // HitlessUpgradeConfig provides configuration for hitless upgrades. + HitlessUpgradeConfig *HitlessUpgradeConfig } // Cluster returns cluster options created from the universal options. @@ -177,6 +180,7 @@ func (o *UniversalOptions) Cluster() *ClusterOptions { IdentitySuffix: o.IdentitySuffix, FailingTimeoutSeconds: o.FailingTimeoutSeconds, UnstableResp3: o.UnstableResp3, + HitlessUpgradeConfig: o.HitlessUpgradeConfig, } } @@ -237,6 +241,7 @@ func (o *UniversalOptions) Failover() *FailoverOptions { DisableIndentity: o.DisableIndentity, IdentitySuffix: o.IdentitySuffix, UnstableResp3: o.UnstableResp3, + // Note: HitlessUpgradeConfig not supported for FailoverOptions } } @@ -284,10 +289,11 @@ func (o *UniversalOptions) Simple() *Options { TLSConfig: o.TLSConfig, - DisableIdentity: o.DisableIdentity, - DisableIndentity: o.DisableIndentity, - IdentitySuffix: o.IdentitySuffix, - UnstableResp3: o.UnstableResp3, + DisableIdentity: o.DisableIdentity, + DisableIndentity: o.DisableIndentity, + IdentitySuffix: o.IdentitySuffix, + UnstableResp3: o.UnstableResp3, + HitlessUpgradeConfig: o.HitlessUpgradeConfig, } } From c6f4820a96b6b5d0dba68a2b4bd5ea22b2408686 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Tue, 19 Aug 2025 15:38:16 +0300 Subject: [PATCH 02/21] Update pubsub.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- pubsub.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pubsub.go b/pubsub.go index 01188d10cf..6db13a9a61 100644 --- a/pubsub.go +++ b/pubsub.go @@ -493,7 +493,12 @@ func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (int // Receive returns a message as a Subscription, Message, Pong or error. // See PubSub example for details. This is low-level API and in most cases // Channel should be used instead. -// This will block until a message is received. +// Receive returns a message as a Subscription, Message, Pong, or an error. +// See PubSub example for details. This is a low-level API and in most cases +// Channel should be used instead. +// This method blocks until a message is received or an error occurs. +// It may return early with an error if the context is canceled, the connection fails, +// or other internal errors occur. func (c *PubSub) Receive(ctx context.Context) (interface{}, error) { return c.ReceiveTimeout(ctx, 0) } From 55747935a39632ada8701cad61f12b69c065b7c8 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Tue, 19 Aug 2025 15:39:20 +0300 Subject: [PATCH 03/21] Update redis.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- redis.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/redis.go b/redis.go index e5cca4af29..27bd6b7e76 100644 --- a/redis.go +++ b/redis.go @@ -930,7 +930,16 @@ func NewClient(opt *Options) *Client { if err != nil { internal.Logger.Printf(context.Background(), "hitless: failed to initialize hitless upgrades: %v", err) if opt.HitlessUpgradeConfig.Mode == hitless.MaintNotificationsEnabled { - // panic so we fail fast without breaking existing clients api + /* + Design decision: panic here to fail fast if hitless upgrades cannot be enabled when explicitly requested. + We choose to panic instead of returning an error to avoid breaking the existing client API, which does not expect + an error from NewClient. This ensures that misconfiguration or critical initialization failures are surfaced + immediately, rather than allowing the client to continue in a partially initialized or inconsistent state. + Clients relying on hitless upgrades should be aware that initialization errors will cause a panic, and should + handle this accordingly (e.g., via recover or by validating configuration before calling NewClient). + This approach is only used when HitlessUpgradeConfig.Mode is MaintNotificationsEnabled, indicating that hitless + upgrades are required for correct operation. In other modes, initialization failures are logged but do not panic. + */ panic(fmt.Errorf("failed to enable hitless upgrades: %w", err)) } } From ded98eccb13288b75a8cfa0f8ddf9bc810bf4a82 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 19 Aug 2025 15:56:41 +0300 Subject: [PATCH 04/21] address comments --- hitless/errors.go | 1 + hitless/pool_hook.go | 2 +- internal/pool/pool.go | 14 ++++++-- options.go | 76 ++++++++++++++++++++++++++++++++++++------- redis.go | 14 ++++++-- sentinel.go | 22 ++++++++++--- 6 files changed, 108 insertions(+), 21 deletions(-) diff --git a/hitless/errors.go b/hitless/errors.go index 5beb250aaa..784b41a21b 100644 --- a/hitless/errors.go +++ b/hitless/errors.go @@ -37,6 +37,7 @@ var ( ErrHandoffInProgress = errors.New("hitless: handoff already in progress") ErrNoHandoffInProgress = errors.New("hitless: no handoff in progress") ErrConnectionFailed = errors.New("hitless: failed to establish new connection") + ErrHandoffQueueFull = errors.New("hitless: handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration") ) // Dead error variables removed - unused in simplified architecture diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go index eb3eaf905d..073fdd7378 100644 --- a/hitless/pool_hook.go +++ b/hitless/pool_hook.go @@ -309,7 +309,7 @@ func (ph *PoolHook) queueHandoff(conn *pool.Conn) error { // Ensure we have workers available to handle the load ph.ensureWorkerAvailable() - return errors.New("queue full") + return ErrHandoffQueueFull } // performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 43c6a81907..0989d38c17 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -24,8 +24,18 @@ var ( // ErrPoolTimeout timed out waiting to get a connection from the connection pool. ErrPoolTimeout = errors.New("redis: connection pool timeout") - popAttempts = 10 - getAttempts = 3 + // popAttempts is the maximum number of attempts to find a usable connection + // when popping from the idle connection pool. This handles cases where connections + // are temporarily marked as unusable (e.g., during hitless upgrades or network issues). + // Value of 10 provides sufficient resilience without excessive overhead. + popAttempts = 10 + + // getAttempts is the maximum number of attempts to get a connection that passes + // hook validation (e.g., hitless upgrade hooks). This protects against race conditions + // where hooks might temporarily reject connections during cluster transitions. + // Value of 3 balances resilience with performance - most hook rejections resolve quickly. + getAttempts = 3 + minTime = time.Unix(-2208988800, 0) // Jan 1, 1900 maxTime = minTime.Add(1<<63 - 1) noExpiration = maxTime diff --git a/options.go b/options.go index 3c5d364cef..45f62b32a0 100644 --- a/options.go +++ b/options.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "math" "net" "net/url" "runtime" @@ -31,6 +32,17 @@ type Limiter interface { ReportResult(result error) } +// safeIntToInt32 safely converts an int to int32, returning an error if overflow would occur. +func safeIntToInt32(value int, fieldName string) (int32, error) { + if value > math.MaxInt32 { + return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32) + } + if value < math.MinInt32 { + return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32) + } + return int32(value), nil +} + // Options keeps the settings to set up redis connection. type Options struct { @@ -648,40 +660,80 @@ func getUserPassword(u *url.URL) (string, string) { func newConnPool( opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error), -) *pool.ConnPool { +) (*pool.ConnPool, error) { + poolSize, err := safeIntToInt32(opt.PoolSize, "PoolSize") + if err != nil { + return nil, err + } + + minIdleConns, err := safeIntToInt32(opt.MinIdleConns, "MinIdleConns") + if err != nil { + return nil, err + } + + maxIdleConns, err := safeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") + if err != nil { + return nil, err + } + + maxActiveConns, err := safeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") + if err != nil { + return nil, err + } + return pool.NewConnPool(&pool.Options{ Dialer: func(ctx context.Context) (net.Conn, error) { return dialer(ctx, opt.Network, opt.Addr) }, PoolFIFO: opt.PoolFIFO, - PoolSize: int32(opt.PoolSize), + PoolSize: poolSize, PoolTimeout: opt.PoolTimeout, DialTimeout: opt.DialTimeout, - MinIdleConns: int32(opt.MinIdleConns), - MaxIdleConns: int32(opt.MaxIdleConns), - MaxActiveConns: int32(opt.MaxActiveConns), + MinIdleConns: minIdleConns, + MaxIdleConns: maxIdleConns, + MaxActiveConns: maxActiveConns, ConnMaxIdleTime: opt.ConnMaxIdleTime, ConnMaxLifetime: opt.ConnMaxLifetime, ReadBufferSize: opt.ReadBufferSize, WriteBufferSize: opt.WriteBufferSize, PushNotificationsEnabled: opt.Protocol == 3, - }) + }), nil } func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error), -) *pool.PubSubPool { +) (*pool.PubSubPool, error) { + poolSize, err := safeIntToInt32(opt.PoolSize, "PoolSize") + if err != nil { + return nil, err + } + + minIdleConns, err := safeIntToInt32(opt.MinIdleConns, "MinIdleConns") + if err != nil { + return nil, err + } + + maxIdleConns, err := safeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") + if err != nil { + return nil, err + } + + maxActiveConns, err := safeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") + if err != nil { + return nil, err + } + return pool.NewPubSubPool(&pool.Options{ PoolFIFO: opt.PoolFIFO, - PoolSize: int32(opt.PoolSize), + PoolSize: poolSize, PoolTimeout: opt.PoolTimeout, DialTimeout: opt.DialTimeout, - MinIdleConns: int32(opt.MinIdleConns), - MaxIdleConns: int32(opt.MaxIdleConns), - MaxActiveConns: int32(opt.MaxActiveConns), + MinIdleConns: minIdleConns, + MaxIdleConns: maxIdleConns, + MaxActiveConns: maxActiveConns, ConnMaxIdleTime: opt.ConnMaxIdleTime, ConnMaxLifetime: opt.ConnMaxLifetime, ReadBufferSize: 32 * 1024, WriteBufferSize: 32 * 1024, PushNotificationsEnabled: opt.Protocol == 3, - }, dialer) + }, dialer), nil } diff --git a/redis.go b/redis.go index 27bd6b7e76..c2791336c8 100644 --- a/redis.go +++ b/redis.go @@ -456,6 +456,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { c.optLock.Unlock() return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr) default: // will handle auto and any other + internal.Logger.Printf(ctx, "hitless: auto mode fallback: hitless upgrades disabled due to handshake failure: %v", hitlessHandshakeErr) c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled c.optLock.Unlock() // auto mode, disable hitless upgrades and continue @@ -562,6 +563,8 @@ func (c *baseClient) assertUnstableCommand(cmd Cmder) bool { if c.opt.UnstableResp3 { return true } else { + // TODO: find the best way to remove the panic and return error here + // The client should not panic when executing a command, only when initializing. panic("RESP3 responses for this command are disabled because they may still change. Please set the flag UnstableResp3 . See the [README](https://github.com/redis/go-redis/blob/master/README.md) and the release notes for guidance.") } default: @@ -921,8 +924,15 @@ func NewClient(opt *Options) *Client { opt.PushNotificationProcessor = c.pushProcessor // Create connection pools - c.connPool = newConnPool(opt, c.dialHook) - c.pubSubPool = newPubSubPool(opt, c.dialHook) + var err error + c.connPool, err = newConnPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + c.pubSubPool, err = newPubSubPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } // Initialize hitless upgrades first if enabled and protocol is RESP3 if opt.HitlessUpgradeConfig != nil && opt.HitlessUpgradeConfig.Mode != hitless.MaintNotificationsDisabled && opt.Protocol == 3 { diff --git a/sentinel.go b/sentinel.go index 8ae284dec1..e52e840722 100644 --- a/sentinel.go +++ b/sentinel.go @@ -475,8 +475,15 @@ func NewFailoverClient(failoverOpt *FailoverOptions) *Client { // Use void processor by default for RESP2 connections rdb.pushProcessor = initializePushProcessor(opt) - rdb.connPool = newConnPool(opt, rdb.dialHook) - rdb.pubSubPool = newPubSubPool(opt, rdb.dialHook) + var err error + rdb.connPool, err = newConnPool(opt, rdb.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + rdb.pubSubPool, err = newPubSubPool(opt, rdb.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } rdb.onClose = rdb.wrappedOnClose(failover.Close) @@ -552,8 +559,15 @@ func NewSentinelClient(opt *Options) *SentinelClient { dial: c.baseClient.dial, process: c.baseClient.process, }) - c.connPool = newConnPool(opt, c.dialHook) - c.pubSubPool = newPubSubPool(opt, c.dialHook) + var err error + c.connPool, err = newConnPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + c.pubSubPool, err = newPubSubPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create pubsub pool: %w", err)) + } return c } From fd98a22aa9fa149426ddbd6459af80a1eeb571ca Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:21:50 +0300 Subject: [PATCH 05/21] Update internal/pool/pool.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/pool/pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 0989d38c17..da5ca30280 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -539,7 +539,7 @@ func (p *ConnPool) popIdle() (*Conn, error) { // If we exhausted all attempts without finding a usable connection, return nil if attempts >= popAttempts { - log.Printf("redis: connection pool: failed to get an usable connection after %d attempts", popAttempts) + log.Printf("redis: connection pool: failed to get a usable connection after %d attempts", popAttempts) return nil, nil } From 1f8b660d11533a2ba4d0d77c7f03cd42d65cd1e8 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:22:25 +0300 Subject: [PATCH 06/21] Update internal/pool/pool.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- internal/pool/pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index da5ca30280..337397e0ec 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -386,7 +386,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { attempts := 0 for { if attempts >= getAttempts { - log.Printf("redis: connection pool: failed to get an connection accepted by hook after %d attempts", attempts) + log.Printf("redis: connection pool: failed to get a connection accepted by hook after %d attempts", attempts) break } attempts++ From e846e5a5b29514857ad61adf029f87c088dad39a Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov <1547186+ndyakov@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:23:56 +0300 Subject: [PATCH 07/21] Update tx.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tx.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tx.go b/tx.go index ee7b76e46d..40bc1d6618 100644 --- a/tx.go +++ b/tx.go @@ -24,7 +24,7 @@ type Tx struct { func (c *Client) newTx() *Tx { tx := Tx{ baseClient: baseClient{ - opt: c.opt.clone(), // Clone options to avoid sharing HitlessUpgradeConfig + opt: c.opt.clone(), // Clone options to avoid sharing mutable state between transaction and parent client connPool: pool.NewStickyConnPool(c.connPool), hooksMixin: c.hooksMixin.clone(), pushProcessor: c.pushProcessor, // Copy push processor from parent client From 22da2b590eb59453dd2792b28746e6d6fe8ac721 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 19 Aug 2025 16:38:05 +0300 Subject: [PATCH 08/21] address comments --- .github/copilot-instructions.md | 336 ++++++++++++++++++++++++++++++++ internal/util/convert.go | 11 ++ options.go | 29 +-- 3 files changed, 356 insertions(+), 20 deletions(-) create mode 100644 .github/copilot-instructions.md diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000000..b25460c6a4 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,336 @@ +# GitHub Copilot Instructions for go-redis + +This file provides context and guidelines for GitHub Copilot when working with the go-redis codebase. + +## Project Overview + +go-redis is a Redis client for Go with support for: +- Redis Standalone, Cluster, Sentinel, and Ring topologies +- RESP2 and RESP3 protocols +- Connection pooling and management +- Push notifications (RESP3) +- Hitless upgrades for seamless cluster transitions +- Pub/Sub messaging +- Pipelines and transactions + +## Architecture + +### Core Components + +- **Client Types**: `Client`, `ClusterClient`, `SentinelClient`, `RingClient` +- **Connection Pool**: `internal/pool` package manages connection lifecycle +- **Protocol**: `internal/proto` handles RESP protocol parsing +- **Hitless Upgrades**: `hitless` package provides seamless cluster transitions +- **Push Notifications**: `push` package handles RESP3 push notifications + +### Key Packages + +- `redis.go` - Main client implementation +- `options.go` - Configuration and client options +- `osscluster.go` - Open source cluster client +- `sentinel.go` - Sentinel failover client +- `ring.go` - Ring (sharding) client +- `internal/pool/` - Connection pool management +- `hitless/` - Hitless upgrade functionality + +## Coding Standards + +### General Guidelines + +1. **Error Handling**: Always handle errors explicitly, prefer descriptive error messages +2. **Context**: Use `context.Context` for cancellation and timeouts +3. **Thread Safety**: All public APIs must be thread-safe +4. **Memory Management**: Minimize allocations, reuse buffers where possible +5. **Testing**: Write comprehensive unit tests, prefer table-driven tests + +### Naming Conventions + +- Use Go standard naming (camelCase for private, PascalCase for public) +- Interface names should end with `-er` (e.g., `Pooler`, `Cmder`) +- Error variables should start with `Err` (e.g., `ErrClosed`) +- Constants should be grouped and well-documented + +### Code Organization + +- Keep functions focused and small (prefer < 100 lines) +- Group related functionality in the same file +- Use internal packages for implementation details +- Extract common patterns into helper functions + +## Connection Pool Guidelines + +### Pool Management + +- Connections are managed by `internal/pool/ConnPool` +- Use `pool.Conn` wrapper for Redis connections +- Implement proper connection lifecycle (dial, auth, select DB) +- Handle connection health checks and cleanup + +### Pool Hooks + +- Use `PoolHook` interface for connection processing +- Hooks are called on `OnGet` and `OnPut` operations +- Support for hitless upgrades through pool hooks +- Maintain backward compatibility when adding hooks + +### Connection States + +- `IsUsable()` - Connection can be used for commands +- `ShouldHandoff()` - Connection needs handoff during cluster transition +- Proper state management is critical for hitless upgrades + +## Hitless Upgrades + +### Design Principles + +- Seamless connection handoffs during cluster topology changes +- Event-driven architecture with push notifications +- Atomic state management using `sync/atomic` +- Worker pools for concurrent handoff processing + +### Key Components + +- `HitlessManager` - Orchestrates upgrade operations +- `PoolHook` - Handles connection-level operations +- `NotificationHandler` - Processes push notifications +- Configuration through `hitless.Config` + +### Implementation Guidelines + +- Use atomic operations for state checks (avoid mutex locks) +- Implement proper timeout handling for handoff operations +- Support retry logic with exponential backoff +- Maintain connection pool integrity during transitions + +## Testing Guidelines + +### Unit Tests + +- Use table-driven tests for multiple scenarios +- Test both success and error paths +- Mock external dependencies (Redis server) +- Verify thread safety with race detection + +### Integration Tests + +- Separate integration tests from unit tests +- Use real Redis instances when needed +- Test all client types (standalone, cluster, sentinel) +- Verify hitless upgrade scenarios + +### Test Structure + +```go +func TestFeature(t *testing.T) { + tests := []struct { + name string + input InputType + expected ExpectedType + wantErr bool + }{ + // test cases + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // test implementation + }) + } +} +``` + +## Performance Considerations + +### Memory Optimization + +- Reuse buffers and objects where possible +- Use object pools for frequently allocated types +- Minimize string allocations in hot paths +- Profile memory usage regularly + +### Concurrency + +- Prefer atomic operations over mutexes for simple state +- Use `sync.Map` for concurrent map access +- Implement proper worker pool patterns +- Avoid blocking operations in hot paths + +### Connection Management + +- Implement connection pooling efficiently +- Handle connection timeouts properly +- Support connection health checks +- Minimize connection churn + +## Common Patterns + +### Error Handling + +```go +if err != nil { + return fmt.Errorf("operation failed: %w", err) +} +``` + +### Context Usage + +```go +func (c *Client) operation(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + // continue with operation + } +} +``` + +### Configuration Validation + +```go +func (opt *Options) validate() error { + if opt.PoolSize <= 0 { + return errors.New("PoolSize must be positive") + } + return nil +} +``` + +## Documentation Standards + +- Use Go doc comments for all public APIs +- Include examples for complex functionality +- Document configuration options thoroughly +- Maintain README.md with usage examples + +## Compatibility + +- Maintain backward compatibility for public APIs +- Use build tags for version-specific features +- Support multiple Redis versions +- Handle protocol differences gracefully + +## Security Considerations + +- Validate all user inputs +- Handle authentication securely +- Support TLS connections +- Avoid logging sensitive information + +## go-redis Specific Patterns + +### Command Interface + +All Redis commands implement the `Cmder` interface: + +```go +type Cmder interface { + Name() string + FullName() string + Args() []interface{} + String() string + stringArg(int) string + firstKeyPos() int8 + SetFirstKeyPos(int8) + readTimeout() *time.Duration + readReply(rd *proto.Reader) error + SetErr(error) + Err() error +} +``` + +### Client Initialization Pattern + +```go +func NewClient(opt *Options) *Client { + if opt == nil { + panic("redis: NewClient nil options") + } + opt.init() // Apply defaults + + c := Client{ + baseClient: &baseClient{opt: opt}, + } + c.init() + + // Create pools with error handling + var err error + c.connPool, err = newConnPool(opt, c.dialHook) + if err != nil { + panic(fmt.Errorf("redis: failed to create connection pool: %w", err)) + } + + return &c +} +``` + +### Pool Hook Pattern + +```go +type PoolHook interface { + OnGet(ctx context.Context, conn *Conn, isNewConn bool) error + OnPut(ctx context.Context, conn *Conn) (shouldPool bool, shouldRemove bool, err error) +} +``` + +### Atomic State Management + +Prefer atomic operations for simple state: + +```go +type Manager struct { + closed atomic.Bool + count atomic.Int64 +} + +func (m *Manager) isClosed() bool { + return m.closed.Load() +} + +func (m *Manager) close() { + m.closed.Store(true) +} +``` + +### Configuration Defaults Pattern + +```go +func (opt *Options) init() { + if opt.PoolSize == 0 { + opt.PoolSize = 10 * runtime.GOMAXPROCS(0) + } + if opt.ReadTimeout == 0 { + opt.ReadTimeout = 3 * time.Second + } + // Apply hitless upgrade defaults + opt.HitlessUpgradeConfig = opt.HitlessUpgradeConfig.ApplyDefaultsWithPoolSize(opt.PoolSize) +} +``` + +### Push Notification Handling + +```go +type NotificationProcessor interface { + ProcessPushNotification(ctx context.Context, data []byte) error + RegisterHandler(notificationType string, handler NotificationHandler) error + Close() error +} +``` + +### Error Definitions + +Group related errors in separate files: + +```go +// errors.go +var ( + ErrClosed = errors.New("redis: client is closed") + ErrPoolExhausted = errors.New("redis: connection pool exhausted") + ErrPoolTimeout = errors.New("redis: connection pool timeout") +) +``` + +### Panics +Creating the client (NewClient, NewClusterClient, etc.) is the only time when the library can panic. +This includes initialization of the pool, hitless upgrade manager, and other critical components. +Other than that, the library should never panic. \ No newline at end of file diff --git a/internal/util/convert.go b/internal/util/convert.go index d326d50d35..b743a4f0eb 100644 --- a/internal/util/convert.go +++ b/internal/util/convert.go @@ -28,3 +28,14 @@ func MustParseFloat(s string) float64 { } return f } + +// SafeIntToInt32 safely converts an int to int32, returning an error if overflow would occur. +func SafeIntToInt32(value int, fieldName string) (int32, error) { + if value > math.MaxInt32 { + return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32) + } + if value < math.MinInt32 { + return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32) + } + return int32(value), nil +} diff --git a/options.go b/options.go index 45f62b32a0..c89f4605f8 100644 --- a/options.go +++ b/options.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "errors" "fmt" - "math" "net" "net/url" "runtime" @@ -18,6 +17,7 @@ import ( "github.com/redis/go-redis/v9/hitless" "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/internal/proto" + "github.com/redis/go-redis/v9/internal/util" "github.com/redis/go-redis/v9/push" ) @@ -32,17 +32,6 @@ type Limiter interface { ReportResult(result error) } -// safeIntToInt32 safely converts an int to int32, returning an error if overflow would occur. -func safeIntToInt32(value int, fieldName string) (int32, error) { - if value > math.MaxInt32 { - return 0, fmt.Errorf("redis: %s value %d exceeds maximum allowed value %d", fieldName, value, math.MaxInt32) - } - if value < math.MinInt32 { - return 0, fmt.Errorf("redis: %s value %d is below minimum allowed value %d", fieldName, value, math.MinInt32) - } - return int32(value), nil -} - // Options keeps the settings to set up redis connection. type Options struct { @@ -661,22 +650,22 @@ func newConnPool( opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error), ) (*pool.ConnPool, error) { - poolSize, err := safeIntToInt32(opt.PoolSize, "PoolSize") + poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize") if err != nil { return nil, err } - minIdleConns, err := safeIntToInt32(opt.MinIdleConns, "MinIdleConns") + minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns") if err != nil { return nil, err } - maxIdleConns, err := safeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") + maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") if err != nil { return nil, err } - maxActiveConns, err := safeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") + maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") if err != nil { return nil, err } @@ -702,22 +691,22 @@ func newConnPool( func newPubSubPool(opt *Options, dialer func(ctx context.Context, network, addr string) (net.Conn, error), ) (*pool.PubSubPool, error) { - poolSize, err := safeIntToInt32(opt.PoolSize, "PoolSize") + poolSize, err := util.SafeIntToInt32(opt.PoolSize, "PoolSize") if err != nil { return nil, err } - minIdleConns, err := safeIntToInt32(opt.MinIdleConns, "MinIdleConns") + minIdleConns, err := util.SafeIntToInt32(opt.MinIdleConns, "MinIdleConns") if err != nil { return nil, err } - maxIdleConns, err := safeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") + maxIdleConns, err := util.SafeIntToInt32(opt.MaxIdleConns, "MaxIdleConns") if err != nil { return nil, err } - maxActiveConns, err := safeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") + maxActiveConns, err := util.SafeIntToInt32(opt.MaxActiveConns, "MaxActiveConns") if err != nil { return nil, err } From adee0a89eab0aba578330d6a50f3043e6e77a92b Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Tue, 19 Aug 2025 16:50:22 +0300 Subject: [PATCH 09/21] fix logger --- internal/pool/pool.go | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 337397e0ec..32e9221863 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -3,7 +3,6 @@ package pool import ( "context" "errors" - "log" "net" "sync" "sync/atomic" @@ -386,7 +385,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { attempts := 0 for { if attempts >= getAttempts { - log.Printf("redis: connection pool: failed to get a connection accepted by hook after %d attempts", attempts) + internal.Logger.Printf(ctx, "redis: connection pool: failed to get a connection accepted by hook after %d attempts", attempts) break } attempts++ @@ -416,7 +415,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { if hookManager != nil { if err := hookManager.ProcessOnGet(ctx, cn, false); err != nil { - log.Printf("redis: connection pool: failed to process idle connection by hook: %v", err) + internal.Logger.Printf(ctx, "redis: connection pool: failed to process idle connection by hook: %v", err) // Failed to process connection, discard it _ = p.CloseConn(cn) continue @@ -443,7 +442,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { if hookManager != nil { if err := hookManager.ProcessOnGet(ctx, newcn, true); err != nil { // Failed to process connection, discard it - log.Printf("redis: connection pool: failed to process new connection by hook: %v", err) + internal.Logger.Printf(ctx, "redis: connection pool: failed to process new connection by hook: %v", err) _ = p.CloseConn(newcn) return nil, err } @@ -539,7 +538,7 @@ func (p *ConnPool) popIdle() (*Conn, error) { // If we exhausted all attempts without finding a usable connection, return nil if attempts >= popAttempts { - log.Printf("redis: connection pool: failed to get a usable connection after %d attempts", popAttempts) + internal.Logger.Printf(context.Background(), "redis: connection pool: failed to get a usable connection after %d attempts", popAttempts) return nil, nil } @@ -764,5 +763,3 @@ func (p *ConnPool) isHealthyConn(cn *Conn, now time.Time) bool { } return true } - - From 2e47e39a9794a53a8794478952feb66a7b4abd16 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 21 Aug 2025 14:26:50 +0300 Subject: [PATCH 10/21] address pr comments --- hitless/errors.go | 44 +------------ hitless/notification_handler.go | 24 +++----- hitless/pool_hook.go | 106 +++++++++++++++----------------- 3 files changed, 58 insertions(+), 116 deletions(-) diff --git a/hitless/errors.go b/hitless/errors.go index 784b41a21b..7e71b13f8c 100644 --- a/hitless/errors.go +++ b/hitless/errors.go @@ -2,7 +2,6 @@ package hitless import ( "errors" - "fmt" ) // Configuration errors @@ -19,12 +18,7 @@ var ( // Configuration validation errors ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10") - ErrInvalidConnectionValidationTimeout = errors.New("hitless: ConnectionValidationTimeout must be greater than 0 and less than 30 seconds") - ErrInvalidConnectionHealthCheckInterval = errors.New("hitless: ConnectionHealthCheckInterval must be between 0 and 1 hour") - ErrInvalidOperationCleanupInterval = errors.New("hitless: OperationCleanupInterval must be greater than 0 and less than 1 hour") - ErrInvalidMaxActiveOperations = errors.New("hitless: MaxActiveOperations must be between 100 and 100000") - ErrInvalidNotificationBufferSize = errors.New("hitless: NotificationBufferSize must be between 10 and 10000") - ErrInvalidNotificationTimeout = errors.New("hitless: NotificationTimeout must be greater than 0 and less than 30 seconds") + ErrInvalidHandoffState = errors.New("hitless: Conn is in invalid state for handoff") ) // Integration errors @@ -34,44 +28,10 @@ var ( // Handoff errors var ( - ErrHandoffInProgress = errors.New("hitless: handoff already in progress") - ErrNoHandoffInProgress = errors.New("hitless: no handoff in progress") - ErrConnectionFailed = errors.New("hitless: failed to establish new connection") - ErrHandoffQueueFull = errors.New("hitless: handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration") + ErrHandoffQueueFull = errors.New("hitless: handoff queue is full, cannot queue new handoff requests - consider increasing HandoffQueueSize or MaxWorkers in configuration") ) -// Dead error variables removed - unused in simplified architecture - // Notification errors var ( ErrInvalidNotification = errors.New("hitless: invalid notification format") ) - -// Dead error variables removed - unused in simplified architecture - -// HandoffError represents an error that occurred during connection handoff. -type HandoffError struct { - Operation string - Endpoint string - Cause error -} - -func (e *HandoffError) Error() string { - if e.Cause != nil { - return fmt.Sprintf("hitless: handoff %s failed for endpoint %s: %v", e.Operation, e.Endpoint, e.Cause) - } - return fmt.Sprintf("hitless: handoff %s failed for endpoint %s", e.Operation, e.Endpoint) -} - -func (e *HandoffError) Unwrap() error { - return e.Cause -} - -// NewHandoffError creates a new HandoffError. -func NewHandoffError(operation, endpoint string, cause error) *HandoffError { - return &HandoffError{ - Operation: operation, - Endpoint: endpoint, - Cause: cause, - } -} diff --git a/hitless/notification_handler.go b/hitless/notification_handler.go index 933e0ea68e..cf28460c3f 100644 --- a/hitless/notification_handler.go +++ b/hitless/notification_handler.go @@ -3,7 +3,6 @@ package hitless import ( "context" "fmt" - "strconv" "time" "github.com/redis/go-redis/v9/internal" @@ -63,27 +62,17 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus if len(notification) < 3 { return ErrInvalidNotification } - seqIDStr, ok := notification[1].(string) + seqID, ok := notification[1].(int64) if !ok { return ErrInvalidNotification } - seqID, err := strconv.ParseInt(seqIDStr, 10, 64) - if err != nil { - return ErrInvalidNotification - } - // Extract timeS - timeSStr, ok := notification[2].(string) + timeS, ok := notification[2].(int64) if !ok { return ErrInvalidNotification } - timeS, err := strconv.ParseInt(timeSStr, 10, 64) - if err != nil { - return ErrInvalidNotification - } - newEndpoint := "" if len(notification) > 3 { // Extract new endpoint @@ -116,16 +105,17 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus newEndpoint = snh.manager.options.GetAddr() // delay the handoff for timeS/2 seconds to the same endpoint // do this in a goroutine to avoid blocking the notification handler - go func() { - time.Sleep(time.Duration(timeS/2) * time.Second) + // NOTE: This timer is started while parsing the notification, so the connection is not marked for handoff + // and there should be no possibility of a race condition or double handoff. + time.AfterFunc(time.Duration(timeS/2)*time.Second, func() { if poolConn == nil || poolConn.IsClosed() { return } if err := snh.markConnForHandoff(poolConn, newEndpoint, seqID, deadline); err != nil { // Log error but don't fail the goroutine - internal.Logger.Printf(context.Background(), "hitless: failed to mark connection for handoff: %v", err) + internal.Logger.Printf(ctx, "hitless: failed to mark connection for handoff: %v", err) } - }() + }) return nil } diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go index 073fdd7378..2a31a5d0fb 100644 --- a/hitless/pool_hook.go +++ b/hitless/pool_hook.go @@ -257,16 +257,45 @@ func (ph *PoolHook) processHandoffRequest(request HandoffRequest) { }() // Perform the handoff with cancellable context - err := ph.performConnectionHandoffWithPool(shutdownCtx, request.Conn, request.Pool) - - // If handoff failed, restore the handoff state for potential retry + shouldRetry, err := ph.performConnectionHandoffWithPool(shutdownCtx, request.Conn, request.Pool) if err != nil { - request.Conn.RestoreHandoffState() - internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY: %v", err) - } + if shouldRetry { + now := time.Now() + deadline, ok := shutdownCtx.Deadline() + if !ok || deadline.Before(now) { + // wait half the timeout before retrying if no deadline or deadline has passed + deadline = now.Add(handoffTimeout / 2) + } - // No need for scale down scheduling with on-demand workers - // Workers automatically exit when idle + afterTime := deadline.Sub(now) + if afterTime < handoffTimeout/2 { + afterTime = handoffTimeout / 2 + } + + internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY After %v: %v", afterTime, err) + time.AfterFunc(afterTime, func() { + ph.queueHandoff(request.Conn) + }) + } else { + pooler := request.Pool + conn := request.Conn + if pooler != nil { + go pooler.Remove(ctx, conn, err) + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: removed connection %d from pool due to max handoff retries reached", + conn.GetID()) + } + } else { + go conn.Close() + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", + conn.GetID(), err) + } + } + } + } } // queueHandoff queues a handoff request for processing @@ -313,8 +342,8 @@ func (ph *PoolHook) queueHandoff(conn *pool.Conn) error { } // performConnectionHandoffWithPool performs the actual connection handoff with pool for connection removal on failure -// if err is returned, connection will be removed from pool -func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn *pool.Conn, pooler pool.Pooler) error { +// When error is returned, the connection handoff should be retried if err is not ErrMaxHandoffRetriesReached +func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn *pool.Conn, pooler pool.Pooler) (shouldRetry bool, err error) { // Clear handoff state after successful handoff seqID := conn.GetMovingSeqID() connID := conn.GetID() @@ -326,11 +355,7 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * newEndpoint := conn.GetHandoffEndpoint() if newEndpoint == "" { - // TODO(hitless): Handle by performing the handoff to the current endpoint in N seconds, - // Where N is the time in the moving notification... - // For now, clear the handoff state and return - conn.ClearHandoffState() - return nil + return false, ErrInvalidHandoffState } retries := conn.IncrementAndGetHandoffRetries(1) @@ -345,23 +370,8 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * "hitless: reached max retries (%d) for handoff of connection %d to %s", maxRetries, conn.GetID(), conn.GetHandoffEndpoint()) } - err := ErrMaxHandoffRetriesReached - if pooler != nil { - go pooler.Remove(ctx, conn, err) - if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level - internal.Logger.Printf(ctx, - "hitless: removed connection %d from pool due to max handoff retries reached", - conn.GetID()) - } - } else { - go conn.Close() - if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level - internal.Logger.Printf(ctx, - "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", - conn.GetID(), err) - } - } - return err + // won't retry on ErrMaxHandoffRetriesReached + return false, ErrMaxHandoffRetriesReached } // Create endpoint-specific dialer @@ -370,10 +380,9 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * // Create new connection to the new endpoint newNetConn, err := endpointDialer(ctx) if err != nil { - // TODO(hitless): retry - // This is the only case where we should retry the handoff request - // Should we do anything else other than return the error? - return err + // hitless: will retry + // Maybe a network error - retry after a delay + return true, err } // Get the old connection @@ -382,26 +391,9 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * // Replace the connection and execute initialization err = conn.SetNetConnAndInitConn(ctx, newNetConn) if err != nil { - // Remove the connection from the pool since it's in a bad state - if pooler != nil { - // Use pool.Pooler interface directly - no adapter needed - go pooler.Remove(ctx, conn, err) - if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level - internal.Logger.Printf(ctx, - "hitless: removed connection %d from pool due to handoff initialization failure: %v", - conn.GetID(), err) - } - } else { - go conn.Close() - if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level - internal.Logger.Printf(ctx, - "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", - conn.GetID(), err) - } - } - - // Keep the handoff state for retry - return err + // hitless: won't retry + // Initialization failed - remove the connection + return false, err } defer func() { if oldConn != nil { @@ -428,7 +420,7 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * } } - return nil + return false, nil } // createEndpointDialer creates a dialer function that connects to a specific endpoint From 26923a2357d1aee3ee50ec2317dabd8b4ef497d7 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 21 Aug 2025 16:08:57 +0300 Subject: [PATCH 11/21] refactor for readibility --- hitless/pool_hook.go | 60 +++++++++++++++++++++---------------------- internal/pool/conn.go | 8 ------ 2 files changed, 30 insertions(+), 38 deletions(-) diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go index 2a31a5d0fb..86ca9b3156 100644 --- a/hitless/pool_hook.go +++ b/hitless/pool_hook.go @@ -138,40 +138,40 @@ func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, isNewConn bool) // OnPut is called when a connection is returned to the pool func (ph *PoolHook) OnPut(ctx context.Context, conn *pool.Conn) (shouldPool bool, shouldRemove bool, err error) { // first check if we should handoff for faster rejection - if conn.ShouldHandoff() { - // check pending handoff to not queue the same connection twice - _, hasPendingHandoff := ph.pending.Load(conn.GetID()) - if !hasPendingHandoff { - // Check for empty endpoint first (synchronous check) - if conn.GetHandoffEndpoint() == "" { - conn.ClearHandoffState() - } else { - if err := ph.queueHandoff(conn); err != nil { - // Failed to queue handoff, remove the connection - internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err) - return false, true, nil // Don't pool, remove connection, no error to caller - } + if !conn.ShouldHandoff() { + // Default behavior (no handoff): pool the connection + return true, false, nil + } - // Check if handoff was already processed by a worker before we can mark it as queued - if !conn.ShouldHandoff() { - // Handoff was already processed - this is normal and the connection should be pooled - return true, false, nil - } + // check pending handoff to not queue the same connection twice + _, hasPendingHandoff := ph.pending.Load(conn.GetID()) + if hasPendingHandoff { + // Default behavior (pending handoff): pool the connection + return true, false, nil + } - if err := conn.MarkQueuedForHandoff(); err != nil { - // If marking fails, check if handoff was processed in the meantime - if !conn.ShouldHandoff() { - // Handoff was processed - this is normal, pool the connection - return true, false, nil - } - // Other error - remove the connection - return false, true, nil - } - return true, false, nil - } + if err := ph.queueHandoff(conn); err != nil { + // Failed to queue handoff, remove the connection + internal.Logger.Printf(ctx, "Failed to queue handoff: %v", err) + // Don't pool, remove connection, no error to caller + return false, true, nil + } + + // Check if handoff was already processed by a worker before we can mark it as queued + if !conn.ShouldHandoff() { + // Handoff was already processed - this is normal and the connection should be pooled + return true, false, nil + } + + if err := conn.MarkQueuedForHandoff(); err != nil { + // If marking fails, check if handoff was processed in the meantime + if !conn.ShouldHandoff() { + // Handoff was processed - this is normal, pool the connection + return true, false, nil } + // Other error - remove the connection + return false, true, nil } - // Default: pool the connection return true, false, nil } diff --git a/internal/pool/conn.go b/internal/pool/conn.go index aa2da01a7f..60875b0409 100644 --- a/internal/pool/conn.go +++ b/internal/pool/conn.go @@ -409,14 +409,6 @@ func (cn *Conn) MarkQueuedForHandoff() error { return nil } -// RestoreHandoffState restores the handoff state after a failed handoff (lock-free). -func (cn *Conn) RestoreHandoffState() { - // Restore shouldHandoff flag for retry - cn.shouldHandoffAtomic.Store(true) - // Keep usable=false to prevent the connection from being used until handoff succeeds - cn.setUsable(false) -} - // ShouldHandoff returns true if the connection needs to be handed off (lock-free). func (cn *Conn) ShouldHandoff() bool { return cn.shouldHandoff() From 4940827684c5e4f87cfe2061e3c98d871844b77d Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 21 Aug 2025 16:12:26 +0300 Subject: [PATCH 12/21] refactor for readibility x2 --- hitless/errors.go | 10 +++++++++- hitless/pool_hook.go | 9 ++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/hitless/errors.go b/hitless/errors.go index 7e71b13f8c..0519df4845 100644 --- a/hitless/errors.go +++ b/hitless/errors.go @@ -18,7 +18,6 @@ var ( // Configuration validation errors ErrInvalidHandoffRetries = errors.New("hitless: MaxHandoffRetries must be between 1 and 10") - ErrInvalidHandoffState = errors.New("hitless: Conn is in invalid state for handoff") ) // Integration errors @@ -35,3 +34,12 @@ var ( var ( ErrInvalidNotification = errors.New("hitless: invalid notification format") ) + +// connection handoff errors +var ( + // ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff + // and should not be used until the handoff is complete + ErrConnectionMarkedForHandoff = errors.New("hitless: connection marked for handoff") + // ErrConnectionInvalidHandoffState is returned when a connection is in an invalid state for handoff + ErrConnectionInvalidHandoffState = errors.New("hitless: connection is in invalid state for handoff") +) diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go index 86ca9b3156..6ff22c153b 100644 --- a/hitless/pool_hook.go +++ b/hitless/pool_hook.go @@ -355,7 +355,7 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * newEndpoint := conn.GetHandoffEndpoint() if newEndpoint == "" { - return false, ErrInvalidHandoffState + return false, ErrConnectionInvalidHandoffState } retries := conn.IncrementAndGetHandoffRetries(1) @@ -445,8 +445,7 @@ func (ph *PoolHook) createEndpointDialer(endpoint string) func(context.Context) func (ph *PoolHook) Shutdown(ctx context.Context) error { ph.shutdownOnce.Do(func() { close(ph.shutdown) - - // No timers to clean up with on-demand workers + // workers will exit when they finish their current request }) // Wait for workers to complete @@ -463,7 +462,3 @@ func (ph *PoolHook) Shutdown(ctx context.Context) error { return ctx.Err() } } - -// ErrConnectionMarkedForHandoff is returned when a connection is marked for handoff -// and should not be used until the handoff is complete -var ErrConnectionMarkedForHandoff = errors.New("connection marked for handoff") From 0fd9871d479167c5d9bd5bfda8fb729d68432eff Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 21 Aug 2025 16:27:13 +0300 Subject: [PATCH 13/21] filter out logging --- internal/log.go | 10 ---------- internal/pool/pool.go | 2 +- main_test.go | 21 +++++++++++++++++++++ redis.go | 12 +++++++++++- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/internal/log.go b/internal/log.go index 4fe3d7db9c..dd0e3d0c74 100644 --- a/internal/log.go +++ b/internal/log.go @@ -27,13 +27,3 @@ func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { var Logger Logging = &logger{ log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile), } - -// VoidLogger is a logger that does nothing. -// Used to disable logging and thus speed up the library. -type VoidLogger struct{} - -func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) { - // do nothing -} - -var _ Logging = (*VoidLogger)(nil) diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 32e9221863..7005222075 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -385,7 +385,7 @@ func (p *ConnPool) getConn(ctx context.Context) (*Conn, error) { attempts := 0 for { if attempts >= getAttempts { - internal.Logger.Printf(ctx, "redis: connection pool: failed to get a connection accepted by hook after %d attempts", attempts) + internal.Logger.Printf(ctx, "redis: connection pool: was not able to get a connection accepted by hook after %d attempts", attempts) break } attempts++ diff --git a/main_test.go b/main_test.go index 29e6014b9b..dc7786031e 100644 --- a/main_test.go +++ b/main_test.go @@ -103,6 +103,9 @@ var _ = BeforeSuite(func() { fmt.Printf("REDIS_VERSION: %.1f\n", RedisVersion) fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE")) + tlogger := &TestLogger{} + tlogger.Filter("ERR unknown subcommand 'maint_notifications'") + redis.SetLogger(tlogger) if RedisVersion < 7.0 || RedisVersion > 9 { panic("incorrect or not supported redis version") } @@ -399,3 +402,21 @@ func (h *hook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.Process } return hook } + +// TestLogger is a logger that filters out specific substrings so +// the test output is not polluted with noise. +type TestLogger struct { + filteredSugstrings []string +} + +func (t *TestLogger) Filter(substr string) { + t.filteredSugstrings = append(t.filteredSugstrings, substr) +} +func (t *TestLogger) Printf(ctx context.Context, format string, v ...interface{}) { + for _, substr := range t.filteredSugstrings { + if strings.Contains(format, substr) { + return + } + } + fmt.Printf(format, v...) +} diff --git a/redis.go b/redis.go index c2791336c8..13c7feca69 100644 --- a/redis.go +++ b/redis.go @@ -456,7 +456,7 @@ func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { c.optLock.Unlock() return fmt.Errorf("failed to enable maintenance notifications: %w", hitlessHandshakeErr) default: // will handle auto and any other - internal.Logger.Printf(ctx, "hitless: auto mode fallback: hitless upgrades disabled due to handshake failure: %v", hitlessHandshakeErr) + internal.Logger.Printf(ctx, "hitless: auto mode fallback: hitless upgrades disabled due to handshake error: %v", hitlessHandshakeErr) c.opt.HitlessUpgradeConfig.Mode = hitless.MaintNotificationsDisabled c.optLock.Unlock() // auto mode, disable hitless upgrades and continue @@ -1280,3 +1280,13 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica Conn: &connectionAdapter{conn: cn}, // Wrap in adapter for easier interface access } } + +// VoidLogger is a logger that does nothing. +// Used to disable logging and thus speed up the library. +type VoidLogger struct{} + +func (v *VoidLogger) Printf(_ context.Context, _ string, _ ...interface{}) { + // do nothing +} + +var _ internal.Logging = (*VoidLogger)(nil) From c31ec645abc3f0d7f2c2c6333b321771f3023e0d Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 21 Aug 2025 16:41:04 +0300 Subject: [PATCH 14/21] check err on requeue --- hitless/pool_hook.go | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go index 6ff22c153b..e1b7d60361 100644 --- a/hitless/pool_hook.go +++ b/hitless/pool_hook.go @@ -274,26 +274,33 @@ func (ph *PoolHook) processHandoffRequest(request HandoffRequest) { internal.Logger.Printf(context.Background(), "Handoff failed for connection WILL RETRY After %v: %v", afterTime, err) time.AfterFunc(afterTime, func() { - ph.queueHandoff(request.Conn) + if err := ph.queueHandoff(request.Conn); err != nil { + internal.Logger.Printf(context.Background(), "can't queue handoff for retry: %v", err) + ph.removeConn(ctx, request, err) + } }) } else { - pooler := request.Pool - conn := request.Conn - if pooler != nil { - go pooler.Remove(ctx, conn, err) - if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level - internal.Logger.Printf(ctx, - "hitless: removed connection %d from pool due to max handoff retries reached", - conn.GetID()) - } - } else { - go conn.Close() - if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level - internal.Logger.Printf(ctx, - "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", - conn.GetID(), err) - } - } + go ph.removeConn(ctx, request, err) + } + } +} + +func (ph *PoolHook) removeConn(ctx context.Context, request HandoffRequest, err error) { + pooler := request.Pool + conn := request.Conn + if pooler != nil { + pooler.Remove(ctx, conn, err) + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: removed connection %d from pool due to max handoff retries reached", + conn.GetID()) + } + } else { + conn.Close() + if ph.config != nil && ph.config.LogLevel >= 1 { // Warning level + internal.Logger.Printf(ctx, + "hitless: no pool provided for connection %d, cannot remove due to handoff initialization failure: %v", + conn.GetID(), err) } } } From e596dd7c50cb16f1c590d8d886951fe9b4a33120 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 21 Aug 2025 16:48:56 +0300 Subject: [PATCH 15/21] fix test logger --- main_test.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/main_test.go b/main_test.go index dc7786031e..d7e1e10fbf 100644 --- a/main_test.go +++ b/main_test.go @@ -1,6 +1,7 @@ package redis_test import ( + "context" "fmt" "net" "os" @@ -13,6 +14,7 @@ import ( . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/internal" ) const ( @@ -103,9 +105,11 @@ var _ = BeforeSuite(func() { fmt.Printf("REDIS_VERSION: %.1f\n", RedisVersion) fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE")) - tlogger := &TestLogger{} + // set logger that will filter some of the noise from the tests + tlogger := NewTestLogger() tlogger.Filter("ERR unknown subcommand 'maint_notifications'") redis.SetLogger(tlogger) + if RedisVersion < 7.0 || RedisVersion > 9 { panic("incorrect or not supported redis version") } @@ -403,20 +407,30 @@ func (h *hook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.Process return hook } +func NewTestLogger() *TestLogger { + intLogger := internal.Logger + + return &TestLogger{ + intLogger, + []string{}, + } +} + // TestLogger is a logger that filters out specific substrings so // the test output is not polluted with noise. type TestLogger struct { + intLogger internal.Logging filteredSugstrings []string } func (t *TestLogger) Filter(substr string) { t.filteredSugstrings = append(t.filteredSugstrings, substr) } -func (t *TestLogger) Printf(ctx context.Context, format string, v ...interface{}) { +func (t *TestLogger) Printf(_ context.Context, format string, v ...interface{}) { for _, substr := range t.filteredSugstrings { if strings.Contains(format, substr) { return } } - fmt.Printf(format, v...) + t.intLogger.Printf(ctx, format, v...) } From 6031458454d3f6add11dafd053266949256b3d45 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 21 Aug 2025 17:00:24 +0300 Subject: [PATCH 16/21] fix test --- hitless/pool_hook_test.go | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/hitless/pool_hook_test.go b/hitless/pool_hook_test.go index 6dbb7a0472..e1f47bb46d 100644 --- a/hitless/pool_hook_test.go +++ b/hitless/pool_hook_test.go @@ -267,8 +267,8 @@ func TestConnectionHook(t *testing.T) { EndpointType: EndpointTypeAuto, MaxWorkers: 2, HandoffQueueSize: 10, - MaxHandoffRetries: 3, - HandoffTimeout: 1 * time.Second, // Shorter timeout for faster test + MaxHandoffRetries: 2, // Reduced retries for faster test + HandoffTimeout: 500 * time.Millisecond, // Shorter timeout for faster test LogLevel: 2, } processor := NewPoolHook(failingDialer, "tcp", config, nil) @@ -294,13 +294,12 @@ func TestConnectionHook(t *testing.T) { } // Wait for handoff to complete and fail with proper timeout and polling - // Use longer timeout to account for handoff timeout + processing time - timeout := time.After(5 * time.Second) + timeout := time.After(3 * time.Second) ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() // wait for handoff to start - time.Sleep(100 * time.Millisecond) + time.Sleep(50 * time.Millisecond) handoffCompleted := false for !handoffCompleted { select { @@ -318,10 +317,17 @@ func TestConnectionHook(t *testing.T) { t.Error("Connection should be removed from pending map after failed handoff") } - // Handoff state should still be set (since handoff failed) - if !conn.ShouldHandoff() { - t.Error("Connection should still be marked for handoff after failed handoff") + // Wait for retries to complete (with MaxHandoffRetries=2, it will retry twice then give up) + // Each retry has a delay of handoffTimeout/2 = 250ms, so wait for all retries to complete + time.Sleep(800 * time.Millisecond) + + // After max retries are reached, the connection should be removed from pool + // and handoff state should be cleared + if conn.ShouldHandoff() { + t.Error("Connection should not be marked for handoff after max retries reached") } + + t.Logf("EventDrivenHandoffDialerError test completed successfully") }) t.Run("BufferedDataRESP2", func(t *testing.T) { From b2228f4ddbb5145499982aa562a41ea902bb8155 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Thu, 21 Aug 2025 17:07:21 +0300 Subject: [PATCH 17/21] fix tests --- main_test.go | 43 +++++++++++++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 14 deletions(-) diff --git a/main_test.go b/main_test.go index d7e1e10fbf..2237106b34 100644 --- a/main_test.go +++ b/main_test.go @@ -3,6 +3,7 @@ package redis_test import ( "context" "fmt" + "log" "net" "os" "strconv" @@ -14,7 +15,6 @@ import ( . "github.com/bsm/ginkgo/v2" . "github.com/bsm/gomega" "github.com/redis/go-redis/v9" - "github.com/redis/go-redis/v9/internal" ) const ( @@ -55,6 +55,8 @@ var ( sentinel1, sentinel2, sentinel3 *redis.Client ) +var TLogger *TestLogger + var cluster = &clusterScenario{ ports: []string{"16600", "16601", "16602", "16603", "16604", "16605"}, nodeIDs: make([]string, 6), @@ -106,9 +108,10 @@ var _ = BeforeSuite(func() { fmt.Printf("CLIENT_LIBS_TEST_IMAGE: %v\n", os.Getenv("CLIENT_LIBS_TEST_IMAGE")) // set logger that will filter some of the noise from the tests - tlogger := NewTestLogger() - tlogger.Filter("ERR unknown subcommand 'maint_notifications'") - redis.SetLogger(tlogger) + TLogger := NewTestLogger() + TLogger.Filter("ERR unknown subcommand 'maint_notifications'") + TLogger.Filter("test panic") + redis.SetLogger(TLogger) if RedisVersion < 7.0 || RedisVersion > 9 { panic("incorrect or not supported redis version") @@ -408,8 +411,7 @@ func (h *hook) ProcessPipelineHook(hook redis.ProcessPipelineHook) redis.Process } func NewTestLogger() *TestLogger { - intLogger := internal.Logger - + intLogger := log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile) return &TestLogger{ intLogger, []string{}, @@ -419,18 +421,31 @@ func NewTestLogger() *TestLogger { // TestLogger is a logger that filters out specific substrings so // the test output is not polluted with noise. type TestLogger struct { - intLogger internal.Logging - filteredSugstrings []string + log *log.Logger + filteredSubstrings []string +} + +// Filter adds a substring to the filter list. +func (tl *TestLogger) Filter(substr string) { + tl.filteredSubstrings = append(tl.filteredSubstrings, substr) } -func (t *TestLogger) Filter(substr string) { - t.filteredSugstrings = append(t.filteredSugstrings, substr) +// Unfilter removes a substring from the filter list. +func (tl *TestLogger) Unfilter(substr string) { + for i, s := range tl.filteredSubstrings { + if s == substr { + tl.filteredSubstrings = append(tl.filteredSubstrings[:i], tl.filteredSubstrings[i+1:]...) + return + } + } } -func (t *TestLogger) Printf(_ context.Context, format string, v ...interface{}) { - for _, substr := range t.filteredSugstrings { - if strings.Contains(format, substr) { + +func (tl *TestLogger) Printf(_ context.Context, format string, v ...interface{}) { + msg := fmt.Sprintf(format, v...) + for _, substr := range tl.filteredSubstrings { + if strings.Contains(msg, substr) { return } } - t.intLogger.Printf(ctx, format, v...) + _ = tl.log.Output(2, msg) } From 30fceb81b4f7cba2cd602f21d0cbc127460f4e13 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Fri, 22 Aug 2025 17:54:40 +0300 Subject: [PATCH 18/21] fix hooks and add logging, logging will be removed before merge --- adapters.go | 4 +- hitless/config.go | 97 ++++++++++++++++++++++++--------- hitless/hitless_manager.go | 8 ++- hitless/hooks.go | 23 ++------ hitless/notification_handler.go | 20 +++++++ 5 files changed, 104 insertions(+), 48 deletions(-) diff --git a/adapters.go b/adapters.go index 6f123e212b..801b86d43f 100644 --- a/adapters.go +++ b/adapters.go @@ -103,8 +103,8 @@ func (ca *connectionAdapter) IsUsable() bool { return ca.conn.IsUsable() } -// GetPoolConnection returns the underlying pool connection. -func (ca *connectionAdapter) GetPoolConnection() *pool.Conn { +// GetPoolConn returns the underlying pool connection. +func (ca *connectionAdapter) GetPoolConn() *pool.Conn { return ca.conn } diff --git a/hitless/config.go b/hitless/config.go index b35a0d7185..ddb10954cd 100644 --- a/hitless/config.go +++ b/hitless/config.go @@ -3,6 +3,7 @@ package hitless import ( "net" "runtime" + "strings" "time" "github.com/redis/go-redis/v9/internal/util" @@ -183,8 +184,6 @@ func (c *Config) Validate() error { return ErrInvalidHandoffRetries } - - return nil } @@ -284,8 +283,6 @@ func (c *Config) ApplyDefaultsWithPoolSize(poolSize int) *Config { result.MaxHandoffRetries = c.MaxHandoffRetries } - - return result } @@ -334,44 +331,92 @@ func (c *Config) applyWorkerDefaults(poolSize int) { // DetectEndpointType automatically detects the appropriate endpoint type // based on the connection address and TLS configuration. +// +// For IP addresses: +// - If TLS is enabled: requests FQDN for proper certificate validation +// - If TLS is disabled: requests IP for better performance +// +// For hostnames: +// - If TLS is enabled: always requests FQDN for proper certificate validation +// - If TLS is disabled: requests IP for better performance +// +// Internal vs External detection: +// - For IPs: uses private IP range detection +// - For hostnames: uses heuristics based on common internal naming patterns func DetectEndpointType(addr string, tlsEnabled bool) EndpointType { - // Parse the address to determine if it's an IP or hostname - isPrivate := isPrivateIP(addr) + // Extract host from "host:port" format + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr // Assume no port + } + // Check if the host is an IP address or hostname + ip := net.ParseIP(host) + isIPAddress := ip != nil var endpointType EndpointType - if tlsEnabled { - // TLS requires FQDN for certificate validation - if isPrivate { - endpointType = EndpointTypeInternalFQDN + if isIPAddress { + // Address is an IP - determine if it's private or public + isPrivate := ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() + + if tlsEnabled { + // TLS with IP addresses - still prefer FQDN for certificate validation + if isPrivate { + endpointType = EndpointTypeInternalFQDN + } else { + endpointType = EndpointTypeExternalFQDN + } } else { - endpointType = EndpointTypeExternalFQDN + // No TLS - can use IP addresses directly + if isPrivate { + endpointType = EndpointTypeInternalIP + } else { + endpointType = EndpointTypeExternalIP + } } } else { - // No TLS, can use IP addresses - if isPrivate { - endpointType = EndpointTypeInternalIP + // Address is a hostname + isInternalHostname := isInternalHostname(host) + if isInternalHostname { + endpointType = EndpointTypeInternalFQDN } else { - endpointType = EndpointTypeExternalIP + endpointType = EndpointTypeExternalFQDN } } return endpointType } -// isPrivateIP checks if the given address is in a private IP range. -func isPrivateIP(addr string) bool { - // Extract host from "host:port" format - host, _, err := net.SplitHostPort(addr) - if err != nil { - host = addr // Assume no port +// isInternalHostname determines if a hostname appears to be internal/private. +// This is a heuristic based on common naming patterns. +func isInternalHostname(hostname string) bool { + // Convert to lowercase for comparison + hostname = strings.ToLower(hostname) + + // Common internal hostname patterns + internalPatterns := []string{ + "localhost", + ".local", + ".internal", + ".corp", + ".lan", + ".intranet", + ".private", } - ip := net.ParseIP(host) - if ip == nil { - return false // Not an IP address (likely hostname) + // Check for exact match or suffix match + for _, pattern := range internalPatterns { + if hostname == pattern || strings.HasSuffix(hostname, pattern) { + return true + } + } + + // Check for RFC 1918 style hostnames (e.g., redis-1, db-server, etc.) + // If hostname doesn't contain dots, it's likely internal + if !strings.Contains(hostname, ".") { + return true } - // Check for private/loopback ranges - return ip.IsPrivate() || ip.IsLoopback() || ip.IsLinkLocalUnicast() + // Default to external for fully qualified domain names + return false } diff --git a/hitless/hitless_manager.go b/hitless/hitless_manager.go index 26c379a5b7..309ac6430c 100644 --- a/hitless/hitless_manager.go +++ b/hitless/hitless_manager.go @@ -13,8 +13,6 @@ import ( "github.com/redis/go-redis/v9/internal/pool" ) - - // Push notification type constants for hitless upgrades const ( NotificationMoving = "MOVING" @@ -297,3 +295,9 @@ func (hm *HitlessManager) createPoolHook(baseDialer func(context.Context, string return hm.poolHooksRef } + +func (hm *HitlessManager) AddNotificationHook(notificationHook NotificationHook) { + hm.hooksMu.Lock() + defer hm.hooksMu.Unlock() + hm.hooks = append(hm.hooks, notificationHook) +} diff --git a/hitless/hooks.go b/hitless/hooks.go index 7e84e032d2..7d1b646360 100644 --- a/hitless/hooks.go +++ b/hitless/hooks.go @@ -22,27 +22,14 @@ func (lh *LoggingHook) PreHook(ctx context.Context, notificationType string, not // PostHook logs the result after processing. func (lh *LoggingHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) { if result != nil && lh.LogLevel >= 1 { // Warning level - internal.Logger.Printf(ctx, "hitless: %s notification processing failed: %v", notificationType, result) + internal.Logger.Printf(ctx, "hitless: %s notification processing failed: %v - %v", notificationType, result, notification) } else if lh.LogLevel >= 3 { // Debug level internal.Logger.Printf(ctx, "hitless: %s notification processed successfully", notificationType) } } -// FilterHook is an example hook that can filter out certain notifications. -type FilterHook struct { - BlockedTypes map[string]bool -} - -// PreHook filters notifications based on type. -func (fh *FilterHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) { - if fh.BlockedTypes[notificationType] { - internal.Logger.Printf(ctx, "hitless: filtering out %s notification", notificationType) - return notification, false // Skip processing - } - return notification, true -} - -// PostHook does nothing for filter hook. -func (fh *FilterHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) { - // No post-processing needed for filter hook +// NewLoggingHook creates a new logging hook with the specified log level. +// Log levels: 0=errors, 1=warnings, 2=info, 3=debug +func NewLoggingHook(logLevel int) *LoggingHook { + return &LoggingHook{LogLevel: logLevel} } diff --git a/hitless/notification_handler.go b/hitless/notification_handler.go index cf28460c3f..7222644261 100644 --- a/hitless/notification_handler.go +++ b/hitless/notification_handler.go @@ -19,11 +19,13 @@ type NotificationHandler struct { // HandlePushNotification processes push notifications with hook support. func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { if len(notification) == 0 { + internal.Logger.Printf(ctx, "hitless: invalid notification format: %v", notification) return ErrInvalidNotification } notificationType, ok := notification[0].(string) if !ok { + internal.Logger.Printf(ctx, "hitless: invalid notification type format: %v", notification[0]) return ErrInvalidNotification } @@ -60,16 +62,19 @@ func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, hand // ["MOVING", seqNum, timeS, endpoint] - per-connection handoff func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx push.NotificationHandlerContext, notification []interface{}) error { if len(notification) < 3 { + internal.Logger.Printf(ctx, "hitless: invalid MOVING notification: %v", notification) return ErrInvalidNotification } seqID, ok := notification[1].(int64) if !ok { + internal.Logger.Printf(ctx, "hitless: invalid seqID in MOVING notification: %v", notification[1]) return ErrInvalidNotification } // Extract timeS timeS, ok := notification[2].(int64) if !ok { + internal.Logger.Printf(ctx, "hitless: invalid timeS in MOVING notification: %v", notification[2]) return ErrInvalidNotification } @@ -78,6 +83,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus // Extract new endpoint newEndpoint, ok = notification[3].(string) if !ok { + internal.Logger.Printf(ctx, "hitless: invalid newEndpoint in MOVING notification: %v", notification[3]) return ErrInvalidNotification } } @@ -85,6 +91,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus // Get the connection that received this notification conn := handlerCtx.Conn if conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for MOVING notification") return ErrInvalidNotification } @@ -95,6 +102,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus } else if pc, ok := conn.(*pool.Conn); ok { poolConn = pc } else { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MOVING notification - %T %#v", conn, handlerCtx) return ErrInvalidNotification } @@ -145,17 +153,20 @@ func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx // MIGRATING notifications indicate that a connection is about to be migrated // Apply relaxed timeouts to the specific connection that received this notification if len(notification) < 2 { + internal.Logger.Printf(ctx, "hitless: invalid MIGRATING notification: %v", notification) return ErrInvalidNotification } // Get the connection from handler context and type assert to connectionAdapter if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATING notification") return ErrInvalidNotification } // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) if !ok { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATING notification") return ErrInvalidNotification } @@ -169,17 +180,20 @@ func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx p // MIGRATED notifications indicate that a connection migration has completed // Restore normal timeouts for the specific connection that received this notification if len(notification) < 2 { + internal.Logger.Printf(ctx, "hitless: invalid MIGRATED notification: %v", notification) return ErrInvalidNotification } // Get the connection from handler context and type assert to connectionAdapter if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATED notification") return ErrInvalidNotification } // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) if !ok { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATED notification") return ErrInvalidNotification } @@ -193,17 +207,20 @@ func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCt // FAILING_OVER notifications indicate that a connection is about to failover // Apply relaxed timeouts to the specific connection that received this notification if len(notification) < 2 { + internal.Logger.Printf(ctx, "hitless: invalid FAILING_OVER notification: %v", notification) return ErrInvalidNotification } // Get the connection from handler context and type assert to connectionAdapter if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILING_OVER notification") return ErrInvalidNotification } // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) if !ok { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILING_OVER notification") return ErrInvalidNotification } @@ -217,17 +234,20 @@ func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx // FAILED_OVER notifications indicate that a connection failover has completed // Restore normal timeouts for the specific connection that received this notification if len(notification) < 2 { + internal.Logger.Printf(ctx, "hitless: invalid FAILED_OVER notification: %v", notification) return ErrInvalidNotification } // Get the connection from handler context and type assert to connectionAdapter if handlerCtx.Conn == nil { + internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILED_OVER notification") return ErrInvalidNotification } // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) if !ok { + internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILED_OVER notification") return ErrInvalidNotification } From bfca15a9c53d008c49ac82bfe6ee2ab089c455c6 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Fri, 22 Aug 2025 19:48:45 +0300 Subject: [PATCH 19/21] update example and tests, drop connectionAdapter --- .github/copilot-instructions.md | 47 ++++++++++++++++++++++++++-- adapters.go | 38 ----------------------- example/pubsub/main.go | 40 +++++++++++++++++++----- hitless/README.md | 55 ++++++++++++++++++++++++++++----- hitless/example_hooks.go | 18 +++++++++-- hitless/hitless_manager.go | 13 ++++---- hitless/hooks.go | 21 ++++++++++--- hitless/notification_handler.go | 16 ++-------- hitless/pool_hook.go | 6 +++- internal/pool/pubsub.go | 5 +++ pubsub.go | 11 ++++++- push/handler_context.go | 1 - redis.go | 3 +- 13 files changed, 187 insertions(+), 87 deletions(-) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index b25460c6a4..384e0d8220 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -311,9 +311,50 @@ func (opt *Options) init() { ```go type NotificationProcessor interface { - ProcessPushNotification(ctx context.Context, data []byte) error - RegisterHandler(notificationType string, handler NotificationHandler) error - Close() error + RegisterHandler(pushNotificationName string, handler interface{}, protected bool) error + UnregisterHandler(pushNotificationName string) error + GetHandler(pushNotificationName string) interface{} +} + +type NotificationHandler interface { + HandlePushNotification(ctx context.Context, handlerCtx NotificationHandlerContext, notification []interface{}) error +} +``` + +### Notification Hooks + +```go +type NotificationHook interface { + PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) + PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) +} + +// NotificationHandlerContext provides context for notification processing +type NotificationHandlerContext struct { + Client interface{} // Redis client instance + Pool interface{} // Connection pool + Conn interface{} // Specific connection (*pool.Conn) + IsBlocking bool // Whether notification was on blocking connection +} +``` + +### Hook Implementation Pattern + +```go +func (h *CustomHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { + // Access connection information + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + connID := conn.GetID() + // Process with connection context + } + return notification, true // Continue processing +} + +func (h *CustomHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + // Handle processing result + if result != nil { + // Log or handle error + } } ``` diff --git a/adapters.go b/adapters.go index 801b86d43f..4146153bf3 100644 --- a/adapters.go +++ b/adapters.go @@ -7,7 +7,6 @@ import ( "time" "github.com/redis/go-redis/v9/internal/interfaces" - "github.com/redis/go-redis/v9/internal/pool" "github.com/redis/go-redis/v9/push" ) @@ -88,43 +87,6 @@ func (oa *optionsAdapter) NewDialer() func(context.Context) (net.Conn, error) { } } -// connectionAdapter adapts a Redis connection to interfaces.ConnectionWithRelaxedTimeout -type connectionAdapter struct { - conn *pool.Conn -} - -// Close closes the connection. -func (ca *connectionAdapter) Close() error { - return ca.conn.Close() -} - -// IsUsable returns true if the connection is safe to use for new commands. -func (ca *connectionAdapter) IsUsable() bool { - return ca.conn.IsUsable() -} - -// GetPoolConn returns the underlying pool connection. -func (ca *connectionAdapter) GetPoolConn() *pool.Conn { - return ca.conn -} - -// SetRelaxedTimeout sets relaxed timeouts for this connection during hitless upgrades. -// These timeouts remain active until explicitly cleared. -func (ca *connectionAdapter) SetRelaxedTimeout(readTimeout, writeTimeout time.Duration) { - ca.conn.SetRelaxedTimeout(readTimeout, writeTimeout) -} - -// SetRelaxedTimeoutWithDeadline sets relaxed timeouts with an expiration deadline. -// After the deadline, timeouts automatically revert to normal values. -func (ca *connectionAdapter) SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout time.Duration, deadline time.Time) { - ca.conn.SetRelaxedTimeoutWithDeadline(readTimeout, writeTimeout, deadline) -} - -// ClearRelaxedTimeout clears relaxed timeouts for this connection. -func (ca *connectionAdapter) ClearRelaxedTimeout() { - ca.conn.ClearRelaxedTimeout() -} - // pushProcessorAdapter adapts a push.NotificationProcessor to implement interfaces.NotificationProcessor. type pushProcessorAdapter struct { processor push.NotificationProcessor diff --git a/example/pubsub/main.go b/example/pubsub/main.go index ddc0604d0e..c733055486 100644 --- a/example/pubsub/main.go +++ b/example/pubsub/main.go @@ -5,6 +5,7 @@ import ( "fmt" "log" "sync" + "sync/atomic" "time" "github.com/redis/go-redis/v9" @@ -12,11 +13,15 @@ import ( ) var ctx = context.Background() +var cntErrors atomic.Int64 +var cntSuccess atomic.Int64 +var startTime = time.Now() // This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management. // It was used to find regressions in pool management in hitless mode. // Please don't use it as a reference for how to use pubsub. func main() { + startTime = time.Now() wg := &sync.WaitGroup{} rdb := redis.NewClient(&redis.Options{ Addr: ":6379", @@ -25,6 +30,12 @@ func main() { }, }) _ = rdb.FlushDB(ctx).Err() + hitlessManager := rdb.GetHitlessManager() + if hitlessManager == nil { + panic("hitless manager is nil") + } + loggingHook := hitless.NewLoggingHook(3) + hitlessManager.AddNotificationHook(loggingHook) go func() { for { @@ -62,7 +73,8 @@ func main() { subCtx, cancelSubCtx = context.WithCancel(ctx) for i := 0; i < 10; i++ { if err := rdb.Incr(ctx, "publishers").Err(); err != nil { - panic(err) + fmt.Println("incr error:", err) + cntErrors.Add(1) } wg.Add(1) go floodThePool(pubCtx, rdb, wg) @@ -70,12 +82,14 @@ func main() { for i := 0; i < 500; i++ { if err := rdb.Incr(ctx, "subscribers").Err(); err != nil { - panic(err) + fmt.Println("incr error:", err) + cntErrors.Add(1) } + wg.Add(1) go subscribe(subCtx, rdb, "test2", i, wg) } - time.Sleep(5 * time.Second) + time.Sleep(120 * time.Second) fmt.Println("canceling publishers") cancelPublishers() time.Sleep(10 * time.Second) @@ -95,6 +109,9 @@ func main() { fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt) time.Sleep(2 * time.Second) + fmt.Println("errors:", cntErrors.Load()) + fmt.Println("success:", cntSuccess.Load()) + fmt.Println("time:", time.Since(startTime)) } func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) { @@ -107,14 +124,18 @@ func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) { } err := rdb.Publish(ctx, "test2", "hello").Err() if err != nil { - // noop - //log.Println("publish error:", err) + if err.Error() != "context canceled" { + log.Println("publish error:", err) + cntErrors.Add(1) + } } err = rdb.Incr(ctx, "published").Err() if err != nil { - // noop - //log.Println("incr error:", err) + if err.Error() != "context canceled" { + log.Println("incr error:", err) + cntErrors.Add(1) + } } time.Sleep(10 * time.Nanosecond) } @@ -137,7 +158,10 @@ func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberI case msg := <-recChan: err := rdb.Incr(ctx, "received").Err() if err != nil { - log.Println("incr error:", err) + if err.Error() != "context canceled" { + log.Printf("%s\n", err.Error()) + cntErrors.Add(1) + } } _ = msg // Use the message to avoid unused variable warning } diff --git a/hitless/README.md b/hitless/README.md index b82b33a3d2..7d117a2aaa 100644 --- a/hitless/README.md +++ b/hitless/README.md @@ -49,23 +49,62 @@ Config: &hitless.Config{ - **Auto-calculated**: `10 × MaxWorkers`, capped by pool size - **Always capped**: Queue size never exceeds pool size -## Metrics Hook Example +## Notification Hooks -A metrics collection hook is available in `example_hooks.go` that demonstrates how to monitor hitless upgrade operations: +Notification hooks allow you to monitor and customize hitless upgrade operations. The `NotificationHook` interface provides pre and post processing hooks: + +```go +type NotificationHook interface { + PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) + PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) +} +``` + +### Example: Metrics Collection Hook + +A metrics collection hook is available in `example_hooks.go`: ```go import "github.com/redis/go-redis/v9/hitless" metricsHook := hitless.NewMetricsHook() -// Use with your monitoring system +manager.AddNotificationHook(metricsHook) + +// Access metrics +metrics := metricsHook.GetMetrics() ``` -The metrics hook tracks: +### Example: Custom Logging Hook + +```go +type CustomHook struct{} + +func (h *CustomHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { + // Log notification with connection details + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + log.Printf("Processing %s on connection %d", notificationType, conn.GetID()) + } + return notification, true // Continue processing +} + +func (h *CustomHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + if result != nil { + log.Printf("Failed to process %s: %v", notificationType, result) + } +} +``` + +The notification context provides access to: +- **Client**: The Redis client instance +- **Pool**: The connection pool +- **Conn**: The specific connection that received the notification +- **IsBlocking**: Whether the notification was received on a blocking connection + +Hooks can track: - Handoff success/failure rates -- Handoff duration -- Queue depth -- Worker utilization -- Connection lifecycle events +- Processing duration +- Connection-specific metrics +- Custom business logic ## Requirements diff --git a/hitless/example_hooks.go b/hitless/example_hooks.go index f03ea3ed59..0b65f1f5b5 100644 --- a/hitless/example_hooks.go +++ b/hitless/example_hooks.go @@ -3,6 +3,10 @@ package hitless import ( "context" "time" + + "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" ) // contextKey is a custom type for context keys to avoid collisions @@ -29,9 +33,14 @@ func NewMetricsHook() *MetricsHook { } // PreHook records the start time for processing metrics. -func (mh *MetricsHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) { +func (mh *MetricsHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { mh.NotificationCounts[notificationType]++ + // Log connection information if available + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + internal.Logger.Printf(ctx, "hitless: metrics hook processing %s notification on connection %d", notificationType, conn.GetID()) + } + // Store start time in context for duration calculation startTime := time.Now() _ = context.WithValue(ctx, startTimeKey, startTime) // Context not used further @@ -40,7 +49,7 @@ func (mh *MetricsHook) PreHook(ctx context.Context, notificationType string, not } // PostHook records processing completion and any errors. -func (mh *MetricsHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) { +func (mh *MetricsHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { // Calculate processing duration if startTime, ok := ctx.Value(startTimeKey).(time.Time); ok { duration := time.Since(startTime) @@ -50,6 +59,11 @@ func (mh *MetricsHook) PostHook(ctx context.Context, notificationType string, no // Record errors if result != nil { mh.ErrorCounts[notificationType]++ + + // Log error details with connection information + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + internal.Logger.Printf(ctx, "hitless: metrics hook recorded error for %s notification on connection %d: %v", notificationType, conn.GetID(), result) + } } } diff --git a/hitless/hitless_manager.go b/hitless/hitless_manager.go index 309ac6430c..364ba5a48e 100644 --- a/hitless/hitless_manager.go +++ b/hitless/hitless_manager.go @@ -11,6 +11,7 @@ import ( "github.com/redis/go-redis/v9/internal" "github.com/redis/go-redis/v9/internal/interfaces" "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" ) // Push notification type constants for hitless upgrades @@ -35,8 +36,8 @@ var hitlessNotificationTypes = []string{ // PreHook can modify the notification and return false to skip processing // PostHook is called after successful processing type NotificationHook interface { - PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) - PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) + PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) + PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) } // MovingOperationKey provides a unique key for tracking MOVING operations @@ -252,14 +253,14 @@ func (hm *HitlessManager) GetState() State { } // processPreHooks calls all pre-hooks and returns the modified notification and whether to continue processing. -func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) { +func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { hm.hooksMu.RLock() defer hm.hooksMu.RUnlock() currentNotification := notification for _, hook := range hm.hooks { - modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationType, currentNotification) + modifiedNotification, shouldContinue := hook.PreHook(ctx, notificationCtx, notificationType, currentNotification) if !shouldContinue { return modifiedNotification, false } @@ -270,12 +271,12 @@ func (hm *HitlessManager) processPreHooks(ctx context.Context, notificationType } // processPostHooks calls all post-hooks with the processing result. -func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationType string, notification []interface{}, result error) { +func (hm *HitlessManager) processPostHooks(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { hm.hooksMu.RLock() defer hm.hooksMu.RUnlock() for _, hook := range hm.hooks { - hook.PostHook(ctx, notificationType, notification, result) + hook.PostHook(ctx, notificationCtx, notificationType, notification, result) } } diff --git a/hitless/hooks.go b/hitless/hooks.go index 7d1b646360..d0093bca75 100644 --- a/hitless/hooks.go +++ b/hitless/hooks.go @@ -4,6 +4,8 @@ import ( "context" "github.com/redis/go-redis/v9/internal" + "github.com/redis/go-redis/v9/internal/pool" + "github.com/redis/go-redis/v9/push" ) // LoggingHook is an example hook implementation that logs all notifications. @@ -12,19 +14,28 @@ type LoggingHook struct { } // PreHook logs the notification before processing and allows modification. -func (lh *LoggingHook) PreHook(ctx context.Context, notificationType string, notification []interface{}) ([]interface{}, bool) { +func (lh *LoggingHook) PreHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}) ([]interface{}, bool) { if lh.LogLevel >= 2 { // Info level - internal.Logger.Printf(ctx, "hitless: processing %s notification: %v", notificationType, notification) + // Log the notification type and content + connID := uint64(0) + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + connID = conn.GetID() + } + internal.Logger.Printf(ctx, "hitless: conn[%d] processing %s notification: %v", connID, notificationType, notification) } return notification, true // Continue processing with unmodified notification } // PostHook logs the result after processing. -func (lh *LoggingHook) PostHook(ctx context.Context, notificationType string, notification []interface{}, result error) { +func (lh *LoggingHook) PostHook(ctx context.Context, notificationCtx push.NotificationHandlerContext, notificationType string, notification []interface{}, result error) { + connID := uint64(0) + if conn, ok := notificationCtx.Conn.(*pool.Conn); ok { + connID = conn.GetID() + } if result != nil && lh.LogLevel >= 1 { // Warning level - internal.Logger.Printf(ctx, "hitless: %s notification processing failed: %v - %v", notificationType, result, notification) + internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processing failed: %v - %v", connID, notificationType, result, notification) } else if lh.LogLevel >= 3 { // Debug level - internal.Logger.Printf(ctx, "hitless: %s notification processed successfully", notificationType) + internal.Logger.Printf(ctx, "hitless: conn[%d] %s notification processed successfully", connID, notificationType) } } diff --git a/hitless/notification_handler.go b/hitless/notification_handler.go index 7222644261..246e887ac0 100644 --- a/hitless/notification_handler.go +++ b/hitless/notification_handler.go @@ -30,7 +30,7 @@ func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, hand } // Process pre-hooks - they can modify the notification or skip processing - modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, notificationType, notification) + modifiedNotification, shouldContinue := snh.manager.processPreHooks(ctx, handlerCtx, notificationType, notification) if !shouldContinue { return nil // Hooks decided to skip processing } @@ -53,7 +53,7 @@ func (snh *NotificationHandler) HandlePushNotification(ctx context.Context, hand } // Process post-hooks with the result - snh.manager.processPostHooks(ctx, notificationType, modifiedNotification, err) + snh.manager.processPostHooks(ctx, handlerCtx, notificationType, modifiedNotification, err) return err } @@ -97,9 +97,7 @@ func (snh *NotificationHandler) handleMoving(ctx context.Context, handlerCtx pus // Type assert to get the underlying pool connection var poolConn *pool.Conn - if connAdapter, ok := conn.(interface{ GetPoolConn() *pool.Conn }); ok { - poolConn = connAdapter.GetPoolConn() - } else if pc, ok := conn.(*pool.Conn); ok { + if pc, ok := conn.(*pool.Conn); ok { poolConn = pc } else { internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MOVING notification - %T %#v", conn, handlerCtx) @@ -157,13 +155,11 @@ func (snh *NotificationHandler) handleMigrating(ctx context.Context, handlerCtx return ErrInvalidNotification } - // Get the connection from handler context and type assert to connectionAdapter if handlerCtx.Conn == nil { internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATING notification") return ErrInvalidNotification } - // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) if !ok { internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATING notification") @@ -184,13 +180,11 @@ func (snh *NotificationHandler) handleMigrated(ctx context.Context, handlerCtx p return ErrInvalidNotification } - // Get the connection from handler context and type assert to connectionAdapter if handlerCtx.Conn == nil { internal.Logger.Printf(ctx, "hitless: no connection in handler context for MIGRATED notification") return ErrInvalidNotification } - // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) if !ok { internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for MIGRATED notification") @@ -211,13 +205,11 @@ func (snh *NotificationHandler) handleFailingOver(ctx context.Context, handlerCt return ErrInvalidNotification } - // Get the connection from handler context and type assert to connectionAdapter if handlerCtx.Conn == nil { internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILING_OVER notification") return ErrInvalidNotification } - // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) if !ok { internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILING_OVER notification") @@ -238,13 +230,11 @@ func (snh *NotificationHandler) handleFailedOver(ctx context.Context, handlerCtx return ErrInvalidNotification } - // Get the connection from handler context and type assert to connectionAdapter if handlerCtx.Conn == nil { internal.Logger.Printf(ctx, "hitless: no connection in handler context for FAILED_OVER notification") return ErrInvalidNotification } - // Type assert to connectionAdapter which implements ConnectionWithRelaxedTimeout connAdapter, ok := handlerCtx.Conn.(interfaces.ConnectionWithRelaxedTimeout) if !ok { internal.Logger.Printf(ctx, "hitless: invalid connection type in handler context for FAILED_OVER notification") diff --git a/hitless/pool_hook.go b/hitless/pool_hook.go index e1b7d60361..9a1d7d35bf 100644 --- a/hitless/pool_hook.go +++ b/hitless/pool_hook.go @@ -117,7 +117,7 @@ func (ph *PoolHook) IsHandoffPending(conn *pool.Conn) bool { } // OnGet is called when a connection is retrieved from the pool -func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, isNewConn bool) error { +func (ph *PoolHook) OnGet(ctx context.Context, conn *pool.Conn, _ bool) error { // NOTE: There are two conditions to make sure we don't return a connection that should be handed off or is // in a handoff state at the moment. @@ -234,6 +234,7 @@ func (ph *PoolHook) onDemandWorker() { func (ph *PoolHook) processHandoffRequest(request HandoffRequest) { // Remove from pending map defer ph.pending.Delete(request.Conn.GetID()) + internal.Logger.Printf(context.Background(), "hitless: conn[%d] Processing handoff request start", request.Conn.GetID()) // Create a context with handoff timeout from config handoffTimeout := 30 * time.Second // Default fallback @@ -366,6 +367,7 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * } retries := conn.IncrementAndGetHandoffRetries(1) + internal.Logger.Printf(ctx, "hitless: conn[%d] Retry %d: Performing handoff to %s(was %s)", conn.GetID(), retries, newEndpoint, conn.RemoteAddr().String()) maxRetries := 3 // Default fallback if ph.config != nil { maxRetries = ph.config.MaxHandoffRetries @@ -387,6 +389,7 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * // Create new connection to the new endpoint newNetConn, err := endpointDialer(ctx) if err != nil { + internal.Logger.Printf(ctx, "hitless: conn[%d] Failed to dial new endpoint %s: %v", conn.GetID(), newEndpoint, err) // hitless: will retry // Maybe a network error - retry after a delay return true, err @@ -409,6 +412,7 @@ func (ph *PoolHook) performConnectionHandoffWithPool(ctx context.Context, conn * }() conn.ClearHandoffState() + internal.Logger.Printf(ctx, "hitless: conn[%d] Handoff to %s successful", conn.GetID(), newEndpoint) // Apply relaxed timeout to the new connection for the configured post-handoff duration // This gives the new connection more time to handle operations during cluster transition diff --git a/internal/pool/pubsub.go b/internal/pool/pubsub.go index a06abcd6b8..c616300f8b 100644 --- a/internal/pool/pubsub.go +++ b/internal/pool/pubsub.go @@ -5,6 +5,8 @@ import ( "net" "sync" "sync/atomic" + + "github.com/redis/go-redis/v9/internal" ) type PubSubStats struct { @@ -52,6 +54,9 @@ func (p *PubSubPool) TrackConn(cn *Conn) { } func (p *PubSubPool) UntrackConn(cn *Conn) { + if !cn.IsUsable() || cn.ShouldHandoff() { + internal.Logger.Printf(context.Background(), "pubsub: untracking connection %d [usable, handoff] = [%v, %v]", cn.GetID(), cn.IsUsable(), cn.ShouldHandoff()) + } atomic.AddUint32(&p.stats.Active, ^uint32(0)) atomic.AddUint32(&p.stats.Untracked, 1) p.activeConns.Delete(cn.GetID()) diff --git a/pubsub.go b/pubsub.go index 6db13a9a61..506ce1e6ed 100644 --- a/pubsub.go +++ b/pubsub.go @@ -170,10 +170,16 @@ func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allo } if !cn.IsUsable() || cn.ShouldHandoff() { + if cn.ShouldHandoff() { + internal.Logger.Printf(ctx, "pubsub: connection[%d] is marked for handoff, reconnecting", cn.GetID()) + } else { + internal.Logger.Printf(ctx, "pubsub: connection[%d] is not usable, reconnecting", cn.GetID()) + } c.reconnect(ctx, fmt.Errorf("pubsub: connection is not usable")) } if isBadConn(err, allowTimeout, c.opt.Addr) { + internal.Logger.Printf(ctx, "pubsub: releasing connection[%d]: %v", cn.GetID(), err) c.reconnect(ctx, err) } } @@ -187,7 +193,10 @@ func (c *PubSub) reconnect(ctx context.Context, reason error) { } if newEndpoint != "" { + // Update the address in the options + oldAddr := c.cn.RemoteAddr().String() c.opt.Addr = newEndpoint + internal.Logger.Printf(ctx, "pubsub: reconnecting to new endpoint %s (was %s)", newEndpoint, oldAddr) } } _ = c.closeTheCn(reason) @@ -199,7 +208,7 @@ func (c *PubSub) closeTheCn(reason error) error { return nil } if !c.closed { - internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason) + internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection[%d]: %s, %v", c.cn.GetID(), reason, c.cn.RemoteAddr()) } err := c.closeConn(c.cn) c.cn = nil diff --git a/push/handler_context.go b/push/handler_context.go index f89f87fa1b..c39e186b0d 100644 --- a/push/handler_context.go +++ b/push/handler_context.go @@ -37,7 +37,6 @@ type NotificationHandlerContext struct { // circular dependencies. The developer is responsible for type assertion. // It can be one of the following types: // - *pool.Conn - // - *connectionAdapter (for hitless upgrades) Conn interface{} // IsBlocking indicates if the notification was received on a blocking connection. diff --git a/redis.go b/redis.go index 13c7feca69..d6e7b4d320 100644 --- a/redis.go +++ b/redis.go @@ -1077,6 +1077,7 @@ func (c *Client) pubSub() *PubSub { // will return nil if already initialized err = c.initConn(ctx, cn) if err != nil { + internal.Logger.Printf(ctx, "pubsub: conn[%d] to ADDR %s [usable, handoff] = [%v, %v] after initConn returned %v", cn.GetID(), addr, cn.IsUsable(), cn.ShouldHandoff(), err) _ = cn.Close() return nil, err } @@ -1277,7 +1278,7 @@ func (c *baseClient) pushNotificationHandlerContext(cn *pool.Conn) push.Notifica return push.NotificationHandlerContext{ Client: c, ConnPool: c.connPool, - Conn: &connectionAdapter{conn: cn}, // Wrap in adapter for easier interface access + Conn: cn, // Wrap in adapter for easier interface access } } From 36157798ca494416fc57a53bc97be7eb7143cab8 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 25 Aug 2025 10:28:46 +0300 Subject: [PATCH 20/21] fix example tests, pop more connections --- example_instrumentation_test.go | 12 ++++++------ internal/pool/pool.go | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/example_instrumentation_test.go b/example_instrumentation_test.go index 73248e4c53..fa776fcf3b 100644 --- a/example_instrumentation_test.go +++ b/example_instrumentation_test.go @@ -57,8 +57,8 @@ func Example_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> - // starting processing: <[client maint_notifications on moving-endpoint-type external-ip]> - // finished processing: <[client maint_notifications on moving-endpoint-type external-ip]> + // starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> + // finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> // finished processing: <[ping]> } @@ -80,8 +80,8 @@ func ExamplePipeline_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> - // starting processing: <[client maint_notifications on moving-endpoint-type external-ip]> - // finished processing: <[client maint_notifications on moving-endpoint-type external-ip]> + // starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> + // finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> // pipeline finished processing: [[ping] [ping]] } @@ -103,8 +103,8 @@ func ExampleClient_Watch_instrumentation() { // finished dialing tcp :6379 // starting processing: <[hello 3]> // finished processing: <[hello 3]> - // starting processing: <[client maint_notifications on moving-endpoint-type external-ip]> - // finished processing: <[client maint_notifications on moving-endpoint-type external-ip]> + // starting processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> + // finished processing: <[client maint_notifications on moving-endpoint-type internal-fqdn]> // finished processing: <[watch foo]> // starting processing: <[ping]> // finished processing: <[ping]> diff --git a/internal/pool/pool.go b/internal/pool/pool.go index 7005222075..359b6442dc 100644 --- a/internal/pool/pool.go +++ b/internal/pool/pool.go @@ -26,8 +26,8 @@ var ( // popAttempts is the maximum number of attempts to find a usable connection // when popping from the idle connection pool. This handles cases where connections // are temporarily marked as unusable (e.g., during hitless upgrades or network issues). - // Value of 10 provides sufficient resilience without excessive overhead. - popAttempts = 10 + // Value of 20 provides sufficient resilience without excessive overhead. + popAttempts = 20 // getAttempts is the maximum number of attempts to get a connection that passes // hook validation (e.g., hitless upgrade hooks). This protects against race conditions From 5f608caf663d80f328ab451cef194d5d72e11e41 Mon Sep 17 00:00:00 2001 From: Nedyalko Dyakov Date: Mon, 25 Aug 2025 11:26:49 +0300 Subject: [PATCH 21/21] fix push processor exposed in opts --- redis.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/redis.go b/redis.go index d6e7b4d320..b896196f25 100644 --- a/redis.go +++ b/redis.go @@ -906,6 +906,8 @@ func NewClient(opt *Options) *Client { if opt == nil { panic("redis: NewClient nil options") } + // clone to not share options with the caller + opt = opt.clone() opt.init() // Push notifications are always enabled for RESP3 (cannot be disabled) @@ -920,8 +922,8 @@ func NewClient(opt *Options) *Client { // Initialize push notification processor using shared helper // Use void processor for RESP2 connections (push notifications not available) c.pushProcessor = initializePushProcessor(opt) - // Update options with the initialized push processor - opt.PushNotificationProcessor = c.pushProcessor + // set opt push processor for child clients + c.opt.PushNotificationProcessor = c.pushProcessor // Create connection pools var err error