@@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
55
55
56
56
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
57
57
type LinuxSocketBind struct {
58
- sock4 int
59
- sock6 int
60
- lastMark uint32
61
- closing sync.RWMutex
58
+ // mu guards sock4 and sock6 and the associated fds.
59
+ // As long as someone holds mu (read or write), the associated fds are valid.
60
+ mu sync.RWMutex
61
+ sock4 int
62
+ sock6 int
62
63
}
63
64
64
65
func NewLinuxSocketBind () Bind { return & LinuxSocketBind {sock4 : - 1 , sock6 : - 1 } }
@@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
102
103
return nil , errors .New ("invalid IP address" )
103
104
}
104
105
105
- func (bind * LinuxSocketBind ) Open (port uint16 ) (uint16 , error ) {
106
+ func (bind * LinuxSocketBind ) Open (port uint16 ) ([]ReceiveFunc , uint16 , error ) {
107
+ bind .mu .Lock ()
108
+ defer bind .mu .Unlock ()
109
+
106
110
var err error
107
111
var newPort uint16
108
112
var tries int
109
113
110
114
if bind .sock4 != - 1 || bind .sock6 != - 1 {
111
- return 0 , ErrBindAlreadyOpen
115
+ return nil , 0 , ErrBindAlreadyOpen
112
116
}
113
117
114
118
originalPort := port
115
119
116
120
again:
117
121
port = originalPort
122
+ var sock4 , sock6 int
118
123
// Attempt ipv6 bind, update port if successful.
119
- bind . sock6 , newPort , err = create6 (port )
124
+ sock6 , newPort , err = create6 (port )
120
125
if err != nil {
121
- if err != syscall .EAFNOSUPPORT {
122
- return 0 , err
126
+ if ! errors . Is ( err , syscall .EAFNOSUPPORT ) {
127
+ return nil , 0 , err
123
128
}
124
129
} else {
125
130
port = newPort
126
131
}
127
132
128
133
// Attempt ipv4 bind, update port if successful.
129
- bind . sock4 , newPort , err = create4 (port )
134
+ sock4 , newPort , err = create4 (port )
130
135
if err != nil {
131
- if originalPort == 0 && err == syscall .EADDRINUSE && tries < 100 {
132
- unix .Close (bind . sock6 )
136
+ if originalPort == 0 && errors . Is ( err , syscall .EADDRINUSE ) && tries < 100 {
137
+ unix .Close (sock6 )
133
138
tries ++
134
139
goto again
135
140
}
136
- if err != syscall .EAFNOSUPPORT {
137
- unix .Close (bind . sock6 )
138
- return 0 , err
141
+ if ! errors . Is ( err , syscall .EAFNOSUPPORT ) {
142
+ unix .Close (sock6 )
143
+ return nil , 0 , err
139
144
}
140
145
} else {
141
146
port = newPort
142
147
}
143
148
144
- if bind .sock4 == - 1 && bind .sock6 == - 1 {
145
- return 0 , syscall .EAFNOSUPPORT
149
+ var fns []ReceiveFunc
150
+ if sock4 != - 1 {
151
+ fns = append (fns , makeReceiveIPv4 (sock4 ))
152
+ bind .sock4 = sock4
153
+ }
154
+ if sock6 != - 1 {
155
+ fns = append (fns , makeReceiveIPv6 (sock6 ))
156
+ bind .sock6 = sock6
157
+ }
158
+ if len (fns ) == 0 {
159
+ return nil , 0 , syscall .EAFNOSUPPORT
146
160
}
147
- return port , nil
161
+ return fns , port , nil
148
162
}
149
163
150
164
func (bind * LinuxSocketBind ) SetMark (value uint32 ) error {
151
- bind .closing .RLock ()
152
- defer bind .closing .RUnlock ()
165
+ bind .mu .RLock ()
166
+ defer bind .mu .RUnlock ()
153
167
154
168
if bind .sock6 != - 1 {
155
169
err := unix .SetsockoptInt (
@@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
177
191
}
178
192
}
179
193
180
- bind .lastMark = value
181
194
return nil
182
195
}
183
196
184
197
func (bind * LinuxSocketBind ) Close () error {
185
- var err1 , err2 error
186
- bind .closing .RLock ()
198
+ // Take a readlock to shut down the sockets...
199
+ bind .mu .RLock ()
187
200
if bind .sock6 != - 1 {
188
201
unix .Shutdown (bind .sock6 , unix .SHUT_RDWR )
189
202
}
190
203
if bind .sock4 != - 1 {
191
204
unix .Shutdown (bind .sock4 , unix .SHUT_RDWR )
192
205
}
193
- bind .closing .RUnlock ()
194
- bind .closing .Lock ()
206
+ bind .mu .RUnlock ()
207
+ // ...and a write lock to close the fd.
208
+ // This ensures that no one else is using the fd.
209
+ bind .mu .Lock ()
210
+ defer bind .mu .Unlock ()
211
+ var err1 , err2 error
195
212
if bind .sock6 != - 1 {
196
213
err1 = unix .Close (bind .sock6 )
197
214
bind .sock6 = - 1
@@ -200,54 +217,36 @@ func (bind *LinuxSocketBind) Close() error {
200
217
err2 = unix .Close (bind .sock4 )
201
218
bind .sock4 = - 1
202
219
}
203
- bind .closing .Unlock ()
204
220
205
221
if err1 != nil {
206
222
return err1
207
223
}
208
224
return err2
209
225
}
210
226
211
- func (bind * LinuxSocketBind ) ReceiveIPv6 (buff []byte ) (int , Endpoint , error ) {
212
- bind .closing .RLock ()
213
- defer bind .closing .RUnlock ()
214
-
215
- var end LinuxSocketEndpoint
216
- if bind .sock6 == - 1 {
217
- return 0 , nil , net .ErrClosed
227
+ func makeReceiveIPv6 (sock int ) ReceiveFunc {
228
+ return func (buff []byte ) (int , Endpoint , error ) {
229
+ var end LinuxSocketEndpoint
230
+ n , err := receive6 (sock , buff , & end )
231
+ return n , & end , err
218
232
}
219
- n , err := receive6 (
220
- bind .sock6 ,
221
- buff ,
222
- & end ,
223
- )
224
- return n , & end , err
225
233
}
226
234
227
- func (bind * LinuxSocketBind ) ReceiveIPv4 (buff []byte ) (int , Endpoint , error ) {
228
- bind .closing .RLock ()
229
- defer bind .closing .RUnlock ()
230
-
231
- var end LinuxSocketEndpoint
232
- if bind .sock4 == - 1 {
233
- return 0 , nil , net .ErrClosed
235
+ func makeReceiveIPv4 (sock int ) ReceiveFunc {
236
+ return func (buff []byte ) (int , Endpoint , error ) {
237
+ var end LinuxSocketEndpoint
238
+ n , err := receive4 (sock , buff , & end )
239
+ return n , & end , err
234
240
}
235
- n , err := receive4 (
236
- bind .sock4 ,
237
- buff ,
238
- & end ,
239
- )
240
- return n , & end , err
241
241
}
242
242
243
243
func (bind * LinuxSocketBind ) Send (buff []byte , end Endpoint ) error {
244
- bind .closing .RLock ()
245
- defer bind .closing .RUnlock ()
246
-
247
244
nend , ok := end .(* LinuxSocketEndpoint )
248
245
if ! ok {
249
246
return ErrWrongEndpointType
250
247
}
248
+ bind .mu .RLock ()
249
+ defer bind .mu .RUnlock ()
251
250
if ! nend .isV6 {
252
251
if bind .sock4 == - 1 {
253
252
return net .ErrClosed
0 commit comments