Skip to content

Commit 15344d2

Browse files
LucasDedieuLucasDedieu
authored andcommitted
[WIP] add TrainableDocClassifier
1 parent 8e9ed84 commit 15344d2

File tree

13 files changed

+379
-7
lines changed

13 files changed

+379
-7
lines changed

edsnlp/metrics/doc_classif.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
2+
3+
from spacy.tokens import Doc
4+
from spacy.training import Example
5+
6+
from edsnlp import registry
7+
from edsnlp.metrics import make_examples
8+
9+
10+
def doc_classification_metric(
11+
examples: Union[Tuple[Iterable[Doc], Iterable[Doc]], Iterable[Example]],
12+
label_attr: str = "label",
13+
micro_key: str = "micro",
14+
filter_expr: Optional[str] = None,
15+
) -> Dict[str, Any]:
16+
"""
17+
Scores document-level classification (accuracy, precision, recall, F1).
18+
19+
Parameters
20+
----------
21+
examples: Examples
22+
The examples to score, either a tuple of (golds, preds) or a list of
23+
spacy.training.Example objects
24+
label_attr: str
25+
The Doc._ attribute containing the label
26+
micro_key: str
27+
The key to use to store the micro-averaged results
28+
filter_expr: str
29+
The filter expression to use to filter the documents
30+
31+
Returns
32+
-------
33+
Dict[str, Any]
34+
"""
35+
examples = make_examples(examples)
36+
if filter_expr is not None:
37+
filter_fn = eval(f"lambda doc: {filter_expr}")
38+
examples = [eg for eg in examples if filter_fn(eg.reference)]
39+
40+
pred_labels = []
41+
gold_labels = []
42+
for eg in examples:
43+
pred = getattr(eg.predicted._, label_attr, None)
44+
gold = getattr(eg.reference._, label_attr, None)
45+
pred_labels.append(pred)
46+
gold_labels.append(gold)
47+
48+
print(pred_labels, gold_labels)
49+
50+
labels = set(gold_labels) | set(pred_labels)
51+
results = {}
52+
for label in labels:
53+
pred_set = [i for i, p in enumerate(pred_labels) if p == label]
54+
gold_set = [i for i, g in enumerate(gold_labels) if g == label]
55+
tp = len(set(pred_set) & set(gold_set))
56+
num_pred = len(pred_set)
57+
num_gold = len(gold_set)
58+
results[label] = {
59+
"f": 2 * tp / max(1, num_pred + num_gold),
60+
"p": 1 if tp == num_pred else (tp / num_pred) if num_pred else 0.0,
61+
"r": 1 if tp == num_gold else (tp / num_gold) if num_gold else 0.0,
62+
"tp": tp,
63+
"support": num_gold,
64+
"positives": num_pred,
65+
}
66+
67+
tp = sum(1 for p, g in zip(pred_labels, gold_labels) if p == g)
68+
num_pred = len(pred_labels)
69+
num_gold = len(gold_labels)
70+
results[micro_key] = {
71+
"accuracy": tp / num_gold if num_gold else 0.0,
72+
"f": 2 * tp / max(1, num_pred + num_gold),
73+
"p": tp / num_pred if num_pred else 0.0,
74+
"r": tp / num_gold if num_gold else 0.0,
75+
"tp": tp,
76+
"support": num_gold,
77+
"positives": num_pred,
78+
}
79+
return results
80+
81+
82+
@registry.metrics.register("eds.doc_classification")
83+
class DocClassificationMetric:
84+
def __init__(
85+
self,
86+
label_attr: str = "label",
87+
micro_key: str = "micro",
88+
filter_expr: Optional[str] = None,
89+
):
90+
self.label_attr = label_attr
91+
self.micro_key = micro_key
92+
self.filter_expr = filter_expr
93+
94+
def __call__(self, *examples):
95+
return doc_classification_metric(
96+
examples,
97+
label_attr=self.label_attr,
98+
micro_key=self.micro_key,
99+
filter_expr=self.filter_expr,
100+
)
101+
102+
103+
__all__ = [
104+
"doc_classification_metric",
105+
"DocClassificationMetric",
106+
]

