Skip to content

Commit f0ba265

Browse files
committed
Add UnionSet class:
* Supports range queries across multiple sets: a = Set.from_iter(["bar", "foo"]) b = Set.from_iter(["baz", "foo"]) list(UnionSet(a, b)['ba':'bb']) ['bar', 'baz'] * Add StreamBuilder, to correspond with the Rust library's abstraction * Update OpBuilder to support constructing operations against multiple underlying types (Set and StreamBuilder for now)
1 parent e6991cb commit f0ba265

File tree

5 files changed

+204
-18
lines changed

5 files changed

+204
-18
lines changed

rust/rust_fst.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ SetStream* fst_set_stream(Set*);
6464
SetLevStream* fst_set_levsearch(Set*, Levenshtein*);
6565
SetRegexStream* fst_set_regexsearch(Set*, Regex*);
6666
SetOpBuilder* fst_set_make_opbuilder(Set*);
67+
SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*);
6768
void fst_set_free(Set*);
6869

6970
char* fst_set_stream_next(SetStream*);
@@ -76,6 +77,7 @@ char* fst_set_regexstream_next(SetRegexStream*);
7677
void fst_set_regexstream_free(SetRegexStream*);
7778

7879
void fst_set_opbuilder_push(SetOpBuilder*, Set*);
80+
void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*);
7981
void fst_set_opbuilder_free(SetOpBuilder*);
8082
SetUnion* fst_set_opbuilder_union(SetOpBuilder*);
8183
SetIntersection* fst_set_opbuilder_intersection(SetOpBuilder*);
@@ -97,6 +99,8 @@ void fst_set_symmetricdifference_free(SetSymmetricDifference*);
9799

98100
SetStreamBuilder* fst_set_streambuilder_new(Set*);
99101
SetStreamBuilder* fst_set_streambuilder_add_ge(SetStreamBuilder*, char*);
102+
SetStreamBuilder* fst_set_streambuilder_add_gt(SetStreamBuilder*, char*);
103+
SetStreamBuilder* fst_set_streambuilder_add_le(SetStreamBuilder*, char*);
100104
SetStreamBuilder* fst_set_streambuilder_add_lt(SetStreamBuilder*, char*);
101105
SetStream* fst_set_streambuilder_finish(SetStreamBuilder*);
102106

rust/src/set.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,28 @@ pub extern "C" fn fst_set_make_opbuilder(ptr: *mut Set) -> *mut set::OpBuilder<'
146146
}
147147
make_free_fn!(fst_set_opbuilder_free, *mut set::OpBuilder);
148148

149+
#[no_mangle]
150+
pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuilder<'static>) -> *mut set::OpBuilder<'static> {
151+
let sb = val_from_ptr!(ptr);
152+
let mut ob = set::OpBuilder::new();
153+
ob.push(sb.into_stream());
154+
to_raw_ptr(ob)
155+
}
156+
149157
#[no_mangle]
150158
pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut Set) {
151159
let set = ref_from_ptr!(set_ptr);
152160
let ob = mutref_from_ptr!(ptr);
153161
ob.push(set);
154162
}
155163

164+
#[no_mangle]
165+
pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder<'static>, sb_ptr: *mut set::StreamBuilder<'static>) {
166+
let sb = val_from_ptr!(sb_ptr);
167+
let ob = mutref_from_ptr!(ptr);
168+
ob.push(sb.into_stream());
169+
}
170+
156171
#[no_mangle]
157172
pub extern "C" fn fst_set_opbuilder_union(ptr: *mut set::OpBuilder)
158173
-> *mut set::Union {
@@ -205,6 +220,22 @@ pub extern "C" fn fst_set_streambuilder_add_ge(ptr: *mut set::StreamBuilder<'sta
205220
to_raw_ptr(sb.ge(cstr_to_str(c_bound)))
206221
}
207222

223+
#[no_mangle]
224+
pub extern "C" fn fst_set_streambuilder_add_gt(ptr: *mut set::StreamBuilder<'static>,
225+
c_bound: *mut libc::c_char)
226+
-> *mut set::StreamBuilder<'static> {
227+
let sb = val_from_ptr!(ptr);
228+
to_raw_ptr(sb.gt(cstr_to_str(c_bound)))
229+
}
230+
231+
#[no_mangle]
232+
pub extern "C" fn fst_set_streambuilder_add_le(ptr: *mut set::StreamBuilder<'static>,
233+
c_bound: *mut libc::c_char)
234+
-> *mut set::StreamBuilder<'static> {
235+
let sb = val_from_ptr!(ptr);
236+
to_raw_ptr(sb.le(cstr_to_str(c_bound)))
237+
}
238+
208239
#[no_mangle]
209240
pub extern "C" fn fst_set_streambuilder_add_lt(ptr: *mut set::StreamBuilder<'static>,
210241
c_bound: *mut libc::c_char)

rust_fst/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .set import Set
1+
from .set import Set, UnionSet
22
from .map import Map
33

4-
__all__ = ["Set", "Map"]
4+
__all__ = ["Set", "UnionSet", "Map"]

rust_fst/set.py

