Skip to content

Commit e745f25

Browse files
committed
Add support for regex search to UnionSet.
1 parent c82a480 commit e745f25

File tree

4 files changed

+71
-0
lines changed

4 files changed

+71
-0
lines changed

rust/rust_fst.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ SetRegexStream* fst_set_regexsearch(Set*, Regex*);
6666
SetOpBuilder* fst_set_make_opbuilder(Set*);
6767
SetOpBuilder* fst_set_make_opbuilder_streambuilder(SetStreamBuilder*);
6868
SetOpBuilder* fst_set_make_opbuilder_levstream(SetLevStream*);
69+
SetOpBuilder* fst_set_make_opbuilder_regexstream(SetRegexStream*);
6970
SetOpBuilder* fst_set_make_opbuilder_union(SetUnion*);
7071
void fst_set_free(Set*);
7172

@@ -80,6 +81,7 @@ void fst_set_regexstream_free(SetRegexStream*);
8081

8182
void fst_set_opbuilder_push(SetOpBuilder*, Set*);
8283
void fst_set_opbuilder_push_levstream(SetOpBuilder*, SetLevStream*);
84+
void fst_set_opbuilder_push_regexstream(SetOpBuilder*, SetRegexStream*);
8385
void fst_set_opbuilder_push_streambuilder(SetOpBuilder*, SetStreamBuilder*);
8486
void fst_set_opbuilder_push_union(SetOpBuilder*, SetUnion*);
8587
void fst_set_opbuilder_free(SetOpBuilder*);

rust/src/set.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,14 @@ pub extern "C" fn fst_set_make_opbuilder_levstream(ptr: *mut SetLevStream) -> *m
154154
to_raw_ptr(ob)
155155
}
156156