edsnlp/pipes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,6 @@
8282
from .trainable.embeddings.span_pooler.factory import create_component as span_pooler
8383
from .trainable.embeddings.transformer.factory import create_component as transformer
8484
from .trainable.embeddings.text_cnn.factory import create_component as text_cnn
85+
from .trainable.embeddings.doc_pooler.factory import create_component as doc_pooler
86+
from .trainable.doc_classifier.factory import create_component as doc_classifier
8587
from .misc.split import Split as split
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .factory import create_component
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import os
2+
import pickle
3+
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union
4+
5+
import torch
6+
from spacy.tokens import Doc
7+
from typing_extensions import NotRequired, TypedDict
8+
9+
from edsnlp.core.pipeline import PipelineProtocol
10+
from edsnlp.core.torch_component import BatchInput, TorchComponent
11+
from edsnlp.pipes.base import BaseComponent
12+
from edsnlp.pipes.trainable.embeddings.typing import (
13+
WordContextualizerComponent,
14+
WordEmbeddingComponent,
15+
)
16+
from edsnlp.utils.bindings import Attributes
17+
18+
DocClassifierBatchInput = TypedDict(
19+
"DocClassifierBatchInput",
20+
{
21+
"embedding": BatchInput,
22+
"targets": NotRequired[torch.Tensor],
23+
},
24+
)
25+
26+
DocClassifierBatchOutput = TypedDict(
27+
"DocClassifierBatchOutput",
28+
{
29+
"loss": Optional[torch.Tensor],
30+
"labels": Optional[torch.Tensor],
31+
},
32+
)
33+
34+
35+
class TrainableDocClassifier(
36+
TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput],
37+
BaseComponent,
38+
):
39+
def __init__(
40+
self,
41+
nlp: Optional[PipelineProtocol] = None,
42+
name: str = "doc_classifier",
43+
*,
44+
embedding: Union[WordEmbeddingComponent, WordContextualizerComponent],
45+
num_classes: int,
46+
label_attr: str = "label",
47+
loss_fn=None,
48+
):
49+
self.label_attr: Attributes = label_attr
50+
super().__init__(nlp, name)
51+
self.embedding = embedding
52+
self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss()
53+
54+
if not hasattr(self.embedding, "output_size"):
55+
raise ValueError(
56+
"The embedding component must have an 'output_size' attribute."
57+
)
58+
embedding_size = self.embedding.output_size
59+
self.classifier = torch.nn.Linear(embedding_size, num_classes)
60+
61+
def set_extensions(self) -> None:
62+
super().set_extensions()
63+
if not Doc.has_extension(self.label_attr):
64+
Doc.set_extension(self.label_attr, default={})
65+
66+
def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
67+
super().post_init(gold_data, exclude=exclude)
68+
69+
def preprocess(self, doc: Doc) -> Dict[str, Any]:
70+
return {"embedding": self.embedding.preprocess(doc)}
71+
72+
def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]:
73+
preps = self.preprocess(doc)
74+
label = getattr(doc._, self.label_attr, None)
75+
if label is None:
76+
raise ValueError(
77+
f"Document does not have a gold label in 'doc._.{self.label_attr}'"
78+
)
79+
return {
80+
**preps,
81+
"targets": torch.tensor(label, dtype=torch.long),
82+
}
83+
84+
def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput:
85+
embeddings = self.embedding.collate(batch["embedding"])
86+
batch_input: DocClassifierBatchInput = {"embedding": embeddings}
87+
if "targets" in batch:
88+
batch_input["targets"] = torch.stack(batch["targets"])
89+
return batch_input
90+
91+
def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
92+
pooled = self.embedding(batch["embedding"])
93+
embeddings = pooled["embeddings"]
94+
95+
logits = self.classifier(embeddings)
96+
97+
output: DocClassifierBatchOutput = {}
98+
if "targets" in batch:
99+
loss = self.loss_fn(logits, batch["targets"])
100+
output["loss"] = loss
101+
output["labels"] = None
102+
else:
103+
output["loss"] = None
104+
output["labels"] = torch.argmax(logits, dim=-1)
105+
return output
106+
107+
def postprocess(self, docs, results, input):
108+
labels = results["labels"]
109+
if isinstance(labels, torch.Tensor):
110+
labels = labels.tolist()
111+
for doc, label in zip(docs, labels):
112+
setattr(doc._, self.label_attr, label)
113+
# doc._.label = label
114+
return docs
115+
116+
def to_disk(self, path, *, exclude=set()):
117+
repr_id = object.__repr__(self)
118+
if repr_id in exclude:
119+
return
120+
exclude.add(repr_id)
121+
os.makedirs(path, exist_ok=True)
122+
data_path = path / "label_attr.pkl"
123+
with open(data_path, "wb") as f:
124+
pickle.dump({"label_attr": self.label_attr}, f)
125+
return super().to_disk(path, exclude=exclude)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from edsnlp import registry
2+
3+
from .doc_classifier import TrainableDocClassifier
4+
5+
create_component = registry.factory.register(
6+
"eds.doc_classifier",
7+
assigns=["doc._.predicted_class"],
8+
deprecated=[],
9+
)(TrainableDocClassifier)

edsnlp/pipes/trainable/embeddings/doc_pooler/__init__.py

