Skip to content

Commit 795852c

Browse files
committed
fix: dedup contexts in eds.span_pooler
1 parent ca645a5 commit 795852c

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

edsnlp/pipes/trainable/embeddings/span_pooler/span_pooler.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,18 @@ def preprocess(
123123
begins = []
124124
ends = []
125125

126-
contexts_to_idx = {span: i for i, span in enumerate(contexts)}
126+
contexts_to_idx = {}
127+
for ctx in contexts:
128+
contexts_to_idx[ctx] = len(contexts_to_idx)
129+
dedup_contexts = sorted(contexts_to_idx, key=contexts_to_idx.get)
127130
assert not pre_aligned or len(spans) == len(contexts), (
128131
"When `pre_aligned` is True, the number of spans and contexts must be the "
129132
"same."
130133
)
131134
aligned_contexts = (
132-
[[c] for c in contexts]
135+
[[c] for c in dedup_contexts]
133136
if pre_aligned
134-
else align_spans(contexts, spans, sort_by_overlap=True)
137+
else align_spans(dedup_contexts, spans, sort_by_overlap=True)
135138
)
136139
for i, (span, ctx) in enumerate(zip(spans, aligned_contexts)):
137140
if len(ctx) == 0 or ctx[0].start > span.start or ctx[0].end < span.end:
@@ -143,12 +146,16 @@ def preprocess(
143146
sequence_idx.append(contexts_to_idx[ctx[0]])
144147
begins.append(span.start - start)
145148
ends.append(span.end - start)
149+
assert begins[-1] >= 0, f"Begin offset is negative: {span.text}"
150+
assert ends[-1] <= len(ctx[0]), f"End offset is out of bounds: {span.text}"
146151
return {
147152
"begins": begins,
148153
"ends": ends,
149154
"sequence_idx": sequence_idx,
150-
"num_sequences": len(contexts),
151-
"embedding": self.embedding.preprocess(doc, contexts=contexts, **kwargs),
155+
"num_sequences": len(dedup_contexts),
156+
"embedding": self.embedding.preprocess(
157+
doc, contexts=dedup_contexts, **kwargs
158+
),
152159
"stats": {"spans": len(begins)},
153160
}
154161

0 commit comments

Comments
 (0)