diff --git a/hashmail_server.go b/hashmail_server.go index 0c03276..bbb3e46 100644 --- a/hashmail_server.go +++ b/hashmail_server.go @@ -39,6 +39,12 @@ const ( // reads for it to be considered for pruning. Otherwise, memory will grow // unbounded. streamTTL = 24 * time.Hour + + // streamAcquireTimeout determines how long we wait for a read/write + // stream to become available before reporting it as occupied. Context + // cancellation is still honoured immediately, so callers can shorten + // the wait. + streamAcquireTimeout = 250 * time.Millisecond ) // streamID is the identifier of a stream. @@ -317,7 +323,14 @@ func (s *stream) RequestReadStream(ctx context.Context) (*readStream, error) { case r := <-s.readStreamChan: s.status.streamTaken(true) return r, nil - default: + + case <-s.quit: + return nil, fmt.Errorf("stream shutting down") + + case <-ctx.Done(): + return nil, ctx.Err() + + case <-time.After(streamAcquireTimeout): return nil, fmt.Errorf("read stream occupied") } } @@ -332,7 +345,14 @@ func (s *stream) RequestWriteStream(ctx context.Context) (*writeStream, error) { case w := <-s.writeStreamChan: s.status.streamTaken(false) return w, nil - default: + + case <-s.quit: + return nil, fmt.Errorf("stream shutting down") + + case <-ctx.Done(): + return nil, ctx.Err() + + case <-time.After(streamAcquireTimeout): return nil, fmt.Errorf("write stream occupied") } } diff --git a/hashmail_server_test.go b/hashmail_server_test.go index d4a4848..a86aabf 100644 --- a/hashmail_server_test.go +++ b/hashmail_server_test.go @@ -188,6 +188,10 @@ func setupAperture(t *testing.T) { errChan := make(chan error) shutdown := make(chan struct{}) require.NoError(t, aperture.Start(errChan, shutdown)) + t.Cleanup(func() { + close(shutdown) + require.NoError(t, aperture.Stop()) + }) // Any error while starting? select { @@ -508,7 +512,10 @@ func recvFromStream(client hashmailrpc.HashMailClient) error { }() select { - case <-time.After(time.Second): + // Wait a little longer than the server's stream-acquire timeout so we + // only trip this path when the server truly failed to hand over the + // stream (instead of beating it to the punch). + case <-time.After(2 * streamAcquireTimeout): return fmt.Errorf("timed out waiting to receive from receive " + "stream")