Skip to content

Commit 10533c3

Browse files
josharianzx2c4
authored andcommitted
all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which to receive packets (an IPv4 source and an IPv6 source), allow the conn.Bind to specify a set of sources. Beneficial consequences: * If there's no IPv6 support on a system, conn.Bind.Open can choose not to return a receive function for it, which is simpler than tracking that state in the bind. This simplification removes existing data races from both conn.StdNetBind and bindtest.ChannelBind. * If there are more than two sources on a system, the conn.Bind no longer needs to add a separate muxing layer. Signed-off-by: Josh Bleecher Snyder <[email protected]>
1 parent 8ed83e0 commit 10533c3

File tree

7 files changed

+138
-142
lines changed

7 files changed

+138
-142
lines changed

conn/bind_linux.go

Lines changed: 54 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
5555

5656
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
5757
type LinuxSocketBind struct {
58-
sock4 int
59-
sock6 int
60-
lastMark uint32
61-
closing sync.RWMutex
58+
// mu guards sock4 and sock6 and the associated fds.
59+
// As long as someone holds mu (read or write), the associated fds are valid.
60+
mu sync.RWMutex
61+
sock4 int
62+
sock6 int
6263
}
6364

6465
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
@@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
102103
return nil, errors.New("invalid IP address")
103104
}
104105

