Skip to content

Commit 358d5a8

Browse files
committed
Add support for set operations with UnionSet:
* difference, intersection, symmetric_difference, union
1 parent f0ba265 commit 358d5a8

File tree

4 files changed

+111
-0
lines changed

4 files changed

+111
-0
lines changed

rust/rust_fst.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ SetLevStream* fst_set_levsearch(Set*, Levenshtein*);
6565
SetRegexStream* fst_set_regexsearch(Set*, Regex*);
6666
SetOpBuilder* fst_set_make_opbuilder(Set*);
6767
SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*);
68+
SetOpBuilder* fst_set_make_opbuilder_union(SetUnion*);
6869
void fst_set_free(Set*);
6970

7071
char* fst_set_stream_next(SetStream*);
@@ -78,6 +79,7 @@ void fst_set_regexstream_free(SetRegexStream*);
7879

7980
void fst_set_opbuilder_push(SetOpBuilder*, Set*);
8081
void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*);
82+
void fst_set_opbuilder_push_union(SetOpBuilder*, SetUnion*);
8183
void fst_set_opbuilder_free(SetOpBuilder*);
8284
SetUnion* fst_set_opbuilder_union(SetOpBuilder*);
8385
SetIntersection* fst_set_opbuilder_intersection(SetOpBuilder*);

rust/src/set.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,14 @@ pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuil
154154
to_raw_ptr(ob)
155155
}
156156

157+
#[no_mangle]
158+
pub extern "C" fn fst_set_make_opbuilder_union(ptr: *mut set::Union<'static>) -> *mut set::OpBuilder<'static> {
159+
let union = val_from_ptr!(ptr);
160+
let mut ob = set::OpBuilder::new();
161+
ob.push(union.into_stream());
162+
to_raw_ptr(ob)
163+
}
164+
157165
#[no_mangle]
158166
pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut Set) {
159167
let set = ref_from_ptr!(set_ptr);
@@ -168,6 +176,13 @@ pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder<
168176
ob.push(sb.into_stream());
169177
}
170178