Whitespace-only changes.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from typing import Any, Dict, Optional
2+
3+
import torch
4+
from spacy.tokens import Doc
5+
from typing_extensions import Literal, TypedDict
6+
7+
from edsnlp.core.pipeline import Pipeline
8+
from edsnlp.core.torch_component import BatchInput
9+
from edsnlp.pipes.base import BaseComponent
10+
from edsnlp.pipes.trainable.embeddings.typing import WordEmbeddingComponent
11+
12+
DocPoolerBatchInput = TypedDict(
13+
"DocPoolerBatchInput",
14+
{
15+
"embedding": BatchInput,
16+
"mask": torch.Tensor, # shape: (batch_size, seq_len)
17+
"stats": Dict[str, Any],
18+
},
19+
)
20+
21+
DocPoolerBatchOutput = TypedDict(
22+
"DocPoolerBatchOutput",
23+
{
24+
"embeddings": torch.Tensor, # shape: (batch_size, embedding_dim)
25+
},
26+
)
27+
28+
29+
class DocPooler(WordEmbeddingComponent, BaseComponent):
30+
"""
31+
Pools word embeddings over the entire document to produce
32+
a single embedding per doc.
33+
34+
Parameters
35+
----------
36+
nlp: Pipeline
37+
The pipeline object
38+
name: str
39+
Name of the component
40+
embedding : WordEmbeddingComponent
41+
The word embedding component
42+
pooling_mode: Literal["max", "sum", "mean"]
43+
How word embeddings are aggregated into a single embedding per document.
44+
hidden_size : Optional[int]
45+
The size of the hidden layer. If None, no projection is done.
46+
"""
47+
48+
def __init__(
49+
self,
50+
nlp: Optional[Pipeline] = None,
51+
name: str = "document_pooler",
52+
*,
53+
embedding: WordEmbeddingComponent,
54+
pooling_mode: Literal["max", "sum", "mean", "cls"] = "mean",
55+
hidden_size: Optional[int] = None,
56+
):
57+
super().__init__(nlp, name)
58+
self.embedding = embedding
59+
self.pooling_mode = pooling_mode
60+
self.output_size = embedding.output_size if hidden_size is None else hidden_size
61+
self.projector = (
62+
torch.nn.Linear(self.embedding.output_size, hidden_size)
63+
if hidden_size is not None
64+
else torch.nn.Identity()
65+
)
66+
67+
def feed_forward(self, doc_embed: torch.Tensor) -> torch.Tensor:
68+
return self.projector(doc_embed)
69+
70+
def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]:
71+
embedding_out = self.embedding.preprocess(doc, **kwargs)
72+
return {
73+
"embedding": embedding_out,
74+
"stats": {"doc_length": len(doc)},
75+
}
76+
77+
def collate(self, batch: Dict[str, Any]) -> DocPoolerBatchInput:
78+
embedding_batch = self.embedding.collate(batch["embedding"])
79+
stats = batch["stats"]
80+
return {
81+
"embedding": embedding_batch,
82+
"stats": {
83+
"doc_length": sum(stats["doc_length"])
84+
}, # <-- sum(...) pour aggréger les comptes par doc en un compte par batch
85+
}
86+
87+
def forward(self, batch: DocPoolerBatchInput) -> DocPoolerBatchOutput:
88+
device = next(self.parameters()).device
89+
90+
embeds = self.embedding(batch["embedding"])["embeddings"]
91+
device = embeds.device
92+
93+
if self.pooling_mode == "mean":
94+
pooled = embeds.mean(dim=1)
95+
elif self.pooling_mode == "max":
96+
pooled = embeds.max(dim=1).values
97+
elif self.pooling_mode == "sum":
98+
pooled = embeds.sum(dim=1)
99+
elif self.pooling_mode == "cls":
100+
pooled = self.embedding(batch["embedding"])["cls"].to(device)
101+
else:
102+
raise ValueError(f"Unknown pooling mode: {self.pooling_mode}")
103+
104+
pooled = self.feed_forward(pooled)
105+
return {"embeddings": pooled}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from edsnlp import registry
2+
3+
from .doc_pooler import DocPooler
4+
5+
create_component = registry.factory.register(
6+
"eds.doc_pooler",
7+
assigns=[],
8+
deprecated=[],
9+
)(DocPooler)

edsnlp/pipes/trainable/embeddings/transformer/transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput:
505505
if "out of memory" in str(e) and trial_idx <= 2:
506506
print(
507507
f"Out of memory: tried to fit {max_windows} "
508-
f"in {free_mem / (1024 ** 3)} (try n°{trial_idx}/2)"
508+
f"in {free_mem / (1024**3)} (try n°{trial_idx}/2)"
509509
)
510510
torch.cuda.empty_cache()
511511
self._mem_per_unit = (free_mem / max_windows) * 1.5
@@ -535,6 +535,7 @@ def forward(self, batch: TransformerBatchInput) -> TransformerBatchOutput:
535535
word_embeddings[batch["empty_word_indices"]] = self.empty_word_embedding
536536
return {
537537
"embeddings": word_embeddings.refold("context", "word"),
538+
"cls": wordpiece_embeddings[:, 0, :],
538539
}
539540

540541
@staticmethod

0 commit comments

Comments
 (0)