Skip to content

Commit a426e9c

Browse files
committed
feat: new attention span_pooling mode
1 parent a764e03 commit a426e9c

File tree

13 files changed

+741
-119
lines changed

13 files changed

+741
-119
lines changed

changelog.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# Changelog
22

3+
## Unreleased
4+
5+
### Added
6+
7+
- New `attention` pooling mode in `eds.span_pooler`
8+
- New `word_pooling_mode=False` in `eds.transformer` to allow returning the worpiece embeddings directly, instead of the mean-pooled word embeddings. At the moment, this only works with `eds.span_pooler` which can pool over wordpieces or words seamlessly.
9+
310
## v0.18.0 (2025-09-02)
411

512
📢 EDS-NLP will drop support for Python 3.7, 3.8 and 3.9 support in the next major release (v0.19.0), in October 2025. Please upgrade to Python 3.10 or later.
@@ -13,6 +20,7 @@
1320
- New `eds.explode` pipe that splits one document into multiple documents, one per span yielded by its `span_getter` parameter, each new document containing exactly that single span.
1421
- New `Training a span classifier` tutorial, and reorganized deep-learning docs
1522
- `ScheduledOptimizer` now warns when a parameter selector does not match any parameter.
23+
- New `attention` pooling mode in `eds.span_pooler`
1624

1725
### Fixed
1826

docs/tutorials/index.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ We provide step-by-step guides to get you started. We cover the following use-ca
44

55
### Base tutorials
66

7+
<!-- --8<-- [start:tutorials] -->
78
<!-- --8<-- [start:classic-tutorials] -->
89

910
=== card {: href=/tutorials/spacy101 }
@@ -85,6 +86,8 @@ We provide step-by-step guides to get you started. We cover the following use-ca
8586
---
8687
Quickly visualize the results of your pipeline as annotations or tables.
8788

89+
<!-- --8<-- [end:classic-tutorials] -->
90+
8891
### Deep learning tutorials
8992

9093
We also provide tutorials on how to train deep-learning models with EDS-NLP. These tutorials cover the training API, hyperparameter tuning, and more.
@@ -123,8 +126,5 @@ We also provide tutorials on how to train deep-learning models with EDS-NLP. The
123126
---
124127
Learn how to tune hyperparameters of a model with `edsnlp.tune`.
125128

126-
127129
<!-- --8<-- [end:deep-learning-tutorials] -->
128-
129-
130130
<!-- --8<-- [end:tutorials] -->

edsnlp/core/torch_component.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,14 @@ def compute_training_metrics(
339339
This is useful to compute averages when doing multi-gpu training or mini-batch
340340
accumulation since full denominators are not known during the forward pass.
341341
"""
342-
return batch_output
342+
return (
343+
{
344+
**batch_output,
345+
"loss": batch_output["loss"] / count,
346+
}
347+
if "loss" in batch_output
348+
else batch_output
349+
)
343350

344351
def module_forward(self, *args, **kwargs): # pragma: no cover
345352
"""
@@ -348,6 +355,31 @@ def module_forward(self, *args, **kwargs): # pragma: no cover
348355
"""
349356
return torch.nn.Module.__call__(self, *args, **kwargs)
350357

358+
def preprocess_batch(self, docs: Sequence[Doc], supervision=False, **kwargs):
359+
"""
360+
Convenience method to preprocess a batch of documents.
361+
Features corresponding to the same path are grouped together in a list,
362+
under the same key.
363+
364+
Parameters
365+
----------
366+
docs: Sequence[Doc]
367+
Batch of documents
368+
supervision: bool
369+
Whether to extract supervision features or not
370+
371+
Returns
372+
-------
373+
Dict[str, Sequence[Any]]
374+
The batch of features
375+
"""
376+
batch = [
377+
(self.preprocess_supervised(d) if supervision else self.preprocess(d))
378+
for d in docs
379+
]
380+
batch = decompress_dict(list(batch_compress_dict(batch)))
381+
return batch
382+
351383
def prepare_batch(
352384
self,
353385
docs: Sequence[Doc],
@@ -372,11 +404,7 @@ def prepare_batch(
372404
-------
373405
Dict[str, Sequence[Any]]
374406
"""
375-
batch = [
376-
(self.preprocess_supervised(doc) if supervision else self.preprocess(doc))
377-
for doc in docs
378-
]
379-
batch = decompress_dict(list(batch_compress_dict(batch)))
407+
batch = self.preprocess_batch(docs, supervision=supervision)
380408
batch = self.collate(batch)
381409
batch = self.batch_to_device(batch, device=device)
382410
return batch

0 commit comments

Comments
 (0)