Skip to content

Commit ef6268e

Browse files
Avoid peeking on unbuffered stream
1 parent dff53d0 commit ef6268e

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

bson/buffered_value_reader.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ type bufferedValueReader struct {
2121

2222
var _ valueReaderByteSrc = (*bufferedValueReader)(nil)
2323

24+
// Read reads up to len(p) bytes from the in-memory buffer, advancing the offset
25+
// by the number of bytes read.
26+
func (b *bufferedValueReader) Read(p []byte) (int, error) {
27+
if b.offset >= int64(len(b.buf)) {
28+
return 0, io.EOF
29+
}
30+
n := copy(p, b.buf[b.offset:])
31+
b.offset += int64(n)
32+
return n, nil
33+
}
34+
2435
// ReadByte returns the single byte at buf[offset] and advances offset by 1.
2536
func (b *bufferedValueReader) ReadByte() (byte, error) {
2637
if b.offset >= int64(len(b.buf)) {

bson/streaming_value_reader.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ type streamingValueReader struct {
2323

2424
var _ valueReaderByteSrc = (*streamingValueReader)(nil)
2525

26+
// Read reads up to len(p) bytes from the underlying bufio.Reader, advancing
27+
// the offset by the number of bytes read.
28+
func (s *streamingValueReader) Read(p []byte) (int, error) {
29+
n, err := s.br.Read(p)
30+
s.offset += int64(n)
31+
return n, err
32+
}
33+
2634
// ReadByte returns the single byte at buf[offset] and advances offset by 1.
2735
func (s *streamingValueReader) ReadByte() (byte, error) {
2836
c, err := s.br.ReadByte()

bson/value_reader.go

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ var vrPool = sync.Pool{
3535
}
3636

3737
type valueReaderByteSrc interface {
38+
io.Reader
3839
io.ByteReader
3940

4041
// Peek returns the next n bytes without advancing the cursor. It must return
@@ -258,18 +259,26 @@ func (vr *valueReader) readBytes(length int32) ([]byte, error) {
258259
return nil, fmt.Errorf("invalid length: %d", length)
259260
}
260261

261-
// Peek the next length bytes.
262-
buf, err := vr.src.peek(int(length))
263-
if err != nil {
264-
return nil, err
262+
// If we can peek and discard the bytes, we can avoid an allocation.
263+
if buf, err := vr.src.peek(int(length)); err == nil {
264+
_, _ = vr.src.discard(int(length)) // Discard the bytes from the source.
265+
return buf, nil
265266
}
266267

267-
// Advance the cursor past those bytes.
268-
if _, err := vr.src.discard(int(length)); err != nil {
268+
data := make([]byte, length)
269+
reader, ok := vr.src.(io.Reader)
270+
if !ok {
271+
return nil, fmt.Errorf("source does not implement io.Reader: %T", vr.src)
272+
}
273+
274+
if _, err := io.ReadFull(reader, data); err != nil {
275+
if errors.Is(err, io.ErrUnexpectedEOF) {
276+
err = io.EOF // Convert io.ErrUnexpectedEOF to io.EOF for consistency.
277+
}
269278
return nil, err
270279
}
271280

272-
return buf, nil
281+
return data, nil
273282
}
274283

275284
func (vr *valueReader) typeError(t Type) error {

0 commit comments

Comments
 (0)