@@ -17,8 +17,8 @@ package mux
17
17
import (
18
18
"fmt"
19
19
"runtime"
20
- "sync"
21
20
"sync/atomic"
21
+ "time"
22
22
23
23
"github.com/bytedance/gopkg/util/gopool"
24
24
@@ -43,16 +43,18 @@ func init() {
43
43
// NewShardQueue .
44
44
func NewShardQueue (size int , conn netpoll.Connection ) (queue * ShardQueue ) {
45
45
queue = & ShardQueue {
46
- conn : conn ,
47
- size : int32 (size ),
48
- getters : make ([][]WriterGetter , size ),
49
- swap : make ([]WriterGetter , 0 , 64 ),
50
- locks : make ([]int32 , size ),
46
+ conn : conn ,
47
+ size : int32 (size ),
48
+ getters : make ([][]WriterGetter , size ),
49
+ swap : make ([]WriterGetter , 0 , 64 ),
50
+ locks : make ([]int32 , size ),
51
+ closeNotif : make (chan struct {}),
51
52
}
52
53
for i := range queue .getters {
53
54
queue .getters [i ] = make ([]WriterGetter , 0 , 64 )
54
55
}
55
- queue .list = make ([]int32 , size )
56
+ // To avoid w equals to r when loop writing, make list larger than size.
57
+ queue .list = make ([]int32 , size + 1 )
56
58
return queue
57
59
}
58
60
@@ -69,6 +71,8 @@ type ShardQueue struct {
69
71
getters [][]WriterGetter // len(getters) = size
70
72
swap []WriterGetter // use for swap
71
73
locks []int32 // len(locks) = size
74
+
75
+ closeNotif chan struct {}
72
76
queueTrigger
73
77
}
74
78
@@ -81,17 +85,29 @@ const (
81
85
82
86
// here for trigger
83
87
type queueTrigger struct {
84
- trigger int32
85
- state int32 // 0: active, 1: closing, 2: closed
86
- runNum int32
87
- w , r int32 // ptr of list
88
- list []int32 // record the triggered shard
89
- listLock sync.Mutex // list total lock
88
+ bufNum int32
89
+ state int32 // 0: active, 1: closing, 2: closed
90
+ runNum int32
91
+ w , r int32 // ptr of list
92
+ list []int32 // record the triggered shard
93
+ }
94
+
95
+ func (q * queueTrigger ) length () int {
96
+ w := int (atomic .LoadInt32 (& q .w ))
97
+ r := int (atomic .LoadInt32 (& q .r ))
98
+ if w < r {
99
+ w += len (q .list )
100
+ }
101
+ return w - r
90
102
}
91
103
92
104
// Add adds to q.getters[shard]
93
105
func (q * ShardQueue ) Add (gts ... WriterGetter ) {
106
+ atomic .AddInt32 (& q .bufNum , 1 )
94
107
if atomic .LoadInt32 (& q .state ) != active {
108
+ if atomic .AddInt32 (& q .bufNum , - 1 ) <= 0 {
109
+ close (q .closeNotif )
110
+ }
95
111
return
96
112
}
97
113
shard := atomic .AddInt32 (& q .idx , 1 ) % q .size
@@ -109,90 +125,106 @@ func (q *ShardQueue) Close() error {
109
125
return fmt .Errorf ("shardQueue has been closed" )
110
126
}
111
127
// wait for all tasks finished
112
- for atomic .LoadInt32 (& q .state ) != closed {
113
- if atomic .LoadInt32 (& q .trigger ) == 0 {
128
+ if atomic .LoadInt32 (& q .bufNum ) == 0 {
129
+ atomic .StoreInt32 (& q .state , closed )
130
+ } else {
131
+ timeout := time .NewTimer (3 * time .Second )
132
+ select {
133
+ case <- q .closeNotif :
114
134
atomic .StoreInt32 (& q .state , closed )
115
- return nil
135
+ timeout .Stop ()
136
+ case <- timeout .C :
137
+ atomic .StoreInt32 (& q .state , closed )
138
+ return fmt .Errorf ("shardQueue close timeout" )
116
139
}
117
- runtime .Gosched ()
118
140
}
119
141
return nil
120
142
}
121
143
122
144
// triggering shard.
123
145
func (q * ShardQueue ) triggering (shard int32 ) {
124
- q . listLock . Lock ()
125
- q . w = ( q .w + 1 ) % q . size
126
- q . list [ q . w ] = shard
127
- q . listLock . Unlock ()
128
-
129
- if atomic . AddInt32 ( & q . trigger , 1 ) > 1 {
130
- return
146
+ for {
147
+ ow := atomic . LoadInt32 ( & q .w )
148
+ nw := ( ow + 1 ) % int32 ( len ( q . list ))
149
+ if atomic . CompareAndSwapInt32 ( & q . w , ow , nw ) {
150
+ q . list [ nw ] = shard
151
+ break
152
+ }
131
153
}
132
154
q .foreach ()
133
155
}
134
156
135
- // foreach swap r & w. It's not concurrency safe.
157
+ // foreach swap r & w.
136
158
func (q * ShardQueue ) foreach () {
137
159
if atomic .AddInt32 (& q .runNum , 1 ) > 1 {
138
160
return
139
161
}
140
162
gopool .CtxGo (nil , func () {
141
- var negNum int32 // is negative number of triggerNum
142
- for triggerNum := atomic .LoadInt32 (& q .trigger ); triggerNum > 0 ; {
143
- q .r = (q .r + 1 ) % q .size
144
- shared := q .list [q .r ]
163
+ var negBufNum int32 // is negative number of bufNum
164
+ for q .length () > 0 {
165
+ nr := (atomic .LoadInt32 (& q .r ) + 1 ) % int32 (len (q .list ))
166
+ atomic .StoreInt32 (& q .r , nr )
167
+ shard := q .list [nr ]
145
168
146
169
// lock & swap
147
- q .lock (shared )
148
- tmp := q .getters [shared ]
149
- q .getters [shared ] = q .swap [:0 ]
170
+ q .lock (shard )
171
+ tmp := q .getters [shard ]
172
+ q .getters [shard ] = q .swap [:0 ]
150
173
q .swap = tmp
151
- q .unlock (shared )
174
+ q .unlock (shard )
152
175
153
176
// deal
154
- q .deal (q .swap )
155
- negNum --
156
- if triggerNum + negNum == 0 {
157
- triggerNum = atomic .AddInt32 (& q .trigger , negNum )
158
- negNum = 0
177
+ if err := q .deal (q .swap ); err != nil {
178
+ close (q .closeNotif )
179
+ return
159
180
}
181
+ negBufNum -= int32 (len (q .swap ))
182
+ }
183
+ if negBufNum < 0 {
184
+ if err := q .flush (); err != nil {
185
+ close (q .closeNotif )
186
+ return
187
+ }
188
+ }
189
+
190
+ // MUST decrease bufNum first.
191
+ if atomic .AddInt32 (& q .bufNum , negBufNum ) <= 0 && atomic .LoadInt32 (& q .state ) != active {
192
+ close (q .closeNotif )
193
+ return
160
194
}
161
- q .flush ()
162
195
163
196
// quit & check again
164
197
atomic .StoreInt32 (& q .runNum , 0 )
165
- if atomic . LoadInt32 ( & q . trigger ) > 0 {
198
+ if q . length ( ) > 0 {
166
199
q .foreach ()
167
200
return
168
201
}
169
- // if state is closing, change it to closed
170
- atomic .CompareAndSwapInt32 (& q .state , closing , closed )
171
202
})
172
203
}
173
204
174
205
// deal is used to get deal of netpoll.Writer.
175
- func (q * ShardQueue ) deal (gts []WriterGetter ) {
206
+ func (q * ShardQueue ) deal (gts []WriterGetter ) error {
176
207
writer := q .conn .Writer ()
177
208
for _ , gt := range gts {
178
209
buf , isNil := gt ()
179
210
if ! isNil {
180
211
err := writer .Append (buf )
181
212
if err != nil {
182
213
q .conn .Close ()
183
- return
214
+ return err
184
215
}
185
216
}
186
217
}
218
+ return nil
187
219
}
188
220
189
221
// flush is used to flush netpoll.Writer.
190
- func (q * ShardQueue ) flush () {
222
+ func (q * ShardQueue ) flush () error {
191
223
err := q .conn .Writer ().Flush ()
192
224
if err != nil {
193
225
q .conn .Close ()
194
- return
195
226
}
227
+ return err
196
228
}
197
229
198
230
// lock shard.
0 commit comments