Skip to content

Commit 66d9f18

Browse files
authored
Merge pull request #621 from djs55/timeout
go: allow concurrent Write calls
2 parents 4786234 + 0b3746d commit 66d9f18

File tree

1 file changed

+42
-48
lines changed

1 file changed

+42
-48
lines changed

go/pkg/libproxy/multiplexed.go

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,15 @@ func (c *channel) sendWindowUpdate() error {
100100
c.read.advance()
101101
seq := c.read.allowed
102102
c.m.Unlock()
103-
return c.multiplexer.send(NewWindow(c.ID, seq))
103+
return c.multiplexer.send(NewWindow(c.ID, seq), nil)
104104
}
105105

106106
func (c *channel) recvWindowUpdate(seq uint64) {
107107
c.m.Lock()
108108
c.write.allowed = seq
109-
c.c.Signal()
109+
// net.Conn says: Multiple goroutines may invoke methods on a Conn simultaneously.
110+
// Therefore there can be multiple goroutines blocked in Write, so when the window opens we should wake them all up.
111+
c.c.Broadcast()
110112
c.m.Unlock()
111113
}
112114

@@ -148,56 +150,40 @@ func (c *channel) Write(p []byte) (int, error) {
148150
c.write.current = c.write.current + uint64(toWrite)
149151

150152
// Don't block holding the metadata mutex.
151-
// Note this would allow concurrent calls to Write on the same channel
152-
// to conflict, but we regard that as user error.
153153
c.m.Unlock()
154-
155-
// need to write the header and the payload together
156-
c.multiplexer.writeMutex.Lock()
157-
f := NewData(c.ID, uint32(toWrite))
158-
c.multiplexer.appendEvent(&event{eventType: eventSend, frame: f})
159-
err1 := f.Write(c.multiplexer.connW)
160-
_, err2 := c.multiplexer.connW.Write(p[0:toWrite])
161-
err3 := c.multiplexer.connW.Flush()
162-
c.multiplexer.writeMutex.Unlock()
163-
154+
err := c.multiplexer.send(NewData(c.ID, uint32(toWrite)), p[0:toWrite])
164155
c.m.Lock()
165-
if err1 != nil {
166-
return written, err1
167-
}
168-
if err2 != nil {
169-
return written, err2
170-
}
171-
if err3 != nil {
172-
return written, err3
156+
157+
if err != nil {
158+
return written, err
173159
}
174160
p = p[toWrite:]
175161
written = written + toWrite
176162
continue
177163
}
178164

179-
// Wait for the write window to be increased (or a timeout)
180-
done := make(chan struct{})
181-
timeout := make(chan time.Time)
165+
// If the client has set a deadline then create a timer:
166+
var (
167+
timer *time.Timer
168+
timeOut bool
169+
)
182170
if !c.writeDeadline.IsZero() {
183-
go func() {
184-
time.Sleep(time.Until(c.writeDeadline))
185-
close(timeout)
186-
}()
171+
timer = time.AfterFunc(time.Until(c.writeDeadline), func() {
172+
c.m.Lock()
173+
defer c.m.Unlock()
174+
timeOut = true
175+
c.c.Broadcast()
176+
})
177+
}
178+
179+
// Wait for the write window to be increased or a timeout
180+
c.c.Wait()
181+
182+
if timer != nil {
183+
timer.Stop()
187184
}
188-
go func() {
189-
c.c.Wait()
190-
close(done)
191-
}()
192-
select {
193-
case <-timeout:
194-
// clean up the goroutine
195-
c.c.Broadcast()
196-
<-done
185+
if timeOut {
197186
return written, &errTimeout{}
198-
case <-done:
199-
// The timeout will still fire in the background
200-
continue
201187
}
202188
}
203189
}
@@ -213,7 +199,7 @@ func (c *channel) Close() error {
213199
if alreadyClosed {
214200
return nil
215201
}
216-
if err := c.multiplexer.send(NewClose(c.ID)); err != nil {
202+
if err := c.multiplexer.send(NewClose(c.ID), nil); err != nil {
217203
return err
218204
}
219205
c.m.Lock()
@@ -240,7 +226,7 @@ func (c *channel) CloseWrite() error {
240226
if alreadyShutdown {
241227
return nil
242228
}
243-
if err := c.multiplexer.send(NewShutdown(c.ID)); err != nil {
229+
if err := c.multiplexer.send(NewShutdown(c.ID), nil); err != nil {
244230
return err
245231
}
246232
c.m.Lock()
@@ -428,14 +414,22 @@ func (m *multiplexer) appendEvent(e *event) {
428414
m.events = m.events.Next()
429415
}
430416

431-
func (m *multiplexer) send(f *Frame) error {
417+
// send a frame (header) plus optional payload. If this call fails then the multiplexed connection will be desynchronised.
418+
func (m *multiplexer) send(f *Frame, payload []byte) error {
432419
m.writeMutex.Lock()
433420
defer m.writeMutex.Unlock()
421+
m.appendEvent(&event{eventType: eventSend, frame: f})
422+
434423
if err := f.Write(m.connW); err != nil {
435-
return err
424+
return fmt.Errorf("writing frame %s: %w", f, err)
436425
}
437-
m.appendEvent(&event{eventType: eventSend, frame: f})
438-
return m.connW.Flush()
426+
if n, err := m.connW.Write(payload); err != nil || n != len(payload) {
427+
return fmt.Errorf("writing frame %s payload length %d: %d, %w", f, len(payload), n, err)
428+
}
429+
if err := m.connW.Flush(); err != nil {
430+
return fmt.Errorf("flushing frame %s: %w", f, err)
431+
}
432+
return nil
439433
}
440434

441435
func (m *multiplexer) findFreeChannelID() uint32 {
@@ -485,7 +479,7 @@ func (m *multiplexer) Dial(d Destination) (MultiplexedConn, error) {
485479
m.channels[id] = channel
486480
m.metadataMutex.Unlock()
487481

488-
if err := m.send(NewOpen(id, d)); err != nil {
482+
if err := m.send(NewOpen(id, d), nil); err != nil {
489483
return nil, err
490484
}
491485
if err := channel.sendWindowUpdate(); err != nil {

0 commit comments

Comments
 (0)