@@ -48,17 +48,17 @@ def __init__(
48
48
id2label : Optional [Dict [int , str ]] = None ,
49
49
loss_fn = None ,
50
50
labels : Optional [Sequence [str ]] = None ,
51
- class_weights : Optional [Union [Dict [str , float ], str ]] = None ,
51
+ class_weights : Optional [Union [Dict [str , float ], str ]] = None ,
52
52
):
53
53
self .label_attr : Attributes = label_attr
54
54
self .label2id = label2id or {}
55
55
self .id2label = id2label or {}
56
56
self .labels = labels
57
- self .class_weights = class_weights
58
-
57
+ self .class_weights = class_weights
58
+
59
59
super ().__init__ (nlp , name )
60
60
self .embedding = embedding
61
-
61
+
62
62
self ._loss_fn = loss_fn
63
63
self .loss_fn = None
64
64
@@ -76,19 +76,19 @@ def _compute_class_weights(self, freq_dict: Dict[str, int]) -> torch.Tensor:
76
76
Uses inverse frequency weighting: weight = 1 / frequency
77
77
"""
78
78
total_samples = sum (freq_dict .values ())
79
-
79
+
80
80
weights = torch .zeros (len (self .label2id ))
81
-
81
+
82
82
for label , freq in freq_dict .items ():
83
83
if label in self .label2id :
84
84
weight = total_samples / (len (self .label2id ) * freq )
85
85
weights [self .label2id [label ]] = weight
86
-
86
+
87
87
return weights
88
88
89
89
def _load_class_weights_from_file (self , filepath : str ) -> Dict [str , int ]:
90
90
"""Load class weights from pickle file."""
91
- with open (filepath , 'rb' ) as f :
91
+ with open (filepath , "rb" ) as f :
92
92
return pickle .load (f )
93
93
94
94
def set_extensions (self ) -> None :
@@ -116,22 +116,22 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
116
116
self .classifier = torch .nn .Linear (
117
117
self .embedding .output_size , len (self .label2id )
118
118
)
119
-
119
+
120
120
weight_tensor = None
121
121
if self .class_weights is not None :
122
122
if isinstance (self .class_weights , str ):
123
123
freq_dict = self ._load_class_weights_from_file (self .class_weights )
124
124
weight_tensor = self ._compute_class_weights (freq_dict )
125
125
elif isinstance (self .class_weights , dict ):
126
126
weight_tensor = self ._compute_class_weights (self .class_weights )
127
-
127
+
128
128
print (f"Using class weights: { weight_tensor } " )
129
-
129
+
130
130
if self ._loss_fn is not None :
131
131
self .loss_fn = self ._loss_fn
132
132
else :
133
133
self .loss_fn = torch .nn .CrossEntropyLoss (weight = weight_tensor )
134
-
134
+
135
135
super ().post_init (gold_data , exclude = exclude )
136
136
137
137
def preprocess (self , doc : Doc ) -> Dict [str , Any ]:
@@ -161,6 +161,10 @@ def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput:
161
161
return batch_input
162
162
163
163
def forward (self , batch : DocClassifierBatchInput ) -> DocClassifierBatchOutput :
164
+ """
165
+ Forward pass: compute embeddings, classify, and calculate loss
166
+ if targets provided.
167
+ """
164
168
pooled = self .embedding (batch ["embedding" ])
165
169
embeddings = pooled ["embeddings" ]
166
170
@@ -187,6 +191,7 @@ def postprocess(self, docs, results, input):
187
191
return docs
188
192
189
193
def to_disk (self , path , * , exclude = set ()):
194
+ """Save classifier state to disk."""
190
195
repr_id = object .__repr__ (self )
191
196
if repr_id in exclude :
192
197
return
@@ -206,11 +211,12 @@ def to_disk(self, path, *, exclude=set()):
206
211
207
212
@classmethod
208
213
def from_disk (cls , path , ** kwargs ):
214
+ """Load classifier from disk."""
209
215
data_path = path / "label_attr.pkl"
210
216
with open (data_path , "rb" ) as f :
211
217
data = pickle .load (f )
212
218
obj = super ().from_disk (path , ** kwargs )
213
219
obj .label_attr = data .get ("label_attr" , "label" )
214
220
obj .label2id = data .get ("label2id" , {})
215
221
obj .id2label = data .get ("id2label" , {})
216
- return obj
222
+ return obj
0 commit comments