Skip to content

Commit c82a480

Browse files
committed
Add support for Levenshtein fuzzy search to UnionSet.
1 parent 358d5a8 commit c82a480

File tree

4 files changed

+57
-0
lines changed

4 files changed

+57
-0
lines changed

rust/rust_fst.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ SetLevStream* fst_set_levsearch(Set*, Levenshtein*);
6565
SetRegexStream* fst_set_regexsearch(Set*, Regex*);
6666
SetOpBuilder* fst_set_make_opbuilder(Set*);
6767
SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*);
68+
SetOpBuilder* fst_set_make_opbuilder_levstream(SetLevStream*);
6869
SetOpBuilder* fst_set_make_opbuilder_union(SetUnion*);
6970
void fst_set_free(Set*);
7071

@@ -78,6 +79,7 @@ char* fst_set_regexstream_next(SetRegexStream*);
7879
void fst_set_regexstream_free(SetRegexStream*);
7980

8081
void fst_set_opbuilder_push(SetOpBuilder*, Set*);
82+
void fst_set_opbuilder_push_levstream(SetOpBuilder*, SetLevStream*);
8183
void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*);
8284
void fst_set_opbuilder_push_union(SetOpBuilder*, SetUnion*);
8385
void fst_set_opbuilder_free(SetOpBuilder*);

rust/src/set.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,14 @@ pub extern "C" fn fst_set_make_opbuilder(ptr: *mut Set) -> *mut set::OpBuilder<'
146146
}
147147
make_free_fn!(fst_set_opbuilder_free, *mut set::OpBuilder);
148148

149+
#[no_mangle]
150+
pub extern "C" fn fst_set_make_opbuilder_levstream(ptr: *mut SetLevStream) -> *mut set::OpBuilder<'static> {
151+
let sls = val_from_ptr!(ptr);
152+
let mut ob = set::OpBuilder::new();
153+
ob.push(sls.into_stream());
154+
to_raw_ptr(ob)
155+
}
156+
149157
#[no_mangle]
150158
pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuilder<'static>) -> *mut set::OpBuilder<'static> {
151159
let sb = val_from_ptr!(ptr);
@@ -169,6 +177,13 @@ pub extern "C" fn fst_set_opbuilder_push(ptr: *mut set::OpBuilder, set_ptr: *mut
169177
ob.push(set);
170178
}
171179

180+
#[no_mangle]
181+
pub extern "C" fn fst_set_opbuilder_push_levstream(ptr: *mut set::OpBuilder<'static>, sls_ptr: *mut SetLevStream) {
182+
let sls = val_from_ptr!(sls_ptr);
183+
let ob = mutref_from_ptr!(ptr);
184+
ob.push(sls.into_stream());
185+
}
186+
172187
#[no_mangle]
173188
pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder<'static>, sb_ptr: *mut set::StreamBuilder<'static>) {
174189
let sb = val_from_ptr!(sb_ptr);

rust_fst/set.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,16 @@ class OpBuilderInputType(Enum):
6060
SET = 1
6161
STREAM_BUILDER = 2
6262
UNION = 3
63+
SEARCH = 4
64+
65+
66+
def _build_levsearch(fst, term, max_dist):
67+
lev_ptr = checked_call(
68+
lib.fst_levenshtein_new,
69+
fst._ctx,
70+
ffi.new("char[]", term.encode('utf8')),
71+
max_dist)
72+
return lib.fst_set_levsearch(fst._ptr, lev_ptr)
6373

6474

6575
class OpBuilder(object):
@@ -68,13 +78,22 @@ class OpBuilder(object):
6878
OpBuilderInputType.SET: lib.fst_set_make_opbuilder,
6979
OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder,
7080
OpBuilderInputType.UNION: lib.fst_set_make_opbuilder_union,
81+
OpBuilderInputType.SEARCH: lib.fst_set_make_opbuilder_levstream,
7182
}
7283
_PUSHERS = {
7384
OpBuilderInputType.SET: lib.fst_set_opbuilder_push,
7485
OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder,
7586
OpBuilderInputType.UNION: lib.fst_set_opbuilder_push_union,
87+
OpBuilderInputType.SEARCH: lib.fst_set_opbuilder_push_levstream,
7688
}
7789

90+
@classmethod
91+
def from_search(cls, fst, term, max_dist):
92+
stream_ptr = _build_levsearch(fst, term, max_dist)
93+
opbuilder = OpBuilder(stream_ptr,
94+
input_type=OpBuilderInputType.SEARCH)
95+
return opbuilder
96+
7897
@classmethod
7998
def from_slice(cls, set_ptr, s):
8099
sb = StreamBuilder.from_slice(set_ptr, s)
@@ -500,6 +519,22 @@ def intersection(self, *others):
500519
"""
501520
return self._make_opbuilder(*others).intersection()
502521

522+
def search(self, term, max_dist):
523+
""" Search the set with a Levenshtein automaton.
524+
525+
:param term: The search term
526+
:param max_dist: The maximum edit distance for search results
527+
:returns: Iterator over matching values in the set
528+
:rtype: :py:class:`KeyStreamIterator`
529+
"""
530+
if len(self.sets) <= 1:
531+
raise ValueError(
532+
"Must have more than one set to operate on.")
533+
opbuilder = OpBuilder.from_search(self.sets[0], term, max_dist)
534+
for fst in self.sets[1:]:
535+
opbuilder.push(_build_levsearch(fst, term, max_dist))
536+
return opbuilder.union()
537+
503538
def symmetric_difference(self, *others):
504539
""" Get an iterator over the keys in the symmetric difference of this
505540
set and others.

tests/test_set.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,11 @@ def test_unionset_range(fst_unionset):
195195
fst_unionset['c']
196196

197197

198+
def test_unionset_search(fst_unionset):
199+
matches = list(fst_unionset.search("bam", 1))
200+
assert matches == ["bap", "bar", "baz"]
201+
202+
198203
def test_unionset_symmetric_difference():
199204
a = Set.from_iter(["bar", "foo"])
200205
b = Set.from_iter(["baz", "foo"])

0 commit comments

Comments
 (0)