157+
#[no_mangle]
158+
pub extern "C" fn fst_set_make_opbuilder_regexstream(ptr: *mut SetRegexStream) -> *mut set::OpBuilder<'static> {
159+
let srs = val_from_ptr!(ptr);
160+
let mut ob = set::OpBuilder::new();
161+
ob.push(srs.into_stream());
162+
to_raw_ptr(ob)
163+
}
164+
157165
#[no_mangle]
158166
pub extern "C" fn fst_set_make_opbuilder_streambuilder(ptr: *mut set::StreamBuilder<'static>) -> *mut set::OpBuilder<'static> {
159167
let sb = val_from_ptr!(ptr);
@@ -184,6 +192,13 @@ pub extern "C" fn fst_set_opbuilder_push_levstream(ptr: *mut set::OpBuilder<'sta
184192
ob.push(sls.into_stream());
185193
}
186194

195+
#[no_mangle]
196+
pub extern "C" fn fst_set_opbuilder_push_regexstream(ptr: *mut set::OpBuilder<'static>, srs_ptr: *mut SetRegexStream) {
197+
let srs = val_from_ptr!(srs_ptr);
198+
let ob = mutref_from_ptr!(ptr);
199+
ob.push(srs.into_stream());
200+
}
201+
187202
#[no_mangle]
188203
pub extern "C" fn fst_set_opbuilder_push_streambuilder(ptr: *mut set::OpBuilder<'static>, sb_ptr: *mut set::StreamBuilder<'static>) {
189204
let sb = val_from_ptr!(sb_ptr);

rust_fst/set.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class OpBuilderInputType(Enum):
6161
STREAM_BUILDER = 2
6262
UNION = 3
6363
SEARCH = 4
64+
SEARCH_RE = 5
6465

6566

6667
def _build_levsearch(fst, term, max_dist):
@@ -72,19 +73,28 @@ def _build_levsearch(fst, term, max_dist):
7273
return lib.fst_set_levsearch(fst._ptr, lev_ptr)
7374

7475

76+
def _build_research(fst, pattern):
77+
re_ptr = checked_call(
78+
lib.fst_regex_new, fst._ctx,
79+
ffi.new("char[]", pattern.encode('utf8')))
80+
return lib.fst_set_regexsearch(fst._ptr, re_ptr)
81+
82+
7583
class OpBuilder(object):
7684

7785
_BUILDERS = {
7886
OpBuilderInputType.SET: lib.fst_set_make_opbuilder,
7987
OpBuilderInputType.STREAM_BUILDER: lib.fst_set_make_opbuilder_streambuilder,
8088
OpBuilderInputType.UNION: lib.fst_set_make_opbuilder_union,
8189
OpBuilderInputType.SEARCH: lib.fst_set_make_opbuilder_levstream,
90+
OpBuilderInputType.SEARCH_RE: lib.fst_set_make_opbuilder_regexstream,
8291
}
8392
_PUSHERS = {
8493
OpBuilderInputType.SET: lib.fst_set_opbuilder_push,
8594
OpBuilderInputType.STREAM_BUILDER: lib.fst_set_opbuilder_push_streambuilder,
8695
OpBuilderInputType.UNION: lib.fst_set_opbuilder_push_union,
8796
OpBuilderInputType.SEARCH: lib.fst_set_opbuilder_push_levstream,
97+
OpBuilderInputType.SEARCH_RE: lib.fst_set_opbuilder_push_regexstream,
8898
}
8999

90100
@classmethod
@@ -94,6 +104,13 @@ def from_search(cls, fst, term, max_dist):
94104
input_type=OpBuilderInputType.SEARCH)
95105
return opbuilder
96106

107+
@classmethod
108+
def from_search_re(cls, fst, pattern):
109+
stream_ptr = _build_research(fst, pattern)
110+
opbuilder = OpBuilder(stream_ptr,
111+
input_type=OpBuilderInputType.SEARCH_RE)
112+
return opbuilder
113+
97114
@classmethod
98115
def from_slice(cls, set_ptr, s):
99116
sb = StreamBuilder.from_slice(set_ptr, s)
@@ -535,6 +552,38 @@ def search(self, term, max_dist):
535552
opbuilder.push(_build_levsearch(fst, term, max_dist))
536553
return opbuilder.union()
537554

555+
def search_re(self, pattern):
556+
""" Search the set with a regular expression.
557+
558+
Note that the regular expression syntax is not Python's, but the one
559+
supported by the `regex` Rust crate, which is almost identical
560+
to the engine of the RE2 engine.
561+
562+
For a documentation of the syntax, see:
563+
http://doc.rust-lang.org/regex/regex/index.html#syntax
564+
565+
Due to limitations of the underlying FST, only a subset of this syntax
566+
is supported. Most notably absent are:
567+
568+
* Lazy quantifiers (``r'*?'``, ``r'+?'``)
569+
* Word boundaries (``r'\\b'``)
570+
* Other zero-width assertions (``r'^'``, ``r'$'``)
571+
572+
For background on these limitations, consult the documentation of
573+
the Rust crate: http://burntsushi.net/rustdoc/fst/struct.Regex.html
574+
575+
:param pattern: A regular expression
576+
:returns: An iterator over all matching keys in the set
577+
:rtype: :py:class:`KeyStreamIterator`
578+
"""
579+
if len(self.sets) <= 1:
580+
raise ValueError(
581+
"Must have more than one set to operate on.")
582+
opbuilder = OpBuilder.from_search_re(self.sets[0], pattern)
583+
for fst in self.sets[1:]:
584+
opbuilder.push(_build_research(fst, pattern))
585+
return opbuilder.union()
586+
538587
def symmetric_difference(self, *others):
539588
""" Get an iterator over the keys in the symmetric difference of this
540589
set and others.

tests/test_set.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ def test_unionset_search(fst_unionset):
200200
assert matches == ["bap", "bar", "baz"]
201201

202202

203+
def test_unionset_search_re(fst_unionset):
204+
matches = list(fst_unionset.search_re(r'ba.*'))
205+
assert matches == ["bap", "bar", "baz"]
206+
207+
203208
def test_unionset_symmetric_difference():
204209
a = Set.from_iter(["bar", "foo"])
205210
b = Set.from_iter(["baz", "foo"])

0 commit comments

Comments
 (0)