Skip to content

Commit fc56243

Browse files
authored
feat: add Scanner implementation with Iter and Iter2 methods (#861)
* feat: add Scanner implementation with Iter and Iter2 methods * feat: implement Scanner with Iter and Iter2 methods; add corresponding tests in helper.go
1 parent f39a872 commit fc56243

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed

helper.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package rueidis
33
import (
44
"context"
55
"errors"
6+
"iter"
67
"time"
78

89
intl "github.com/redis/rueidis/internal/cmds"
@@ -279,3 +280,43 @@ func arrayToKV(m map[string]RedisMessage, arr []RedisMessage, keys []string) map
279280
// ErrMSetNXNotSet is used in the MSetNX helper when the underlying MSETNX response is 0.
280281
// Ref: https://redis.io/commands/msetnx/
281282
var ErrMSetNXNotSet = errors.New("MSETNX: no key was set")
283+
284+
type Scanner struct {
285+
next func(cursor uint64) (ScanEntry, error)
286+
err error
287+
}
288+
289+
func NewScanner(next func(cursor uint64) (ScanEntry, error)) *Scanner {
290+
return &Scanner{next: next}
291+
}
292+
293+
func (s *Scanner) scan() iter.Seq[[]string] {
294+
return func(yield func([]string) bool) {
295+
var e ScanEntry
296+
for e, s.err = s.next(0); s.err == nil && yield(e.Elements) && e.Cursor != 0; {
297+
e, s.err = s.next(e.Cursor)
298+
}
299+
}
300+
}
301+
302+
func (s *Scanner) Iter() iter.Seq[string] {
303+
return func(yield func(string) bool) {
304+
for vs := range s.scan() {
305+
for i := 0; i < len(vs) && yield(vs[i]); i++ {
306+
}
307+
}
308+
}
309+
}
310+
311+
func (s *Scanner) Iter2() iter.Seq2[string, string] {
312+
return func(yield func(string, string) bool) {
313+
for vs := range s.scan() {
314+
for i := 0; i+1 < len(vs) && yield(vs[i], vs[i+1]); i += 2 {
315+
}
316+
}
317+
}
318+
}
319+
320+
func (s *Scanner) Err() error {
321+
return s.err
322+
}

helper_test.go

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package rueidis
22

33
import (
44
"context"
5+
"errors"
56
"reflect"
67
"strconv"
78
"testing"
@@ -1461,3 +1462,165 @@ func TestDecodeSliceOfJSON(t *testing.T) {
14611462
}
14621463
})
14631464
}
1465+
1466+
func TestScannerIter(t *testing.T) {
1467+
tests := []struct {
1468+
name string
1469+
entries []ScanEntry
1470+
err error
1471+
expected []string
1472+
wantErr bool
1473+
}{
1474+
{
1475+
name: "single page",
1476+
entries: []ScanEntry{
1477+
{Elements: []string{"key1", "key2", "key3"}, Cursor: 0},
1478+
},
1479+
expected: []string{"key1", "key2", "key3"},
1480+
},
1481+
{
1482+
name: "multiple pages",
1483+
entries: []ScanEntry{
1484+
{Elements: []string{"key1", "key2"}, Cursor: 10},
1485+
{Elements: []string{"key3", "key4"}, Cursor: 0},
1486+
},
1487+
expected: []string{"key1", "key2", "key3", "key4"},
1488+
},
1489+
{
1490+
name: "empty result",
1491+
entries: []ScanEntry{
1492+
{Elements: []string{}, Cursor: 0},
1493+
},
1494+
expected: []string{},
1495+
},
1496+
{
1497+
name: "error case",
1498+
err: errors.New("scan error"),
1499+
wantErr: true,
1500+
},
1501+
}
1502+
1503+
for _, tt := range tests {
1504+
t.Run(tt.name, func(t *testing.T) {
1505+
callCount := 0
1506+
scanner := NewScanner(func(cursor uint64) (ScanEntry, error) {
1507+
if tt.err != nil {
1508+
return ScanEntry{}, tt.err
1509+
}
1510+
if callCount >= len(tt.entries) {
1511+
return ScanEntry{}, errors.New("unexpected call")
1512+
}
1513+
entry := tt.entries[callCount]
1514+
callCount++
1515+
return entry, nil
1516+
})
1517+
1518+
var result []string
1519+
for element := range scanner.Iter() {
1520+
result = append(result, element)
1521+
}
1522+
1523+
if tt.wantErr {
1524+
if scanner.Err() == nil {
1525+
t.Error("expected error but got none")
1526+
}
1527+
} else {
1528+
if scanner.Err() != nil {
1529+
t.Errorf("unexpected error: %v", scanner.Err())
1530+
}
1531+
if (len(result) != 0 || len(tt.expected) != 0) && !reflect.DeepEqual(result, tt.expected) {
1532+
t.Errorf("got %v, want %v", result, tt.expected)
1533+
}
1534+
}
1535+
})
1536+
}
1537+
}
1538+
1539+
func TestScannerIter2(t *testing.T) {
1540+
tests := []struct {
1541+
name string
1542+
entries []ScanEntry
1543+
err error
1544+
expectedKeys []string
1545+
expectedVals []string
1546+
wantErr bool
1547+
}{
1548+
{
1549+
name: "single page pairs",
1550+
entries: []ScanEntry{
1551+
{Elements: []string{"field1", "value1", "field2", "value2"}, Cursor: 0},
1552+
},
1553+
expectedKeys: []string{"field1", "field2"},
1554+
expectedVals: []string{"value1", "value2"},
1555+
},
1556+
{
1557+
name: "multiple pages pairs",
1558+
entries: []ScanEntry{
1559+
{Elements: []string{"field1", "value1"}, Cursor: 10},
1560+
{Elements: []string{"field2", "value2", "field3", "value3"}, Cursor: 0},
1561+
},
1562+
expectedKeys: []string{"field1", "field2", "field3"},
1563+
expectedVals: []string{"value1", "value2", "value3"},
1564+
},
1565+
{
1566+
name: "odd number of elements",
1567+
entries: []ScanEntry{
1568+
{Elements: []string{"field1", "value1", "field2"}, Cursor: 0},
1569+
},
1570+
expectedKeys: []string{"field1"},
1571+
expectedVals: []string{"value1"},
1572+
},
1573+
{
1574+
name: "empty result",
1575+
entries: []ScanEntry{
1576+
{Elements: []string{}, Cursor: 0},
1577+
},
1578+
expectedKeys: []string{},
1579+
expectedVals: []string{},
1580+
},
1581+
{
1582+
name: "error case",
1583+
err: errors.New("scan error"),
1584+
wantErr: true,
1585+
},
1586+
}
1587+
1588+
for _, tt := range tests {
1589+
t.Run(tt.name, func(t *testing.T) {
1590+
callCount := 0
1591+
scanner := NewScanner(func(cursor uint64) (ScanEntry, error) {
1592+
if tt.err != nil {
1593+
return ScanEntry{}, tt.err
1594+
}
1595+
if callCount >= len(tt.entries) {
1596+
return ScanEntry{}, errors.New("unexpected call")
1597+
}
1598+
entry := tt.entries[callCount]
1599+
callCount++
1600+
return entry, nil
1601+
})
1602+
1603+
var resultKeys, resultVals []string
1604+
for key, val := range scanner.Iter2() {
1605+
resultKeys = append(resultKeys, key)
1606+
resultVals = append(resultVals, val)
1607+
}
1608+
1609+
if tt.wantErr {
1610+
if scanner.Err() == nil {
1611+
t.Error("expected error but got none")
1612+
}
1613+
} else {
1614+
if scanner.Err() != nil {
1615+
t.Errorf("unexpected error: %v", scanner.Err())
1616+
}
1617+
if (len(resultKeys) != 0 || len(tt.expectedKeys) != 0) && !reflect.DeepEqual(resultKeys, tt.expectedKeys) {
1618+
t.Errorf("keys: got %v, want %v", resultKeys, tt.expectedKeys)
1619+
}
1620+
if (len(resultVals) != 0 || len(tt.expectedVals) != 0) && !reflect.DeepEqual(resultVals, tt.expectedVals) {
1621+
t.Errorf("values: got %v, want %v", resultVals, tt.expectedVals)
1622+
}
1623+
}
1624+
})
1625+
}
1626+
}

0 commit comments

Comments
 (0)