1
1
from contextlib import contextmanager
2
+ from enum import Enum
2
3
3
4
from .common import KeyStreamIterator
4
5
from .lib import ffi , lib , checked_call
@@ -55,14 +56,40 @@ def get_set(self):
55
56
return Set (None , _pointer = self ._set_ptr )
56
57
57
58
59
+ class OpBuilderInputType (Enum ):
60
+ SET = 1
61
+ STREAM_BUILDER = 2
62
+
63
+
58
64
class OpBuilder (object ):
59
- def __init__ (self , set_ptr ):
65
+
66
+ _BUILDERS = {
67
+ OpBuilderInputType .SET : lib .fst_set_make_opbuilder ,
68
+ OpBuilderInputType .STREAM_BUILDER : lib .fst_set_make_opbuilder_streambuilder ,
69
+ }
70
+ _PUSHERS = {
71
+ OpBuilderInputType .SET : lib .fst_set_opbuilder_push ,
72
+ OpBuilderInputType .STREAM_BUILDER : lib .fst_set_opbuilder_push_streambuilder ,
73
+ }
74
+
75
+ @classmethod
76
+ def from_slice (cls , set_ptr , s ):
77
+ sb = StreamBuilder .from_slice (set_ptr , s )
78
+ opbuilder = OpBuilder (sb ._ptr ,
79
+ input_type = OpBuilderInputType .STREAM_BUILDER )
80
+ return opbuilder
81
+
82
+ def __init__ (self , ptr , input_type = OpBuilderInputType .SET ):
83
+ if input_type not in self ._BUILDERS :
84
+ raise ValueError (
85
+ "input_type must be a member of OpBuilderInputType." )
86
+ self ._input_type = input_type
60
87
# NOTE: No need for `ffi.gc`, since the struct will be free'd
61
88
# once we call union/intersection/difference
62
- self ._ptr = lib . fst_set_make_opbuilder ( set_ptr )
89
+ self ._ptr = OpBuilder . _BUILDERS [ self . _input_type ]( ptr )
63
90
64
- def push (self , set_ptr ):
65
- lib . fst_set_opbuilder_push (self ._ptr , set_ptr )
91
+ def push (self , ptr ):
92
+ OpBuilder . _PUSHERS [ self . _input_type ] (self ._ptr , ptr )
66
93
67
94
def union (self ):
68
95
stream_ptr = lib .fst_set_opbuilder_union (self ._ptr )
@@ -86,6 +113,44 @@ def symmetric_difference(self):
86
113
lib .fst_set_symmetricdifference_free )
87
114
88
115
116
+ class StreamBuilder (object ):
117
+
118
+ @classmethod
119
+ def from_slice (cls , set_ptr , slice_bounds ):
120
+ sb = StreamBuilder (set_ptr )
121
+ if slice_bounds .start :
122
+ sb .ge (slice_bounds .start )
123
+ if slice_bounds .stop :
124
+ sb .lt (slice_bounds .stop )
125
+ return sb
126
+
127
+ def __init__ (self , set_ptr ):
128
+ # NOTE: No need for `ffi.gc`, since the struct will be free'd
129
+ # once we call union/intersection/difference
130
+ self ._ptr = lib .fst_set_streambuilder_new (set_ptr )
131
+
132
+ def finish (self ):
133
+ stream_ptr = lib .fst_set_streambuilder_finish (self ._ptr )
134
+ return KeyStreamIterator (stream_ptr , lib .fst_set_stream_next ,
135
+ lib .fst_set_stream_free )
136
+
137
+ def ge (self , bound ):
138
+ c_start = ffi .new ("char[]" , bound .encode ('utf8' ))
139
+ self ._ptr = lib .fst_set_streambuilder_add_ge (self ._ptr , c_start )
140
+
141
+ def gt (self , bound ):
142
+ c_start = ffi .new ("char[]" , bound .encode ('utf8' ))
143
+ self ._ptr = lib .fst_set_streambuilder_add_gt (self ._ptr , c_start )
144
+
145
+ def le (self , bound ):
146
+ c_end = ffi .new ("char[]" , bound .encode ('utf8' ))
147
+ self ._ptr = lib .fst_set_streambuilder_add_le (self ._ptr , c_end )
148
+
149
+ def lt (self , bound ):
150
+ c_end = ffi .new ("char[]" , bound .encode ('utf8' ))
151
+ self ._ptr = lib .fst_set_streambuilder_add_lt (self ._ptr , c_end )
152
+
153
+
89
154
class Set (object ):
90
155
""" An immutable ordered string set backed by a finite state transducer.
91
156
@@ -203,19 +268,11 @@ def __getitem__(self, s):
203
268
if s .start and s .stop and s .start > s .stop :
204
269
raise ValueError (
205
270
"Start key must be lexicographically smaller than stop." )
206
- sb_ptr = lib .fst_set_streambuilder_new (self ._ptr )
207
- if s .start :
208
- c_start = ffi .new ("char[]" , s .start .encode ('utf8' ))
209
- sb_ptr = lib .fst_set_streambuilder_add_ge (sb_ptr , c_start )
210
- if s .stop :
211
- c_stop = ffi .new ("char[]" , s .stop .encode ('utf8' ))
212
- sb_ptr = lib .fst_set_streambuilder_add_lt (sb_ptr , c_stop )
213
- stream_ptr = lib .fst_set_streambuilder_finish (sb_ptr )
214
- return KeyStreamIterator (stream_ptr , lib .fst_set_stream_next ,
215
- lib .fst_set_stream_free )
271
+ sb = StreamBuilder .from_slice (self ._ptr , s )
272
+ return sb .finish ()
216
273
217
274
def _make_opbuilder (self , * others ):
218
- opbuilder = OpBuilder (self ._ptr )
275
+ opbuilder = OpBuilder (self ._ptr , input_type = OpBuilderInputType . SET )
219
276
for oth in others :
220
277
opbuilder .push (oth ._ptr )
221
278
return opbuilder
@@ -333,3 +390,65 @@ def search(self, term, max_dist):
333
390
return KeyStreamIterator (stream_ptr , lib .fst_set_levstream_next ,
334
391
lib .fst_set_levstream_free , lev_ptr ,
335
392
lib .fst_levenshtein_free )
393
+
394
+
395
+ class UnionSet (object ):
396
+ """ A collection of Set objects that offer efficient operations across all
397
+ members.
398
+ """
399
+ def __init__ (self , * sets ):
400
+ self .sets = list (sets )
401
+
402
+ def __contains__ (self , val ):
403
+ """ Check if the set contains the value. """
404
+ return any ([
405
+ lib .fst_set_contains (fst ._ptr ,
406
+ ffi .new ("char[]" ,
407
+ val .encode ('utf8' )))
408
+ for fst in self .sets
409
+ ])
410
+
411
+ def __getitem__ (self , s ):
412
+ """ Get an iterator over a range of set contents.
413
+
414
+ Start and stop indices of the slice must be unicode strings.
415
+
416
+ .. important::
417
+ Slicing follows the semantics for numerical indices, i.e. the
418
+ `stop` value is **exclusive**. For example, given the set
419
+ `s = Set.from_iter(["bar", "baz", "foo", "moo"])`, `s['b': 'f']`
420
+ will only return `"bar"` and `"baz"`.
421
+
422
+ :param s: A slice that specifies the range of the set to retrieve
423
+ :type s: :py:class:`slice`
424
+ """
425
+ if not isinstance (s , slice ):
426
+ raise ValueError (
427
+ "Value must be a string slice (e.g. `['foo':]`)" )
428
+ if s .start and s .stop and s .start > s .stop :
429
+ raise ValueError (
430
+ "Start key must be lexicographically smaller than stop." )
431
+ if len (self .sets ) <= 1 :
432
+ raise ValueError (
433
+ "Must have more than one set to operate on." )
434
+
435
+ opbuilder = OpBuilder .from_slice (self .sets [0 ]._ptr , s )
436
+ streams = []
437
+ for fst in self .sets [1 :]:
438
+ sb = StreamBuilder .from_slice (fst ._ptr , s )
439
+ streams .append (sb )
440
+ for sb in streams :
441
+ opbuilder .push (sb ._ptr )
442
+ return opbuilder .union ()
443
+
444
+ def __iter__ (self ):
445
+ """ Get an iterator over all keys in all sets in lexicographical order.
446
+ """
447
+ if len (self .sets ) <= 1 :
448
+ raise ValueError (
449
+ "Must have more than one set to operate on." )
450
+ opbuilder = OpBuilder (self .sets [0 ]._ptr ,
451
+ input_type = OpBuilderInputType .SET )
452
+ for fst in self .sets [1 :]:
453
+ opbuilder .push (fst ._ptr )
454
+ return opbuilder .union ()
0 commit comments