@@ -17,7 +17,6 @@ package mux
17
17
import (
18
18
"fmt"
19
19
"runtime"
20
- "sync"
21
20
"sync/atomic"
22
21
23
22
"github.com/bytedance/gopkg/util/gopool"
@@ -43,16 +42,18 @@ func init() {
43
42
// NewShardQueue .
44
43
func NewShardQueue (size int , conn netpoll.Connection ) (queue * ShardQueue ) {
45
44
queue = & ShardQueue {
46
- conn : conn ,
47
- size : int32 (size ),
48
- getters : make ([][]WriterGetter , size ),
49
- swap : make ([]WriterGetter , 0 , 64 ),
50
- locks : make ([]int32 , size ),
45
+ conn : conn ,
46
+ size : int32 (size ),
47
+ getters : make ([][]WriterGetter , size ),
48
+ swap : make ([]WriterGetter , 0 , 64 ),
49
+ locks : make ([]int32 , size ),
50
+ closeNotif : make (chan struct {}),
51
51
}
52
52
for i := range queue .getters {
53
53
queue .getters [i ] = make ([]WriterGetter , 0 , 64 )
54
54
}
55
- queue .list = make ([]int32 , size )
55
+ // To avoid w equals to r when loop writing, make list larger than size.
56
+ queue .list = make ([]int32 , size + 1 )
56
57
return queue
57
58
}
58
59
@@ -69,6 +70,8 @@ type ShardQueue struct {
69
70
getters [][]WriterGetter // len(getters) = size
70
71
swap []WriterGetter // use for swap
71
72
locks []int32 // len(locks) = size
73
+
74
+ closeNotif chan struct {}
72
75
queueTrigger
73
76
}
74
77
@@ -81,17 +84,29 @@ const (
81
84
82
85
// here for trigger
83
86
type queueTrigger struct {
84
- trigger int32
85
- state int32 // 0: active, 1: closing, 2: closed
86
- runNum int32
87
- w , r int32 // ptr of list
88
- list []int32 // record the triggered shard
89
- listLock sync.Mutex // list total lock
87
+ bufNum int32
88
+ state int32 // 0: active, 1: closing, 2: closed
89
+ runNum int32
90
+ w , r int32 // ptr of list
91
+ list []int32 // record the triggered shard
92
+ }
93
+
94
+ func (q * queueTrigger ) length () int {
95
+ w := int (atomic .LoadInt32 (& q .w ))
96
+ r := int (atomic .LoadInt32 (& q .r ))
97
+ if w < r {
98
+ w += len (q .list )
99
+ }
100
+ return w - r
90
101
}
91
102
92
103
// Add adds to q.getters[shard]
93
104
func (q * ShardQueue ) Add (gts ... WriterGetter ) {
105
+ atomic .AddInt32 (& q .bufNum , 1 )
94
106
if atomic .LoadInt32 (& q .state ) != active {
107
+ if atomic .AddInt32 (& q .bufNum , - 1 ) <= 0 {
108
+ close (q .closeNotif )
109
+ }
95
110
return
96
111
}
97
112
shard := atomic .AddInt32 (& q .idx , 1 ) % q .size
@@ -109,90 +124,101 @@ func (q *ShardQueue) Close() error {
109
124
return fmt .Errorf ("shardQueue has been closed" )
110
125
}
111
126
// wait for all tasks finished
112
- for atomic .LoadInt32 (& q .state ) != closed {
113
- if atomic .LoadInt32 (& q .trigger ) == 0 {
114
- atomic .StoreInt32 (& q .state , closed )
115
- return nil
127
+ if atomic .LoadInt32 (& q .bufNum ) == 0 {
128
+ atomic .StoreInt32 (& q .state , closed )
129
+ } else {
130
+ select {
131
+ case <- q .closeNotif :
116
132
}
117
- runtime . Gosched ( )
133
+ atomic . StoreInt32 ( & q . state , closed )
118
134
}
119
135
return nil
120
136
}
121
137
122
138
// triggering shard.
123
139
func (q * ShardQueue ) triggering (shard int32 ) {
124
- q . listLock . Lock ()
125
- q . w = ( q .w + 1 ) % q . size
126
- q . list [ q . w ] = shard
127
- q . listLock . Unlock ()
128
-
129
- if atomic . AddInt32 ( & q . trigger , 1 ) > 1 {
130
- return
140
+ for {
141
+ ow := atomic . LoadInt32 ( & q .w )
142
+ nw := ( ow + 1 ) % int32 ( len ( q . list ))
143
+ if atomic . CompareAndSwapInt32 ( & q . w , ow , nw ) {
144
+ q . list [ nw ] = shard
145
+ break
146
+ }
131
147
}
132
148
q .foreach ()
133
149
}
134
150
135
- // foreach swap r & w. It's not concurrency safe.
151
+ // foreach swap r & w.
136
152
func (q * ShardQueue ) foreach () {
137
153
if atomic .AddInt32 (& q .runNum , 1 ) > 1 {
138
154
return
139
155
}
140
156
gopool .CtxGo (nil , func () {
141
- var negNum int32 // is negative number of triggerNum
142
- for triggerNum := atomic .LoadInt32 (& q .trigger ); triggerNum > 0 ; {
143
- q .r = (q .r + 1 ) % q .size
144
- shared := q .list [q .r ]
157
+ var negBufNum int32 // is negative number of bufNum
158
+ for q .length () > 0 {
159
+ nr := (atomic .LoadInt32 (& q .r ) + 1 ) % int32 (len (q .list ))
160
+ atomic .StoreInt32 (& q .r , nr )
161
+ shard := q .list [nr ]
145
162
146
163
// lock & swap
147
- q .lock (shared )
148
- tmp := q .getters [shared ]
149
- q .getters [shared ] = q .swap [:0 ]
164
+ q .lock (shard )
165
+ tmp := q .getters [shard ]
166
+ q .getters [shard ] = q .swap [:0 ]
150
167
q .swap = tmp
151
- q .unlock (shared )
168
+ q .unlock (shard )
152
169
153
170
// deal
154
- q .deal (q .swap )
155
- negNum --
156
- if triggerNum + negNum == 0 {
157
- triggerNum = atomic .AddInt32 (& q .trigger , negNum )
158
- negNum = 0
171
+ if err := q .deal (q .swap ); err != nil {
172
+ close (q .closeNotif )
173
+ return
174
+ }
175
+ negBufNum -= int32 (len (q .swap ))
176
+ }
177
+ if negBufNum < 0 {
178
+ if err := q .flush (); err != nil {
179
+ close (q .closeNotif )
180
+ return
159
181
}
160
182
}
161
- q .flush ()
183
+
184
+ // MUST decrease bufNum first.
185
+ if atomic .AddInt32 (& q .bufNum , negBufNum ) <= 0 && atomic .LoadInt32 (& q .state ) != active {
186
+ close (q .closeNotif )
187
+ return
188
+ }
162
189
163
190
// quit & check again
164
191
atomic .StoreInt32 (& q .runNum , 0 )
165
- if atomic . LoadInt32 ( & q . trigger ) > 0 {
192
+ if q . length ( ) > 0 {
166
193
q .foreach ()
167
194
return
168
195
}
169
- // if state is closing, change it to closed
170
- atomic .CompareAndSwapInt32 (& q .state , closing , closed )
171
196
})
172
197
}
173
198
174
199
// deal is used to get deal of netpoll.Writer.
175
- func (q * ShardQueue ) deal (gts []WriterGetter ) {
200
+ func (q * ShardQueue ) deal (gts []WriterGetter ) error {
176
201
writer := q .conn .Writer ()
177
202
for _ , gt := range gts {
178
203
buf , isNil := gt ()
179
204
if ! isNil {
180
205
err := writer .Append (buf )
181
206
if err != nil {
182
207
q .conn .Close ()
183
- return
208
+ return err
184
209
}
185
210
}
186
211
}
212
+ return nil
187
213
}
188
214
189
215
// flush is used to flush netpoll.Writer.
190
- func (q * ShardQueue ) flush () {
216
+ func (q * ShardQueue ) flush () error {
191
217
err := q .conn .Writer ().Flush ()
192
218
if err != nil {
193
219
q .conn .Close ()
194
- return
195
220
}
221
+ return err
196
222
}
197
223
198
224
// lock shard.
0 commit comments