Skip to content

Commit 11d37e7

Browse files
committed
fix: shard queue panic
1 parent 1d17d4d commit 11d37e7

File tree

1 file changed

+78
-46
lines changed

1 file changed

+78
-46
lines changed

mux/shard_queue.go

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ package mux
1717
import (
1818
"fmt"
1919
"runtime"
20-
"sync"
2120
"sync/atomic"
21+
"time"
2222

2323
"github.com/bytedance/gopkg/util/gopool"
2424

@@ -43,16 +43,18 @@ func init() {
4343
// NewShardQueue .
4444
func NewShardQueue(size int, conn netpoll.Connection) (queue *ShardQueue) {
4545
queue = &ShardQueue{
46-
conn: conn,
47-
size: int32(size),
48-
getters: make([][]WriterGetter, size),
49-
swap: make([]WriterGetter, 0, 64),
50-
locks: make([]int32, size),
46+
conn: conn,
47+
size: int32(size),
48+
getters: make([][]WriterGetter, size),
49+
swap: make([]WriterGetter, 0, 64),
50+
locks: make([]int32, size),
51+
closeNotif: make(chan struct{}),
5152
}
5253
for i := range queue.getters {
5354
queue.getters[i] = make([]WriterGetter, 0, 64)
5455
}
55-
queue.list = make([]int32, size)
56+
// To avoid w equals to r when loop writing, make list larger than size.
57+
queue.list = make([]int32, size+1)
5658
return queue
5759
}
5860

@@ -69,6 +71,8 @@ type ShardQueue struct {
6971
getters [][]WriterGetter // len(getters) = size
7072
swap []WriterGetter // use for swap
7173
locks []int32 // len(locks) = size
74+
75+
closeNotif chan struct{}
7276
queueTrigger
7377
}
7478

@@ -81,17 +85,29 @@ const (
8185

8286
// here for trigger
8387
type queueTrigger struct {
84-
trigger int32
85-
state int32 // 0: active, 1: closing, 2: closed
86-
runNum int32
87-
w, r int32 // ptr of list
88-
list []int32 // record the triggered shard
89-
listLock sync.Mutex // list total lock
88+
bufNum int32
89+
state int32 // 0: active, 1: closing, 2: closed
90+
runNum int32
91+
w, r int32 // ptr of list
92+
list []int32 // record the triggered shard
93+
}
94+
95+
func (q *queueTrigger) length() int {
96+
w := int(atomic.LoadInt32(&q.w))
97+
r := int(atomic.LoadInt32(&q.r))
98+
if w < r {
99+
w += len(q.list)
100+
}
101+
return w - r
90102
}
91103

92104
// Add adds to q.getters[shard]
93105
func (q *ShardQueue) Add(gts ...WriterGetter) {
106+
atomic.AddInt32(&q.bufNum, 1)
94107
if atomic.LoadInt32(&q.state) != active {
108+
if atomic.AddInt32(&q.bufNum, -1) <= 0 {
109+
close(q.closeNotif)
110+
}
95111
return
96112
}
97113
shard := atomic.AddInt32(&q.idx, 1) % q.size
@@ -109,90 +125,106 @@ func (q *ShardQueue) Close() error {
109125
return fmt.Errorf("shardQueue has been closed")
110126
}
111127
// wait for all tasks finished
112-
for atomic.LoadInt32(&q.state) != closed {
113-
if atomic.LoadInt32(&q.trigger) == 0 {
128+
if atomic.LoadInt32(&q.bufNum) == 0 {
129+
atomic.StoreInt32(&q.state, closed)
130+
} else {
131+
timeout := time.NewTimer(3 * time.Second)
132+
select {
133+
case <-q.closeNotif:
114134
atomic.StoreInt32(&q.state, closed)
115-
return nil
135+
timeout.Stop()
136+
case <-timeout.C:
137+
atomic.StoreInt32(&q.state, closed)
138+
return fmt.Errorf("shardQueue close timeout")
116139
}
117-
runtime.Gosched()
118140
}
119141
return nil
120142
}
121143

122144
// triggering shard.
123145
func (q *ShardQueue) triggering(shard int32) {
124-
q.listLock.Lock()
125-
q.w = (q.w + 1) % q.size
126-
q.list[q.w] = shard
127-
q.listLock.Unlock()
128-
129-
if atomic.AddInt32(&q.trigger, 1) > 1 {
130-
return
146+
for {
147+
ow := atomic.LoadInt32(&q.w)
148+
nw := (ow + 1) % int32(len(q.list))
149+
if atomic.CompareAndSwapInt32(&q.w, ow, nw) {
150+
q.list[nw] = shard
151+
break
152+
}
131153
}
132154
q.foreach()
133155
}
134156

135-
// foreach swap r & w. It's not concurrency safe.
157+
// foreach swap r & w.
136158
func (q *ShardQueue) foreach() {
137159
if atomic.AddInt32(&q.runNum, 1) > 1 {
138160
return
139161
}
140162
gopool.CtxGo(nil, func() {
141-
var negNum int32 // is negative number of triggerNum
142-
for triggerNum := atomic.LoadInt32(&q.trigger); triggerNum > 0; {
143-
q.r = (q.r + 1) % q.size
144-
shared := q.list[q.r]
163+
var negBufNum int32 // is negative number of bufNum
164+
for q.length() > 0 {
165+
nr := (atomic.LoadInt32(&q.r) + 1) % int32(len(q.list))
166+
atomic.StoreInt32(&q.r, nr)
167+
shard := q.list[nr]
145168

146169
// lock & swap
147-
q.lock(shared)
148-
tmp := q.getters[shared]
149-
q.getters[shared] = q.swap[:0]
170+
q.lock(shard)
171+
tmp := q.getters[shard]
172+
q.getters[shard] = q.swap[:0]
150173
q.swap = tmp
151-
q.unlock(shared)
174+
q.unlock(shard)
152175

153176
// deal
154-
q.deal(q.swap)
155-
negNum--
156-
if triggerNum+negNum == 0 {
157-
triggerNum = atomic.AddInt32(&q.trigger, negNum)
158-
negNum = 0
177+
if err := q.deal(q.swap); err != nil {
178+
close(q.closeNotif)
179+
return
159180
}
181+
negBufNum -= int32(len(q.swap))
182+
}
183+
if negBufNum < 0 {
184+
if err := q.flush(); err != nil {
185+
close(q.closeNotif)
186+
return
187+
}
188+
}
189+
190+
// MUST decrease bufNum first.
191+
if atomic.AddInt32(&q.bufNum, negBufNum) <= 0 && atomic.LoadInt32(&q.state) != active {
192+
close(q.closeNotif)
193+
return
160194
}
161-
q.flush()
162195

163196
// quit & check again
164197
atomic.StoreInt32(&q.runNum, 0)
165-
if atomic.LoadInt32(&q.trigger) > 0 {
198+
if q.length() > 0 {
166199
q.foreach()
167200
return
168201
}
169-
// if state is closing, change it to closed
170-
atomic.CompareAndSwapInt32(&q.state, closing, closed)
171202
})
172203
}
173204

174205
// deal is used to get deal of netpoll.Writer.
175-
func (q *ShardQueue) deal(gts []WriterGetter) {
206+
func (q *ShardQueue) deal(gts []WriterGetter) error {
176207
writer := q.conn.Writer()
177208
for _, gt := range gts {
178209
buf, isNil := gt()
179210
if !isNil {
180211
err := writer.Append(buf)
181212
if err != nil {
182213
q.conn.Close()
183-
return
214+
return err
184215
}
185216
}
186217
}
218+
return nil
187219
}
188220

189221
// flush is used to flush netpoll.Writer.
190-
func (q *ShardQueue) flush() {
222+
func (q *ShardQueue) flush() error {
191223
err := q.conn.Writer().Flush()
192224
if err != nil {
193225
q.conn.Close()
194-
return
195226
}
227+
return err
196228
}
197229

198230
// lock shard.

0 commit comments

Comments
 (0)