Lines changed: 134 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from contextlib import contextmanager
2+
from enum import Enum
23

34
from .common import KeyStreamIterator
45
from .lib import ffi, lib, checked_call
@@ -55,14 +56,40 @@ def get_set(self):
5556
return Set(None, _pointer=self._set_ptr)
5657

5758

59+
class OpBuilderInputType(Enum):
60+
SET = 1
61+
STREAM_BUILDER = 2
62+
63+
5864
class OpBuilder(object):
59-
def __init__(self, set_ptr):
65+
66+
_BUILDERS = {
67+
OpBuilderInputType.SET: lib.fst_set_make_opbuilder,
68+
OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder,
69+
}
70+
_PUSHERS = {
71+
OpBuilderInputType.SET: lib.fst_set_opbuilder_push,
72+
OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder,
73+
}
74+
75+
@classmethod
76+
def from_slice(cls, set_ptr, s):
77+
sb = StreamBuilder.from_slice(set_ptr, s)
78+
opbuilder = OpBuilder(sb._ptr,
79+
input_type=OpBuilderInputType.STREAM_BUILDER)
80+
return opbuilder
81+
82+
def __init__(self, ptr, input_type=OpBuilderInputType.SET):
83+
if input_type not in self._BUILDERS:
84+
raise ValueError(
85+
"input_type must be a member of OpBuilderInputType.")
86+
self._input_type = input_type
6087
# NOTE: No need for `ffi.gc`, since the struct will be free'd
6188
# once we call union/intersection/difference
62-
self._ptr = lib.fst_set_make_opbuilder(set_ptr)
89+
self._ptr = OpBuilder._BUILDERS[self._input_type](ptr)
6390

64-
def push(self, set_ptr):
65-
lib.fst_set_opbuilder_push(self._ptr, set_ptr)
91+
def push(self, ptr):
92+
OpBuilder._PUSHERS[self._input_type](self._ptr, ptr)
6693

6794
def union(self):
6895
stream_ptr = lib.fst_set_opbuilder_union(self._ptr)
@@ -86,6 +113,44 @@ def symmetric_difference(self):
86113
lib.fst_set_symmetricdifference_free)
87114

88115