179+
#[no_mangle]
180+
pub extern "C" fn fst_set_opbuilder_push_union(ptr: *mut set::OpBuilder<'static>, union_ptr: *mut set::Union<'static>) {
181+
let union = val_from_ptr!(union_ptr);
182+
let ob = mutref_from_ptr!(ptr);
183+
ob.push(union.into_stream());
184+
}
185+
171186
#[no_mangle]
172187
pub extern "C" fn fst_set_opbuilder_union(ptr: *mut set::OpBuilder)
173188
-> *mut set::Union {

rust_fst/set.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,20 @@ def get_set(self):
5959
class OpBuilderInputType(Enum):
6060
SET = 1
6161
STREAM_BUILDER = 2
62+
UNION = 3
6263

6364

6465
class OpBuilder(object):
6566

6667
_BUILDERS = {
6768
OpBuilderInputType.SET: lib.fst_set_make_opbuilder,
6869
OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder,
70+
OpBuilderInputType.UNION: lib.fst_set_make_opbuilder_union,
6971
}
7072
_PUSHERS = {
7173
OpBuilderInputType.SET: lib.fst_set_opbuilder_push,
7274
OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder,
75+
OpBuilderInputType.UNION: lib.fst_set_opbuilder_push_union,
7376
}
7477

7578
@classmethod
@@ -452,3 +455,66 @@ def __iter__(self):
452455
for fst in self.sets[1:]:
453456
opbuilder.push(fst._ptr)
454457
return opbuilder.union()
458+
459+
def _make_opbuilder(self, *others):
460+
others = list(others)
461+
if len(self.sets) <= 1:
462+
raise ValueError(
463+
"Must have more than one set to operate on.")
464+
if not others:
465+
raise ValueError(
466+
"Must have at least one set to compare against.")
467+
our_opbuilder = OpBuilder(self.sets[0]._ptr,
468+
input_type=OpBuilderInputType.SET)
469+
for fst in self.sets[1:]:
470+
our_opbuilder.push(fst._ptr)
471+
our_stream = lib.fst_set_opbuilder_union(our_opbuilder._ptr)
472+
473+
their_opbuilder = OpBuilder(others.pop()._ptr,
474+
input_type=OpBuilderInputType.SET)
475+
for fst in others:
476+
their_opbuilder.push(fst._ptr)
477+
their_stream = lib.fst_set_opbuilder_union(their_opbuilder._ptr)
478+
479+
opbuilder = OpBuilder(our_stream, input_type=OpBuilderInputType.UNION)
480+
opbuilder.push(their_stream)
481+
return opbuilder
482+
483+
def difference(self, *others):
484+
""" Get an iterator over the keys in the difference of this set and
485+
others.
486+
487+
:param others: List of :py:class:`Set` objects
488+
:returns: Iterator over all keys that exists in this set, but in
489+
none of the other sets, in lexicographical order
490+
"""
491+
return self._make_opbuilder(*others).difference()
492+
493+
def intersection(self, *others):
494+
""" Get an iterator over the keys in the intersection of this set and
495+
others.
496+
497+
:param others: List of :py:class:`Set` objects
498+
:returns: Iterator over all keys that exists in all of the passed
499+
sets in lexicographical order
500+
"""
501+
return self._make_opbuilder(*others).intersection()
502+
503+
def symmetric_difference(self, *others):
504+
""" Get an iterator over the keys in the symmetric difference of this
505+
set and others.
506+
507+
:param others: List of :py:class:`Set` objects
508+
:returns: Iterator over all keys that exists in only one of the
509+
sets in lexicographical order
510+
"""
511+
return self._make_opbuilder(*others).symmetric_difference()
512+
513+
def union(self, *others):
514+
""" Get an iterator over the keys in the union of this set and others.
515+
516+
:param others: List of :py:class:`Set` objects
517+
:returns: Iterator over all keys in all sets in lexicographical
518+
order
519+
"""
520+
return self._make_opbuilder(*others).union()

tests/test_set.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,20 @@ def test_unionset_contains(fst_unionset):
166166
assert key in fst_unionset
167167

168168

169+
def test_unionset_difference():
170+
a = Set.from_iter(["bar", "foo"])
171+
b = Set.from_iter(["baz", "foo"])
172+
c = Set.from_iter(["bonk", "foo"])
173+
assert list(UnionSet(a, b).difference(c)) == ["bar", "baz"]
174+
175+
176+
def test_unionset_intersection():
177+
a = Set.from_iter(["bar", "foo"])
178+
b = Set.from_iter(["baz", "foo"])
179+
c = Set.from_iter(["bonk", "foo"])
180+
assert list(UnionSet(a, b).intersection(c)) == ["foo"]
181+
182+
169183
def test_unionset_iter(fst_unionset):
170184
stored_keys = list(fst_unionset)
171185
assert stored_keys == sorted(set(TEST_KEYS+TEST_KEYS2))
@@ -179,3 +193,17 @@ def test_unionset_range(fst_unionset):
179193
fst_unionset['c':'a']
180194
with pytest.raises(ValueError):
181195
fst_unionset['c']
196+
197+
198+
def test_unionset_symmetric_difference():
199+
a = Set.from_iter(["bar", "foo"])
200+
b = Set.from_iter(["baz", "foo"])
201+
c = Set.from_iter(["bonk", "foo"])
202+
assert list(UnionSet(a, b).symmetric_difference(c)) == ["bar", "baz", "bonk"]
203+
204+
205+
def test_unionset_union():
206+
a = Set.from_iter(["bar", "foo"])
207+
b = Set.from_iter(["baz", "foo"])
208+
c = Set.from_iter(["bonk", "foo"])
209+
assert list(UnionSet(a, b).union(c)) == ["bar", "baz", "bonk", "foo"]

0 commit comments

Comments
 (0)