Skip to content

Commit f98dffa

Browse files
committed
fix: shard queue panic
1 parent 1d17d4d commit f98dffa

File tree

1 file changed

+73
-47
lines changed

1 file changed

+73
-47
lines changed

mux/shard_queue.go

Lines changed: 73 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ package mux
1717
import (
1818
"fmt"
1919
"runtime"
20-
"sync"
2120
"sync/atomic"
2221

2322
"github.com/bytedance/gopkg/util/gopool"
@@ -43,16 +42,18 @@ func init() {
4342
// NewShardQueue .
4443
func NewShardQueue(size int, conn netpoll.Connection) (queue *ShardQueue) {
4544
queue = &ShardQueue{
46-
conn: conn,
47-
size: int32(size),
48-
getters: make([][]WriterGetter, size),
49-
swap: make([]WriterGetter, 0, 64),
50-
locks: make([]int32, size),
45+
conn: conn,
46+
size: int32(size),
47+
getters: make([][]WriterGetter, size),
48+
swap: make([]WriterGetter, 0, 64),
49+
locks: make([]int32, size),
50+
closeNotif: make(chan struct{}),
5151
}
5252
for i := range queue.getters {
5353
queue.getters[i] = make([]WriterGetter, 0, 64)
5454
}
55-
queue.list = make([]int32, size)
55+
// To avoid w equals to r when loop writing, make list larger than size.
56+
queue.list = make([]int32, size+1)
5657
return queue
5758
}
5859

@@ -69,6 +70,8 @@ type ShardQueue struct {
6970
getters [][]WriterGetter // len(getters) = size
7071
swap []WriterGetter // use for swap
7172
locks []int32 // len(locks) = size
73+
74+
closeNotif chan struct{}
7275
queueTrigger
7376
}
7477

@@ -81,17 +84,29 @@ const (
8184

8285
// here for trigger
8386
type queueTrigger struct {
84-
trigger int32
85-
state int32 // 0: active, 1: closing, 2: closed
86-
runNum int32
87-
w, r int32 // ptr of list
88-
list []int32 // record the triggered shard
89-
listLock sync.Mutex // list total lock
87+
bufNum int32
88+
state int32 // 0: active, 1: closing, 2: closed
89+
runNum int32
90+
w, r int32 // ptr of list
91+
list []int32 // record the triggered shard
92+
}
93+
94+
func (q *queueTrigger) length() int {
95+
w := int(atomic.LoadInt32(&q.w))
96+
r := int(atomic.LoadInt32(&q.r))
97+
if w < r {
98+
w += len(q.list)
99+
}
100+
return w - r
90101
}
91102

92103
// Add adds to q.getters[shard]
93104
func (q *ShardQueue) Add(gts ...WriterGetter) {
105+
atomic.AddInt32(&q.bufNum, 1)
94106
if atomic.LoadInt32(&q.state) != active {
107+
if atomic.AddInt32(&q.bufNum, -1) <= 0 {
108+
close(q.closeNotif)
109+
}
95110
return
96111
}
97112
shard := atomic.AddInt32(&q.idx, 1) % q.size
@@ -109,90 +124,101 @@ func (q *ShardQueue) Close() error {
109124
return fmt.Errorf("shardQueue has been closed")
110125
}
111126
// wait for all tasks finished
112-
for atomic.LoadInt32(&q.state) != closed {
113-
if atomic.LoadInt32(&q.trigger) == 0 {
114-
atomic.StoreInt32(&q.state, closed)
115-
return nil
127+
if atomic.LoadInt32(&q.bufNum) == 0 {
128+
atomic.StoreInt32(&q.state, closed)
129+
} else {
130+
select {
131+
case <-q.closeNotif:
116132
}
117-
runtime.Gosched()
133+
atomic.StoreInt32(&q.state, closed)
118134
}
119135
return nil
120136
}
121137

122138
// triggering shard.
123139
func (q *ShardQueue) triggering(shard int32) {
124-
q.listLock.Lock()
125-
q.w = (q.w + 1) % q.size
126-
q.list[q.w] = shard
127-
q.listLock.Unlock()
128-
129-
if atomic.AddInt32(&q.trigger, 1) > 1 {
130-
return
140+
for {
141+
ow := atomic.LoadInt32(&q.w)
142+
nw := (ow + 1) % int32(len(q.list))
143+
if atomic.CompareAndSwapInt32(&q.w, ow, nw) {
144+
q.list[nw] = shard
145+
break
146+
}
131147
}
132148
q.foreach()
133149
}
134150

135-
// foreach swap r & w. It's not concurrency safe.
151+
// foreach swap r & w.
136152
func (q *ShardQueue) foreach() {
137153
if atomic.AddInt32(&q.runNum, 1) > 1 {
138154
return
139155
}
140156
gopool.CtxGo(nil, func() {
141-
var negNum int32 // is negative number of triggerNum
142-
for triggerNum := atomic.LoadInt32(&q.trigger); triggerNum > 0; {
143-
q.r = (q.r + 1) % q.size
144-
shared := q.list[q.r]
157+
var negBufNum int32 // is negative number of bufNum
158+
for q.length() > 0 {
159+
nr := (atomic.LoadInt32(&q.r) + 1) % int32(len(q.list))
160+
atomic.StoreInt32(&q.r, nr)
161+
shard := q.list[nr]
145162

146163
// lock & swap
147-
q.lock(shared)
148-
tmp := q.getters[shared]
149-
q.getters[shared] = q.swap[:0]
164+
q.lock(shard)
165+
tmp := q.getters[shard]
166+
q.getters[shard] = q.swap[:0]
150167
q.swap = tmp
151-
q.unlock(shared)
168+
q.unlock(shard)
152169

153170
// deal
154-
q.deal(q.swap)
155-
negNum--
156-
if triggerNum+negNum == 0 {
157-
triggerNum = atomic.AddInt32(&q.trigger, negNum)
158-
negNum = 0
171+
if err := q.deal(q.swap); err != nil {
172+
close(q.closeNotif)
173+
return
174+
}
175+
negBufNum -= int32(len(q.swap))
176+
}
177+
if negBufNum < 0 {
178+
if err := q.flush(); err != nil {
179+
close(q.closeNotif)
180+
return
159181
}
160182
}
161-
q.flush()
183+
184+
// MUST decrease bufNum first.
185+
if atomic.AddInt32(&q.bufNum, negBufNum) <= 0 && atomic.LoadInt32(&q.state) != active {
186+
close(q.closeNotif)
187+
return
188+
}
162189

163190
// quit & check again
164191
atomic.StoreInt32(&q.runNum, 0)
165-
if atomic.LoadInt32(&q.trigger) > 0 {
192+
if q.length() > 0 {
166193
q.foreach()
167194
return
168195
}
169-
// if state is closing, change it to closed
170-
atomic.CompareAndSwapInt32(&q.state, closing, closed)
171196
})
172197
}
173198

174199
// deal is used to get deal of netpoll.Writer.
175-
func (q *ShardQueue) deal(gts []WriterGetter) {
200+
func (q *ShardQueue) deal(gts []WriterGetter) error {
176201
writer := q.conn.Writer()
177202
for _, gt := range gts {
178203
buf, isNil := gt()
179204
if !isNil {
180205
err := writer.Append(buf)
181206
if err != nil {
182207
q.conn.Close()
183-
return
208+
return err
184209
}
185210
}
186211
}
212+
return nil
187213
}
188214

189215
// flush is used to flush netpoll.Writer.
190-
func (q *ShardQueue) flush() {
216+
func (q *ShardQueue) flush() error {
191217
err := q.conn.Writer().Flush()
192218
if err != nil {
193219
q.conn.Close()
194-
return
195220
}
221+
return err
196222
}
197223

198224
// lock shard.

0 commit comments

Comments
 (0)