116+
class StreamBuilder(object):
117+
118+
@classmethod
119+
def from_slice(cls, set_ptr, slice_bounds):
120+
sb = StreamBuilder(set_ptr)
121+
if slice_bounds.start:
122+
sb.ge(slice_bounds.start)
123+
if slice_bounds.stop:
124+
sb.lt(slice_bounds.stop)
125+
return sb
126+
127+
def __init__(self, set_ptr):
128+
# NOTE: No need for `ffi.gc`, since the struct will be free'd
129+
# once we call union/intersection/difference
130+
self._ptr = lib.fst_set_streambuilder_new(set_ptr)
131+
132+
def finish(self):
133+
stream_ptr = lib.fst_set_streambuilder_finish(self._ptr)
134+
return KeyStreamIterator(stream_ptr, lib.fst_set_stream_next,
135+
lib.fst_set_stream_free)
136+
137+
def ge(self, bound):
138+
c_start = ffi.new("char[]", bound.encode('utf8'))
139+
self._ptr = lib.fst_set_streambuilder_add_ge(self._ptr, c_start)
140+
141+
def gt(self, bound):
142+
c_start = ffi.new("char[]", bound.encode('utf8'))
143+
self._ptr = lib.fst_set_streambuilder_add_gt(self._ptr, c_start)
144+
145+
def le(self, bound):
146+
c_end = ffi.new("char[]", bound.encode('utf8'))
147+
self._ptr = lib.fst_set_streambuilder_add_le(self._ptr, c_end)
148+
149+
def lt(self, bound):
150+
c_end = ffi.new("char[]", bound.encode('utf8'))
151+
self._ptr = lib.fst_set_streambuilder_add_lt(self._ptr, c_end)
152+
153+
89154
class Set(object):
90155
""" An immutable ordered string set backed by a finite state transducer.
91156
@@ -203,19 +268,11 @@ def __getitem__(self, s):
203268
if s.start and s.stop and s.start > s.stop:
204269
raise ValueError(
205270
"Start key must be lexicographically smaller than stop.")
206-
sb_ptr = lib.fst_set_streambuilder_new(self._ptr)
207-
if s.start:
208-
c_start = ffi.new("char[]", s.start.encode('utf8'))
209-
sb_ptr = lib.fst_set_streambuilder_add_ge(sb_ptr, c_start)
210-
if s.stop:
211-
c_stop = ffi.new("char[]", s.stop.encode('utf8'))
212-
sb_ptr = lib.fst_set_streambuilder_add_lt(sb_ptr, c_stop)
213-
stream_ptr = lib.fst_set_streambuilder_finish(sb_ptr)
214-
return KeyStreamIterator(stream_ptr, lib.fst_set_stream_next,
215-
lib.fst_set_stream_free)
271+
sb = StreamBuilder.from_slice(self._ptr, s)
272+
return sb.finish()
216273

217274
def _make_opbuilder(self, *others):
218-
opbuilder = OpBuilder(self._ptr)
275+
opbuilder = OpBuilder(self._ptr, input_type=OpBuilderInputType.SET)
219276
for oth in others:
220277
opbuilder.push(oth._ptr)
221278
return opbuilder
@@ -333,3 +390,65 @@ def search(self, term, max_dist):
333390
return KeyStreamIterator(stream_ptr, lib.fst_set_levstream_next,
334391
lib.fst_set_levstream_free, lev_ptr,
335392
lib.fst_levenshtein_free)
393+
394+
395+
class UnionSet(object):
396+
""" A collection of Set objects that offer efficient operations across all
397+
members.
398+
"""
399+
def __init__(self, *sets):
400+
self.sets = list(sets)
401+
402+
def __contains__(self, val):
403+
""" Check if the set contains the value. """
404+
return any([
405+
lib.fst_set_contains(fst._ptr,
406+
ffi.new("char[]",
407+
val.encode('utf8')))
408+
for fst in self.sets
409+
])
410+
411+
def __getitem__(self, s):
412+
""" Get an iterator over a range of set contents.
413+
414+
Start and stop indices of the slice must be unicode strings.
415+
416+
.. important::
417+
Slicing follows the semantics for numerical indices, i.e. the
418+
`stop` value is **exclusive**. For example, given the set
419+
`s = Set.from_iter(["bar", "baz", "foo", "moo"])`, `s['b': 'f']`
420+
will only return `"bar"` and `"baz"`.
421+
422+
:param s: A slice that specifies the range of the set to retrieve
423+
:type s: :py:class:`slice`
424+
"""
425+
if not isinstance(s, slice):
426+
raise ValueError(
427+
"Value must be a string slice (e.g. `['foo':]`)")
428+
if s.start and s.stop and s.start > s.stop:
429+
raise ValueError(
430+
"Start key must be lexicographically smaller than stop.")
431+
if len(self.sets) <= 1:
432+
raise ValueError(
433+
"Must have more than one set to operate on.")
434+
435+
opbuilder = OpBuilder.from_slice(self.sets[0]._ptr, s)
436+
streams = []
437+
for fst in self.sets[1:]:
438+
sb = StreamBuilder.from_slice(fst._ptr, s)
439+
streams.append(sb)
440+
for sb in streams:
441+
opbuilder.push(sb._ptr)
442+
return opbuilder.union()
443+
444+
def __iter__(self):
445+
""" Get an iterator over all keys in all sets in lexicographical order.
446+
"""
447+
if len(self.sets) <= 1:
448+
raise ValueError(
449+
"Must have more than one set to operate on.")
450+
opbuilder = OpBuilder(self.sets[0]._ptr,
451+
input_type=OpBuilderInputType.SET)
452+
for fst in self.sets[1:]:
453+
opbuilder.push(fst._ptr)
454+
return opbuilder.union()

tests/test_set.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import pytest
33

44
import rust_fst.lib as lib
5-
from rust_fst import Set
5+
from rust_fst import Set, UnionSet
66

77

88
TEST_KEYS = [u"möö", "bar", "baz", "foo"]
9+
TEST_KEYS2 = ["bing", "baz", "bap", "foo"]
910

1011

1112
def do_build(path, keys=TEST_KEYS, sorted_=True):
@@ -21,6 +22,17 @@ def fst_set(tmpdir):
2122
return Set(fst_path)
2223

2324

25+
@pytest.fixture
26+
def fst_unionset(tmpdir):
27+
fst_path1 = str(tmpdir.join('test1.fst'))
28+
fst_path2 = str(tmpdir.join('test2.fst'))
29+
do_build(fst_path1, keys=TEST_KEYS)
30+
do_build(fst_path2, keys=TEST_KEYS2)
31+
a = Set(fst_path1)
32+
b = Set(fst_path2)
33+
return UnionSet(a, b)
34+
35+
2436
def test_build(tmpdir):
2537
fst_path = tmpdir.join('test.fst')
2638
do_build(str(fst_path))
@@ -147,3 +159,23 @@ def test_range(fst_set):
147159
fst_set['c':'a']
148160
with pytest.raises(ValueError):
149161
fst_set['c']
162+
163+
164+
def test_unionset_contains(fst_unionset):
165+
for key in TEST_KEYS+TEST_KEYS2:
166+
assert key in fst_unionset
167+
168+
169+
def test_unionset_iter(fst_unionset):
170+
stored_keys = list(fst_unionset)
171+
assert stored_keys == sorted(set(TEST_KEYS+TEST_KEYS2))
172+
173+
174+
def test_unionset_range(fst_unionset):
175+
assert list(fst_unionset['f':]) == ['foo', u'möö']
176+
assert list(fst_unionset[:'m']) == ['bap', 'bar', 'baz', 'bing', 'foo']
177+
assert list(fst_unionset['baz':'m']) == ['baz', 'bing', 'foo']
178+
with pytest.raises(ValueError):
179+
fst_unionset['c':'a']
180+
with pytest.raises(ValueError):
181+
fst_unionset['c']

0 commit comments

Comments
 (0)