105-
func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) {
106+
func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
107+
bind.mu.Lock()
108+
defer bind.mu.Unlock()
109+
106110
var err error
107111
var newPort uint16
108112
var tries int
109113

110114
if bind.sock4 != -1 || bind.sock6 != -1 {
111-
return 0, ErrBindAlreadyOpen
115+
return nil, 0, ErrBindAlreadyOpen
112116
}
113117

114118
originalPort := port
115119

116120
again:
117121
port = originalPort
122+
var sock4, sock6 int
118123
// Attempt ipv6 bind, update port if successful.
119-
bind.sock6, newPort, err = create6(port)
124+
sock6, newPort, err = create6(port)
120125
if err != nil {
121-
if err != syscall.EAFNOSUPPORT {
122-
return 0, err
126+
if !errors.Is(err, syscall.EAFNOSUPPORT) {
127+
return nil, 0, err
123128
}
124129
} else {
125130
port = newPort
126131
}
127132

128133
// Attempt ipv4 bind, update port if successful.
129-
bind.sock4, newPort, err = create4(port)
134+
sock4, newPort, err = create4(port)
130135
if err != nil {
131-
if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 {
132-
unix.Close(bind.sock6)
136+
if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
137+
unix.Close(sock6)
133138
tries++
134139
goto again
135140
}
136-
if err != syscall.EAFNOSUPPORT {
137-
unix.Close(bind.sock6)
138-
return 0, err
141+
if !errors.Is(err, syscall.EAFNOSUPPORT) {
142+
unix.Close(sock6)
143+
return nil, 0, err
139144
}
140145
} else {
141146
port = newPort
142147
}
143148

144-
if bind.sock4 == -1 && bind.sock6 == -1 {
145-
return 0, syscall.EAFNOSUPPORT
149+
var fns []ReceiveFunc
150+
if sock4 != -1 {
151+
fns = append(fns, makeReceiveIPv4(sock4))
152+
bind.sock4 = sock4
153+
}
154+
if sock6 != -1 {
155+
fns = append(fns, makeReceiveIPv6(sock6))
156+
bind.sock6 = sock6
157+
}
158+
if len(fns) == 0 {
159+
return nil, 0, syscall.EAFNOSUPPORT
146160
}
147-
return port, nil
161+
return fns, port, nil
148162
}
149163

150164
func (bind *LinuxSocketBind) SetMark(value uint32) error {
151-
bind.closing.RLock()
152-
defer bind.closing.RUnlock()
165+
bind.mu.RLock()
166+
defer bind.mu.RUnlock()
153167

154168
if bind.sock6 != -1 {
155169
err := unix.SetsockoptInt(
@@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
177191
}
178192
}
179193

180-
bind.lastMark = value
181194
return nil
182195
}
183196

184197
func (bind *LinuxSocketBind) Close() error {
185-
var err1, err2 error
186-
bind.closing.RLock()
198+
// Take a readlock to shut down the sockets...
199+
bind.mu.RLock()
187200
if bind.sock6 != -1 {
188201
unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
189202
}
190203
if bind.sock4 != -1 {
191204
unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
192205
}
193-
bind.closing.RUnlock()
194-
bind.closing.Lock()
206+
bind.mu.RUnlock()
207+
// ...and a write lock to close the fd.
208+
// This ensures that no one else is using the fd.
209+
bind.mu.Lock()
210+
defer bind.mu.Unlock()
211+
var err1, err2 error
195212
if bind.sock6 != -1 {
196213
err1 = unix.Close(bind.sock6)
197214
bind.sock6 = -1
@@ -200,54 +217,36 @@ func (bind *LinuxSocketBind) Close() error {
200217
err2 = unix.Close(bind.sock4)
201218
bind.sock4 = -1
202219
}
203-
bind.closing.Unlock()
204220

205221
if err1 != nil {
206222
return err1
207223
}
208224
return err2
209225
}
210226

211-
func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
212-
bind.closing.RLock()
213-
defer bind.closing.RUnlock()
214-
215-
var end LinuxSocketEndpoint
216-
if bind.sock6 == -1 {
217-
return 0, nil, net.ErrClosed
227+
func makeReceiveIPv6(sock int) ReceiveFunc {
228+
return func(buff []byte) (int, Endpoint, error) {
229+
var end LinuxSocketEndpoint
230+
n, err := receive6(sock, buff, &end)
231+
return n, &end, err
218232
}
219-
n, err := receive6(
220-
bind.sock6,
221-
buff,
222-
&end,
223-
)
224-
return n, &end, err
225233
}
226234

227-
func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
228-
bind.closing.RLock()
229-
defer bind.closing.RUnlock()
230-
231-
var end LinuxSocketEndpoint
232-
if bind.sock4 == -1 {
233-
return 0, nil, net.ErrClosed
235+
func makeReceiveIPv4(sock int) ReceiveFunc {
236+
return func(buff []byte) (int, Endpoint, error) {
237+
var end LinuxSocketEndpoint
238+
n, err := receive4(sock, buff, &end)
239+
return n, &end, err
234240
}
235-
n, err := receive4(
236-
bind.sock4,
237-
buff,
238-
&end,
239-
)
240-
return n, &end, err
241241
}
242242

243243
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
244-
bind.closing.RLock()
245-
defer bind.closing.RUnlock()
246-
247244
nend, ok := end.(*LinuxSocketEndpoint)
248245
if !ok {
249246
return ErrWrongEndpointType
250247
}
248+
bind.mu.RLock()
249+
defer bind.mu.RUnlock()
251250
if !nend.isV6 {
252251
if bind.sock4 == -1 {
253252
return net.ErrClosed

conn/bind_std.go

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package conn
88
import (
99
"errors"
1010
"net"
11+
"sync"
1112
"syscall"
1213
)
1314

@@ -16,6 +17,7 @@ import (
1617
// It uses the Go's net package to implement networking.
1718
// See LinuxSocketBind for a proper implementation on the Linux platform.
1819
type StdNetBind struct {
20+
mu sync.Mutex // protects following fields
1921
ipv4 *net.UDPConn
2022
ipv6 *net.UDPConn
2123
blackhole4 bool
@@ -81,12 +83,15 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
8183
return conn, uaddr.Port, nil
8284
}
8385

84-
func (bind *StdNetBind) Open(uport uint16) (uint16, error) {
86+
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
87+
bind.mu.Lock()
88+
defer bind.mu.Unlock()
89+
8590
var err error
8691
var tries int
8792

8893
if bind.ipv4 != nil || bind.ipv6 != nil {
89-
return 0, ErrBindAlreadyOpen
94+
return nil, 0, ErrBindAlreadyOpen
9095
}
9196

9297
// Attempt to open ipv4 and ipv6 listeners on the same port.
@@ -97,7 +102,7 @@ again:
97102

98103
ipv4, port, err = listenNet("udp4", port)
99104
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
100-
return 0, err
105+
return nil, 0, err
101106
}
102107

103108
// Listen on the same port as we're using for ipv4.
@@ -109,17 +114,27 @@ again:
109114
}
110115
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
111116
ipv4.Close()
112-
return 0, err
117+
return nil, 0, err
113118
}
114-
if ipv4 == nil && ipv6 == nil {
115-
return 0, syscall.EAFNOSUPPORT
119+
var fns []ReceiveFunc
120+
if ipv4 != nil {
121+
fns = append(fns, makeReceiveFunc(ipv4, true))
122+
bind.ipv4 = ipv4
116123
}
117-
bind.ipv4 = ipv4
118-
bind.ipv6 = ipv6
119-
return uint16(port), nil
124+
if ipv6 != nil {
125+
fns = append(fns, makeReceiveFunc(ipv6, false))
126+
bind.ipv6 = ipv6
127+
}
128+
if len(fns) == 0 {
129+
return nil, 0, syscall.EAFNOSUPPORT
130+
}
131+
return fns, uint16(port), nil
120132
}
121133

122134
func (bind *StdNetBind) Close() error {
135+
bind.mu.Lock()
136+
defer bind.mu.Unlock()
137+
123138
var err1, err2 error
124139
if bind.ipv4 != nil {
125140
err1 = bind.ipv4.Close()
@@ -137,23 +152,14 @@ func (bind *StdNetBind) Close() error {
137152
return err2
138153
}
139154

140-
func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
141-
if bind.ipv4 == nil {
142-
return 0, nil, syscall.EAFNOSUPPORT
155+
func makeReceiveFunc(conn *net.UDPConn, isIPv4 bool) ReceiveFunc {
156+
return func(buff []byte) (int, Endpoint, error) {
157+
n, endpoint, err := conn.ReadFromUDP(buff)
158+
if isIPv4 && endpoint != nil {
159+
endpoint.IP = endpoint.IP.To4()
160+
}
161+
return n, (*StdNetEndpoint)(endpoint), err
143162
}
144-
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
145-
if endpoint != nil {
146-
endpoint.IP = endpoint.IP.To4()
147-
}
148-
return n, (*StdNetEndpoint)(endpoint), err
149-
}
150-
151-
func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
152-
if bind.ipv6 == nil {
153-
return 0, nil, syscall.EAFNOSUPPORT
154-
}
155-
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
156-
return n, (*StdNetEndpoint)(endpoint), err
157163
}
158164

159165
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
@@ -162,15 +168,16 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
162168
if !ok {
163169
return ErrWrongEndpointType
164170
}
165-
var conn *net.UDPConn
166-
var blackhole bool
167-
if nend.IP.To4() != nil {
168-
blackhole = bind.blackhole4
169-
conn = bind.ipv4
170-
} else {
171+
172+
bind.mu.Lock()
173+
blackhole := bind.blackhole4
174+
conn := bind.ipv4
175+
if nend.IP.To4() == nil {
171176
blackhole = bind.blackhole6
172177
conn = bind.ipv6
173178
}
179+
bind.mu.Unlock()
180+
174181
if blackhole {
175182
return nil
176183
}

0 commit comments

Comments
 (0)