Skip to content

Commit 27e063b

Browse files
Return leftovers for peek and discard
1 parent 7bf909c commit 27e063b

File tree

2 files changed

+159
-2
lines changed

2 files changed

+159
-2
lines changed

bson/buffered_value_reader.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ func (b *bufferedValueReader) ReadByte() (byte, error) {
4545
func (b *bufferedValueReader) peek(n int) ([]byte, error) {
4646
// Ensure we don't read past the end of the buffer.
4747
if int64(n)+b.offset > int64(len(b.buf)) {
48-
return nil, io.EOF
48+
return b.buf[b.offset:], io.EOF
4949
}
5050

5151
// Return the next n bytes without advancing the offset
@@ -56,7 +56,11 @@ func (b *bufferedValueReader) peek(n int) ([]byte, error) {
5656
func (b *bufferedValueReader) discard(n int) (int, error) {
5757
// Ensure we don't read past the end of the buffer.
5858
if int64(n)+b.offset > int64(len(b.buf)) {
59-
return 0, io.EOF
59+
// If we have exceeded the buffer length, discard only up to the end.
60+
left := len(b.buf) - int(b.offset)
61+
b.offset = int64(len(b.buf))
62+
63+
return left, io.EOF
6064
}
6165

6266
// Advance the read position

bson/buffered_value_reader_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// Copyright (C) MongoDB, Inc. 2025-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package bson
8+
9+
import (
10+
"bytes"
11+
"io"
12+
"testing"
13+
14+
"go.mongodb.org/mongo-driver/v2/internal/assert"
15+
"go.mongodb.org/mongo-driver/v2/internal/require"
16+
)
17+
18+
func TestBufferedvalueReader_discard(t *testing.T) {
19+
tests := []struct {
20+
name string
21+
buf []byte
22+
n int
23+
want int
24+
wantOffset int64
25+
wantErr error
26+
}{
27+
{
28+
name: "nothing",
29+
buf: bytes.Repeat([]byte("a"), 1024),
30+
n: 0,
31+
want: 0,
32+
wantOffset: 0,
33+
wantErr: nil,
34+
},
35+
{
36+
name: "amount less than buffer size",
37+
buf: bytes.Repeat([]byte("a"), 1024),
38+
n: 100,
39+
want: 100,
40+
wantOffset: 100,
41+
wantErr: nil,
42+
},
43+
{
44+
name: "amount greater than buffer size",
45+
buf: bytes.Repeat([]byte("a"), 1024),
46+
n: 10000,
47+
want: 1024,
48+
wantOffset: 1024,
49+
wantErr: io.EOF,
50+
},
51+
{
52+
name: "exact buffer size",
53+
buf: bytes.Repeat([]byte("a"), 1024),
54+
n: 1024,
55+
want: 1024,
56+
wantOffset: 1024,
57+
wantErr: nil,
58+
},
59+
{
60+
name: "from empty buffer",
61+
buf: []byte{},
62+
n: 10,
63+
want: 0,
64+
wantOffset: 0,
65+
wantErr: io.EOF,
66+
},
67+
}
68+
69+
for _, tt := range tests {
70+
t.Run(tt.name, func(t *testing.T) {
71+
reader := &bufferedValueReader{buf: tt.buf, offset: 0}
72+
n, err := reader.discard(tt.n)
73+
if tt.wantErr != nil {
74+
assert.ErrorIs(t, err, tt.wantErr, "Expected error %v, got %v", tt.wantErr, err)
75+
} else {
76+
require.NoError(t, err, "Expected no error when discarding %d bytes", tt.n)
77+
}
78+
79+
assert.Equal(t, tt.want, n, "Expected to discard %d bytes, got %d", tt.want, n)
80+
assert.Equal(t, tt.wantOffset, reader.offset, "Expected offset to be %d, got %d", tt.wantOffset, reader.offset)
81+
})
82+
}
83+
}
84+
85+
func TestBufferedvalueReader_peek(t *testing.T) {
86+
tests := []struct {
87+
name string
88+
buf []byte
89+
n int
90+
offset int64
91+
want []byte
92+
wantErr error
93+
}{
94+
{
95+
name: "nothing",
96+
buf: bytes.Repeat([]byte("a"), 1024),
97+
n: 0,
98+
want: []byte{},
99+
wantErr: nil,
100+
},
101+
{
102+
name: "amount less than buffer size",
103+
buf: bytes.Repeat([]byte("a"), 1024),
104+
n: 100,
105+
want: bytes.Repeat([]byte("a"), 100),
106+
wantErr: nil,
107+
},
108+
{
109+
name: "amount greater than buffer size",
110+
buf: bytes.Repeat([]byte("a"), 1024),
111+
n: 10000,
112+
want: bytes.Repeat([]byte("a"), 1024),
113+
wantErr: io.EOF,
114+
},
115+
{
116+
name: "exact buffer size",
117+
buf: bytes.Repeat([]byte("a"), 1024),
118+
n: 1024,
119+
want: bytes.Repeat([]byte("a"), 1024),
120+
wantErr: nil,
121+
},
122+
{
123+
name: "from empty buffer",
124+
buf: []byte{},
125+
n: 10,
126+
want: []byte{},
127+
wantErr: io.EOF,
128+
},
129+
{
130+
name: "peek with offset",
131+
buf: append(bytes.Repeat([]byte("a"), 100), bytes.Repeat([]byte("b"), 100)...),
132+
offset: 100,
133+
n: 100,
134+
want: bytes.Repeat([]byte("b"), 100),
135+
wantErr: nil,
136+
},
137+
}
138+
139+
for _, tt := range tests {
140+
t.Run(tt.name, func(t *testing.T) {
141+
reader := &bufferedValueReader{buf: tt.buf, offset: tt.offset}
142+
n, err := reader.peek(tt.n)
143+
if tt.wantErr != nil {
144+
assert.ErrorIs(t, err, tt.wantErr, "Expected error %v, got %v", tt.wantErr, err)
145+
} else {
146+
require.NoError(t, err, "Expected no error when peeking %d bytes", tt.n)
147+
}
148+
149+
assert.Equal(t, tt.want, n, "Expected to peek %d bytes, got %d", len(tt.want), len(n))
150+
assert.Equal(t, tt.offset, reader.offset, "Expected offset to be %d, got %d", tt.offset, reader.offset)
151+
})
152+
}
153+
}

0 commit comments

Comments
 (0)