@@ -123,15 +123,18 @@ def preprocess(
123
123
begins = []
124
124
ends = []
125
125
126
- contexts_to_idx = {span : i for i , span in enumerate (contexts )}
126
+ contexts_to_idx = {}
127
+ for ctx in contexts :
128
+ contexts_to_idx [ctx ] = len (contexts_to_idx )
129
+ dedup_contexts = sorted (contexts_to_idx , key = contexts_to_idx .get )
127
130
assert not pre_aligned or len (spans ) == len (contexts ), (
128
131
"When `pre_aligned` is True, the number of spans and contexts must be the "
129
132
"same."
130
133
)
131
134
aligned_contexts = (
132
- [[c ] for c in contexts ]
135
+ [[c ] for c in dedup_contexts ]
133
136
if pre_aligned
134
- else align_spans (contexts , spans , sort_by_overlap = True )
137
+ else align_spans (dedup_contexts , spans , sort_by_overlap = True )
135
138
)
136
139
for i , (span , ctx ) in enumerate (zip (spans , aligned_contexts )):
137
140
if len (ctx ) == 0 or ctx [0 ].start > span .start or ctx [0 ].end < span .end :
@@ -143,12 +146,16 @@ def preprocess(
143
146
sequence_idx .append (contexts_to_idx [ctx [0 ]])
144
147
begins .append (span .start - start )
145
148
ends .append (span .end - start )
149
+ assert begins [- 1 ] >= 0 , f"Begin offset is negative: { span .text } "
150
+ assert ends [- 1 ] <= len (ctx [0 ]), f"End offset is out of bounds: { span .text } "
146
151
return {
147
152
"begins" : begins ,
148
153
"ends" : ends ,
149
154
"sequence_idx" : sequence_idx ,
150
- "num_sequences" : len (contexts ),
151
- "embedding" : self .embedding .preprocess (doc , contexts = contexts , ** kwargs ),
155
+ "num_sequences" : len (dedup_contexts ),
156
+ "embedding" : self .embedding .preprocess (
157
+ doc , contexts = dedup_contexts , ** kwargs
158
+ ),
152
159
"stats" : {"spans" : len (begins )},
153
160
}
154
161
0 commit comments