3
3
from typing import Any , Dict , Iterable , Optional , Sequence , Set , Union
4
4
5
5
import torch
6
+ import torch .nn as nn
6
7
from spacy .tokens import Doc
7
- from typing_extensions import NotRequired , TypedDict
8
+ from typing_extensions import Literal , NotRequired , TypedDict
8
9
9
10
from edsnlp .core .pipeline import PipelineProtocol
10
11
from edsnlp .core .torch_component import BatchInput , TorchComponent
@@ -36,6 +37,8 @@ class TrainableDocClassifier(
36
37
TorchComponent [DocClassifierBatchOutput , DocClassifierBatchInput ],
37
38
BaseComponent ,
38
39
):
40
+ """A trainable document classifier that uses embeddings to classify documents."""
41
+
39
42
def __init__ (
40
43
self ,
41
44
nlp : Optional [PipelineProtocol ] = None ,
@@ -49,12 +52,21 @@ def __init__(
49
52
loss_fn = None ,
50
53
labels : Optional [Sequence [str ]] = None ,
51
54
class_weights : Optional [Union [Dict [str , float ], str ]] = None ,
55
+ hidden_size : Optional [int ] = None ,
56
+ activation_mode : Literal ["relu" , "gelu" , "silu" ] = "relu" ,
57
+ dropout_rate : Optional [float ] = 0.0 ,
58
+ layer_norm : Optional [bool ] = False ,
52
59
):
60
+ self .num_classes = num_classes
53
61
self .label_attr : Attributes = label_attr
54
62
self .label2id = label2id or {}
55
63
self .id2label = id2label or {}
56
64
self .labels = labels
57
65
self .class_weights = class_weights
66
+ self .hidden_size = hidden_size
67
+ self .activation_mode = activation_mode
68
+ self .dropout_rate = dropout_rate
69
+ self .layer_norm = layer_norm
58
70
59
71
super ().__init__ (nlp , name )
60
72
self .embedding = embedding
@@ -66,9 +78,23 @@ def __init__(
66
78
raise ValueError (
67
79
"The embedding component must have an 'output_size' attribute."
68
80
)
69
- embedding_size = self .embedding .output_size
70
- if num_classes :
71
- self .classifier = torch .nn .Linear (embedding_size , num_classes )
81
+ self .embedding_size = self .embedding .output_size
82
+ if self .num_classes :
83
+ self .build_classifier ()
84
+
85
+ def build_classifier (self ):
86
+ """Build classification head"""
87
+ if self .hidden_size :
88
+ self .hidden_layer = torch .nn .Linear (self .embedding_size , self .hidden_size )
89
+ self .activation = {"relu" : nn .ReLU (), "gelu" : nn .GELU (), "silu" : nn .SiLU ()}[
90
+ self .activation_mode
91
+ ]
92
+ if self .layer_norm :
93
+ self .norm = nn .LayerNorm (self .hidden_size )
94
+ self .dropout = nn .Dropout (self .dropout_rate )
95
+ self .classifier = torch .nn .Linear (self .hidden_size , self .num_classes )
96
+ else :
97
+ self .classifier = torch .nn .Linear (self .embedding_size , self .num_classes )
72
98
73
99
def _compute_class_weights (self , freq_dict : Dict [str , int ]) -> torch .Tensor :
74
100
"""
@@ -112,10 +138,9 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
112
138
for i , label in enumerate (labels ):
113
139
self .label2id [label ] = i
114
140
self .id2label [i ] = label
115
- print ("num classes:" , len (self .label2id ))
116
- self .classifier = torch .nn .Linear (
117
- self .embedding .output_size , len (self .label2id )
118
- )
141
+ self .num_classes = len (self .label2id )
142
+ print ("num classes:" , self .num_classes )
143
+ self .build_classifier ()
119
144
120
145
weight_tensor = None
121
146
if self .class_weights is not None :
@@ -138,6 +163,7 @@ def preprocess(self, doc: Doc) -> Dict[str, Any]:
138
163
return {"embedding" : self .embedding .preprocess (doc )}
139
164
140
165
def preprocess_supervised (self , doc : Doc ) -> Dict [str , Any ]:
166
+ """Preprocess document with target labels for training."""
141
167
preps = self .preprocess (doc )
142
168
label = getattr (doc ._ , self .label_attr , None )
143
169
if label is None :
@@ -166,9 +192,14 @@ def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
166
192
if targets provided.
167
193
"""
168
194
pooled = self .embedding (batch ["embedding" ])
169
- embeddings = pooled ["embeddings" ]
170
-
171
- logits = self .classifier (embeddings )
195
+ x = pooled ["embeddings" ]
196
+ if self .hidden_size :
197
+ x = self .hidden_layer (x )
198
+ x = self .activation (x )
199
+ if self .layer_norm :
200
+ x = self .norm (x )
201
+ x = self .dropout (x )
202
+ logits = self .classifier (x )
172
203
173
204
output : DocClassifierBatchOutput = {}
174
205
if "targets" in batch :
@@ -181,6 +212,7 @@ def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
181
212
return output
182
213
183
214
def postprocess (self , docs , results , input ):
215
+ """Postprocess predictions by assigning labels to documents."""
184
216
labels = results ["labels" ]
185
217
if isinstance (labels , torch .Tensor ):
186
218
labels = labels .tolist ()
0 commit comments