Skip to content

Commit a39a23a

Browse files
committed
pubsub pool
1 parent e9f32f0 commit a39a23a

15 files changed

+765
-191
lines changed

example/pubsub/main.go

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@ import (
1111
)
1212

1313
var ctx = context.Background()
14-
var consStopped = false
1514

15+
// This example is not supposed to be run as is. It is just a test to see how pubsub behaves in relation to pool management.
16+
// It was used to find regressions in pool management in hitless mode.
17+
// Please don't use it as a reference for how to use pubsub.
1618
func main() {
1719
wg := &sync.WaitGroup{}
1820
rdb := redis.NewClient(&redis.Options{
19-
Addr: ":6379",
21+
Addr: ":6379",
22+
HitlessUpgrades: true,
2023
})
2124
_ = rdb.FlushDB(ctx).Err()
2225

@@ -30,21 +33,22 @@ func main() {
3033
if err != nil {
3134
panic(err)
3235
}
33-
if err := rdb.Set(ctx, "prods", "0", 0).Err(); err != nil {
36+
if err := rdb.Set(ctx, "publishers", "0", 0).Err(); err != nil {
3437
panic(err)
3538
}
36-
if err := rdb.Set(ctx, "cons", "0", 0).Err(); err != nil {
39+
if err := rdb.Set(ctx, "subscribers", "0", 0).Err(); err != nil {
3740
panic(err)
3841
}
39-
if err := rdb.Set(ctx, "cntr", "0", 0).Err(); err != nil {
42+
if err := rdb.Set(ctx, "published", "0", 0).Err(); err != nil {
4043
panic(err)
4144
}
42-
if err := rdb.Set(ctx, "recs", "0", 0).Err(); err != nil {
45+
if err := rdb.Set(ctx, "received", "0", 0).Err(); err != nil {
4346
panic(err)
4447
}
45-
fmt.Println("cntr", rdb.Get(ctx, "cntr").Val())
46-
fmt.Println("recs", rdb.Get(ctx, "recs").Val())
48+
fmt.Println("published", rdb.Get(ctx, "published").Val())
49+
fmt.Println("received", rdb.Get(ctx, "received").Val())
4750
subCtx, cancelSubCtx := context.WithCancel(ctx)
51+
pubCtx, cancelPublishers := context.WithCancel(ctx)
4852
for i := 0; i < 10; i++ {
4953
wg.Add(1)
5054
go subscribe(subCtx, rdb, "test", i, wg)
@@ -54,32 +58,39 @@ func main() {
5458
time.Sleep(time.Second)
5559
subCtx, cancelSubCtx = context.WithCancel(ctx)
5660
for i := 0; i < 10; i++ {
57-
if err := rdb.Incr(ctx, "prods").Err(); err != nil {
61+
if err := rdb.Incr(ctx, "publishers").Err(); err != nil {
5862
panic(err)
5963
}
6064
wg.Add(1)
61-
go floodThePool(subCtx, rdb, wg)
65+
go floodThePool(pubCtx, rdb, wg)
6266
}
6367

6468
for i := 0; i < 500; i++ {
65-
if err := rdb.Incr(ctx, "cons").Err(); err != nil {
69+
if err := rdb.Incr(ctx, "subscribers").Err(); err != nil {
6670
panic(err)
6771
}
6872
wg.Add(1)
6973
go subscribe(subCtx, rdb, "test2", i, wg)
7074
}
75+
time.Sleep(5 * time.Second)
76+
fmt.Println("canceling publishers")
77+
cancelPublishers()
7178
time.Sleep(10 * time.Second)
72-
fmt.Println("canceling")
79+
fmt.Println("canceling subscribers")
7380
cancelSubCtx()
7481
wg.Wait()
75-
cntr, err := rdb.Get(ctx, "cntr").Result()
76-
recs, err := rdb.Get(ctx, "recs").Result()
77-
prods, err := rdb.Get(ctx, "prods").Result()
78-
cons, err := rdb.Get(ctx, "cons").Result()
79-
fmt.Printf("cntr: %s\n", cntr)
80-
fmt.Printf("recs: %s\n", recs)
81-
fmt.Printf("prods: %s\n", prods)
82-
fmt.Printf("cons: %s\n", cons)
82+
published, err := rdb.Get(ctx, "published").Result()
83+
received, err := rdb.Get(ctx, "received").Result()
84+
publishers, err := rdb.Get(ctx, "publishers").Result()
85+
subscribers, err := rdb.Get(ctx, "subscribers").Result()
86+
fmt.Printf("publishers: %s\n", publishers)
87+
fmt.Printf("published: %s\n", published)
88+
fmt.Printf("subscribers: %s\n", subscribers)
89+
fmt.Printf("received: %s\n", received)
90+
publishedInt, err := rdb.Get(ctx, "published").Int()
91+
subscribersInt, err := rdb.Get(ctx, "subscribers").Int()
92+
fmt.Printf("if drained = published*subscribers: %d\n", publishedInt*subscribersInt)
93+
8394
time.Sleep(2 * time.Second)
8495
}
8596

@@ -88,8 +99,6 @@ func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
8899
for {
89100
select {
90101
case <-ctx.Done():
91-
fmt.Println("floodThePool stopping")
92-
consStopped = true
93102
return
94103
default:
95104
}
@@ -99,7 +108,7 @@ func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
99108
//log.Println("publish error:", err)
100109
}
101110

102-
err = rdb.Incr(ctx, "cntr").Err()
111+
err = rdb.Incr(ctx, "published").Err()
103112
if err != nil {
104113
// noop
105114
//log.Println("incr error:", err)
@@ -110,36 +119,24 @@ func floodThePool(ctx context.Context, rdb *redis.Client, wg *sync.WaitGroup) {
110119

111120
func subscribe(ctx context.Context, rdb *redis.Client, topic string, subscriberId int, wg *sync.WaitGroup) {
112121
defer wg.Done()
113-
defer fmt.Printf("subscriber %d stopping\n", subscriberId)
114122
rec := rdb.Subscribe(ctx, topic)
115123
recChan := rec.Channel()
116124
for {
117125
select {
118126
case <-ctx.Done():
119127
rec.Close()
120-
if subscriberId == 199 {
121-
fmt.Printf("subscriber %d done\n", subscriberId)
122-
}
123128
return
124129
default:
125130
select {
126131
case <-ctx.Done():
127132
rec.Close()
128-
if subscriberId == 199 {
129-
fmt.Printf("subscriber %d done\n", subscriberId)
130-
}
131133
return
132134
case msg := <-recChan:
133-
err := rdb.Incr(ctx, "recs").Err()
135+
err := rdb.Incr(ctx, "received").Err()
134136
if err != nil {
135137
log.Println("incr error:", err)
136138
}
137-
if consStopped {
138-
fmt.Printf("subscriber %d received %s\n", subscriberId, msg.Payload)
139-
}
140-
if subscriberId == 199 {
141-
fmt.Printf("subscriber %d received %s\n", subscriberId, msg.Payload)
142-
}
139+
_ = msg // Use the message to avoid unused variable warning
143140
}
144141
}
145142
}

hitless/config.go

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -314,13 +314,14 @@ func isPrivateIP(addr string) bool {
314314
// Simplified check for common private IP ranges
315315
// 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
316316
// This is a simplified implementation; a full implementation would parse the IP properly
317-
if len(addr) >= 3 {
318-
if addr[:3] == "10." || addr[:8] == "192.168." {
319-
return true
320-
}
321-
if len(addr) >= 7 && addr[:7] == "172.16." {
322-
return true
323-
}
317+
if len(addr) >= 3 && addr[:3] == "10." {
318+
return true
319+
}
320+
if len(addr) >= 8 && addr[:8] == "192.168." {
321+
return true
322+
}
323+
if len(addr) >= 7 && addr[:7] == "172.16." {
324+
return true
324325
}
325326
return false
326327
}

hitless/redis_connection_processor_test.go

Lines changed: 108 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,6 @@ func (mp *mockPool) Get(ctx context.Context) (*pool.Conn, error) {
6060
return nil, errors.New("not implemented")
6161
}
6262

63-
func (mp *mockPool) GetPubSub(ctx context.Context) (*pool.Conn, error) {
64-
return nil, errors.New("not implemented")
65-
}
66-
6763
func (mp *mockPool) Put(ctx context.Context, conn *pool.Conn) {
6864
// Not implemented for testing
6965
}
@@ -107,18 +103,27 @@ func TestRedisConnectionProcessor(t *testing.T) {
107103
}
108104

109105
t.Run("SuccessfulEventDrivenHandoff", func(t *testing.T) {
110-
processor := NewRedisConnectionProcessor(3, baseDialer, nil, nil)
106+
config := &Config{
107+
MinWorkers: 1,
108+
MaxWorkers: 2,
109+
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
110+
LogLevel: 2,
111+
}
112+
processor := NewRedisConnectionProcessor(3, baseDialer, config, nil)
111113
defer processor.Shutdown(context.Background())
112114

113115
conn := createMockPoolConnection()
114116
if err := conn.MarkForHandoff("new-endpoint:6379", 12345); err != nil {
115117
t.Fatalf("Failed to mark connection for handoff: %v", err)
116118
}
117119

118-
// Set a mock initialization function
119-
initConnCalled := false
120+
// Set a mock initialization function with synchronization
121+
initConnCalled := make(chan bool, 1)
120122
initConnFunc := func(ctx context.Context, cn *pool.Conn) error {
121-
initConnCalled = true
123+
select {
124+
case initConnCalled <- true:
125+
default:
126+
}
122127
return nil
123128
}
124129
conn.SetInitConnFunc(initConnFunc)
@@ -142,22 +147,44 @@ func TestRedisConnectionProcessor(t *testing.T) {
142147
t.Error("Connection should be in pending handoffs map")
143148
}
144149

145-
// Wait for handoff to complete
146-
time.Sleep(100 * time.Millisecond)
150+
// Wait for initialization to be called (indicates handoff started)
151+
select {
152+
case <-initConnCalled:
153+
// Good, initialization was called
154+
case <-time.After(1 * time.Second):
155+
t.Fatal("Timeout waiting for initialization function to be called")
156+
}
157+
158+
// Wait for handoff to complete with proper timeout and polling
159+
timeout := time.After(2 * time.Second)
160+
ticker := time.NewTicker(10 * time.Millisecond)
161+
defer ticker.Stop()
162+
163+
handoffCompleted := false
164+
for !handoffCompleted {
165+
select {
166+
case <-timeout:
167+
t.Fatal("Timeout waiting for handoff to complete")
168+
case <-ticker.C:
169+
if _, pending := processor.pending.Load(conn); !pending {
170+
handoffCompleted = true
171+
}
172+
}
173+
}
147174

148175
// Verify handoff completed (removed from pending map)
149176
if _, pending := processor.pending.Load(conn); pending {
150177
t.Error("Connection should be removed from pending map after handoff")
151178
}
152179

153-
// Verify handoff state was cleared
154-
if conn.ShouldHandoff() {
155-
t.Error("Connection should not be marked for handoff after successful handoff")
180+
// Verify connection is usable again
181+
if !conn.IsUsable() {
182+
t.Error("Connection should be usable after successful handoff")
156183
}
157184

158-
// Verify initialization was called
159-
if !initConnCalled {
160-
t.Error("InitConn should have been called")
185+
// Verify handoff state is cleared
186+
if conn.ShouldHandoff() {
187+
t.Error("Connection should not be marked for handoff after completion")
161188
}
162189
})
163190

@@ -214,7 +241,13 @@ func TestRedisConnectionProcessor(t *testing.T) {
214241
return nil, errors.New("dial failed")
215242
}
216243

217-
processor := NewRedisConnectionProcessor(3, failingDialer, nil, nil)
244+
config := &Config{
245+
MinWorkers: 1,
246+
MaxWorkers: 2,
247+
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
248+
LogLevel: 2,
249+
}
250+
processor := NewRedisConnectionProcessor(3, failingDialer, config, nil)
218251
defer processor.Shutdown(context.Background())
219252

220253
conn := createMockPoolConnection()
@@ -236,8 +269,22 @@ func TestRedisConnectionProcessor(t *testing.T) {
236269
t.Error("Connection should not be removed when queuing handoff")
237270
}
238271

239-
// Wait for handoff to complete and fail
240-
time.Sleep(100 * time.Millisecond)
272+
// Wait for handoff to complete and fail with proper timeout and polling
273+
timeout := time.After(2 * time.Second)
274+
ticker := time.NewTicker(10 * time.Millisecond)
275+
defer ticker.Stop()
276+
277+
handoffCompleted := false
278+
for !handoffCompleted {
279+
select {
280+
case <-timeout:
281+
t.Fatal("Timeout waiting for failed handoff to complete")
282+
case <-ticker.C:
283+
if _, pending := processor.pending.Load(conn); !pending {
284+
handoffCompleted = true
285+
}
286+
}
287+
}
241288

242289
// Connection should be removed from pending map after failed handoff
243290
if _, pending := processor.pending.Load(conn); pending {
@@ -285,7 +332,13 @@ func TestRedisConnectionProcessor(t *testing.T) {
285332
})
286333

287334
t.Run("ProcessConnectionOnGetWithPendingHandoff", func(t *testing.T) {
288-
processor := NewRedisConnectionProcessor(3, baseDialer, nil, nil)
335+
config := &Config{
336+
MinWorkers: 1,
337+
MaxWorkers: 2,
338+
HandoffQueueSize: 10, // Explicit queue size to avoid 0-size queue
339+
LogLevel: 2,
340+
}
341+
processor := NewRedisConnectionProcessor(3, baseDialer, config, nil)
289342
defer processor.Shutdown(context.Background())
290343

291344
conn := createMockPoolConnection()
@@ -468,8 +521,22 @@ func TestRedisConnectionProcessor(t *testing.T) {
468521
t.Error("Connection should be pooled after handoff")
469522
}
470523

471-
// Wait for handoff to complete
472-
time.Sleep(50 * time.Millisecond)
524+
// Wait for handoff to complete with proper timeout and polling
525+
timeout := time.After(1 * time.Second)
526+
ticker := time.NewTicker(5 * time.Millisecond)
527+
defer ticker.Stop()
528+
529+
handoffCompleted := false
530+
for !handoffCompleted {
531+
select {
532+
case <-timeout:
533+
t.Fatal("Timeout waiting for handoff to complete")
534+
case <-ticker.C:
535+
if _, pending := processor.pending.Load(conn); !pending {
536+
handoffCompleted = true
537+
}
538+
}
539+
}
473540

474541
// Verify relaxed timeout is set with deadline
475542
if !conn.HasRelaxedTimeout() {
@@ -626,17 +693,15 @@ func TestRedisConnectionProcessor(t *testing.T) {
626693
}
627694
}
628695

629-
// Verify queue has items but capacity remains static
630-
currentQueueSize := len(processor.handoffQueue)
631-
if currentQueueSize == 0 {
632-
t.Error("Expected some items in queue after processing connections")
633-
}
634-
696+
// Verify queue capacity remains static (the main purpose of this test)
635697
finalCapacity := cap(processor.handoffQueue)
636698
if finalCapacity != 50 {
637699
t.Errorf("Queue capacity should remain static at 50, got %d", finalCapacity)
638700
}
639701

702+
// Note: We don't check queue size here because workers process items quickly
703+
// The important thing is that the capacity remains static regardless of pool size
704+
currentQueueSize := len(processor.handoffQueue)
640705
t.Logf("Static queue test completed - Capacity: %d, Current size: %d",
641706
finalCapacity, currentQueueSize)
642707
})
@@ -738,7 +803,21 @@ func TestRedisConnectionProcessor(t *testing.T) {
738803
}
739804

740805
// Wait for the handoff to complete (it happens asynchronously)
741-
time.Sleep(50 * time.Millisecond)
806+
timeout := time.After(1 * time.Second)
807+
ticker := time.NewTicker(5 * time.Millisecond)
808+
defer ticker.Stop()
809+
810+
handoffCompleted := false
811+
for !handoffCompleted {
812+
select {
813+
case <-timeout:
814+
t.Fatal("Timeout waiting for handoff to complete")
815+
case <-ticker.C:
816+
if _, pending := processor.pending.Load(conn); !pending {
817+
handoffCompleted = true
818+
}
819+
}
820+
}
742821

743822
// Verify that relaxed timeout was applied to the new connection
744823
if !conn.HasRelaxedTimeout() {

0 commit comments

Comments
 (0)