Skip to content

Commit be824e1

Browse files
committed
move classif head model logic to doc_classifier instead of doc_pooler
1 parent 780c584 commit be824e1

File tree

2 files changed

+57
-29
lines changed

2 files changed

+57
-29
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from typing import Any, Dict, Iterable, Optional, Sequence, Set, Union
44

55
import torch
6+
import torch.nn as nn
67
from spacy.tokens import Doc
7-
from typing_extensions import NotRequired, TypedDict
8+
from typing_extensions import Literal, NotRequired, TypedDict
89

910
from edsnlp.core.pipeline import PipelineProtocol
1011
from edsnlp.core.torch_component import BatchInput, TorchComponent
@@ -36,6 +37,8 @@ class TrainableDocClassifier(
3637
TorchComponent[DocClassifierBatchOutput, DocClassifierBatchInput],
3738
BaseComponent,
3839
):
40+
"""A trainable document classifier that uses embeddings to classify documents."""
41+
3942
def __init__(
4043
self,
4144
nlp: Optional[PipelineProtocol] = None,
@@ -49,12 +52,21 @@ def __init__(
4952
loss_fn=None,
5053
labels: Optional[Sequence[str]] = None,
5154
class_weights: Optional[Union[Dict[str, float], str]] = None,
55+
hidden_size: Optional[int] = None,
56+
activation_mode: Literal["relu", "gelu", "silu"] = "relu",
57+
dropout_rate: Optional[float] = 0.0,
58+
layer_norm: Optional[bool] = False,
5259
):
60+
self.num_classes = num_classes
5361
self.label_attr: Attributes = label_attr
5462
self.label2id = label2id or {}
5563
self.id2label = id2label or {}
5664
self.labels = labels
5765
self.class_weights = class_weights
66+
self.hidden_size = hidden_size
67+
self.activation_mode = activation_mode
68+
self.dropout_rate = dropout_rate
69+
self.layer_norm = layer_norm
5870

5971
super().__init__(nlp, name)
6072
self.embedding = embedding
@@ -66,9 +78,23 @@ def __init__(
6678
raise ValueError(
6779
"The embedding component must have an 'output_size' attribute."
6880
)
69-
embedding_size = self.embedding.output_size
70-
if num_classes:
71-
self.classifier = torch.nn.Linear(embedding_size, num_classes)
81+
self.embedding_size = self.embedding.output_size
82+
if self.num_classes:
83+
self.build_classifier()
84+
85+
def build_classifier(self):
86+
"""Build classification head"""
87+
if self.hidden_size:
88+
self.hidden_layer = torch.nn.Linear(self.embedding_size, self.hidden_size)
89+
self.activation = {"relu": nn.ReLU(), "gelu": nn.GELU(), "silu": nn.SiLU()}[
90+
self.activation_mode
91+
]
92+
if self.layer_norm:
93+
self.norm = nn.LayerNorm(self.hidden_size)
94+
self.dropout = nn.Dropout(self.dropout_rate)
95+
self.classifier = torch.nn.Linear(self.hidden_size, self.num_classes)
96+
else:
97+
self.classifier = torch.nn.Linear(self.embedding_size, self.num_classes)
7298

7399
def _compute_class_weights(self, freq_dict: Dict[str, int]) -> torch.Tensor:
74100
"""
@@ -112,10 +138,9 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
112138
for i, label in enumerate(labels):
113139
self.label2id[label] = i
114140
self.id2label[i] = label
115-
print("num classes:", len(self.label2id))
116-
self.classifier = torch.nn.Linear(
117-
self.embedding.output_size, len(self.label2id)
118-
)
141+
self.num_classes = len(self.label2id)
142+
print("num classes:", self.num_classes)
143+
self.build_classifier()
119144

120145
weight_tensor = None
121146
if self.class_weights is not None:
@@ -138,6 +163,7 @@ def preprocess(self, doc: Doc) -> Dict[str, Any]:
138163
return {"embedding": self.embedding.preprocess(doc)}
139164

140165
def preprocess_supervised(self, doc: Doc) -> Dict[str, Any]:
166+
"""Preprocess document with target labels for training."""
141167
preps = self.preprocess(doc)
142168
label = getattr(doc._, self.label_attr, None)
143169
if label is None:
@@ -166,9 +192,14 @@ def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
166192
if targets provided.
167193
"""
168194
pooled = self.embedding(batch["embedding"])
169-
embeddings = pooled["embeddings"]
170-
171-
logits = self.classifier(embeddings)
195+
x = pooled["embeddings"]
196+
if self.hidden_size:
197+
x = self.hidden_layer(x)
198+
x = self.activation(x)
199+
if self.layer_norm:
200+
x = self.norm(x)
201+
x = self.dropout(x)
202+
logits = self.classifier(x)
172203

173204
output: DocClassifierBatchOutput = {}
174205
if "targets" in batch:
@@ -181,6 +212,7 @@ def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
181212
return output
182213

183214
def postprocess(self, docs, results, input):
215+
"""Postprocess predictions by assigning labels to documents."""
184216
labels = results["labels"]
185217
if isinstance(labels, torch.Tensor):
186218
labels = labels.tolist()

edsnlp/pipes/trainable/embeddings/doc_pooler/doc_pooler.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,11 @@ def __init__(
5252
*,
5353
embedding: WordEmbeddingComponent,
5454
pooling_mode: Literal["max", "sum", "mean", "cls"] = "mean",
55-
hidden_size: Optional[int] = None,
5655
):
5756
super().__init__(nlp, name)
5857
self.embedding = embedding
5958
self.pooling_mode = pooling_mode
60-
self.output_size = embedding.output_size if hidden_size is None else hidden_size
61-
self.projector = (
62-
torch.nn.Linear(self.embedding.output_size, hidden_size)
63-
if hidden_size is not None
64-
else torch.nn.Identity()
65-
)
66-
67-
def feed_forward(self, doc_embed: torch.Tensor) -> torch.Tensor:
68-
return self.projector(doc_embed)
59+
self.output_size = embedding.output_size
6960

7061
def preprocess(self, doc: Doc, **kwargs) -> Dict[str, Any]:
7162
embedding_out = self.embedding.preprocess(doc, **kwargs)
@@ -85,21 +76,26 @@ def collate(self, batch: Dict[str, Any]) -> DocPoolerBatchInput:
8576
}
8677

8778
def forward(self, batch: DocPoolerBatchInput) -> DocPoolerBatchOutput:
88-
device = next(self.parameters()).device
89-
9079
embeds = self.embedding(batch["embedding"])["embeddings"]
9180
device = embeds.device
9281

82+
if self.pooling_mode == "cls":
83+
pooled = self.embedding(batch["embedding"])["cls"].to(device)
84+
return {"embeddings": pooled}
85+
86+
mask = embeds.mask
87+
mask_expanded = mask.unsqueeze(-1)
88+
masked_embeds = embeds * mask_expanded
89+
sum_embeds = masked_embeds.sum(dim=1)
9390
if self.pooling_mode == "mean":
94-
pooled = embeds.mean(dim=1)
91+
valid_counts = mask.sum(dim=1, keepdim=True).clamp(min=1)
92+
pooled = sum_embeds / valid_counts
9593
elif self.pooling_mode == "max":
96-
pooled = embeds.max(dim=1).values
94+
masked_embeds = embeds.masked_fill(~mask_expanded, float("-inf"))
95+
pooled, _ = masked_embeds.max(dim=1)
9796
elif self.pooling_mode == "sum":
98-
pooled = embeds.sum(dim=1) / embeds.size(1)
99-
elif self.pooling_mode == "cls":
100-
pooled = self.embedding(batch["embedding"])["cls"].to(device)
97+
pooled = sum_embeds
10198
else:
10299
raise ValueError(f"Unknown pooling mode: {self.pooling_mode}")
103100

104-
pooled = self.feed_forward(pooled)
105101
return {"embeddings": pooled}

0 commit comments

Comments
 (0)