Skip to content

Commit a938317

Browse files
committed
feat: add class_weight handling
1 parent 2d3a2aa commit a938317

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,17 +48,17 @@ def __init__(
4848
id2label: Optional[Dict[int, str]] = None,
4949
loss_fn=None,
5050
labels: Optional[Sequence[str]] = None,
51-
class_weights: Optional[Union[Dict[str, float], str]] = None,
51+
class_weights: Optional[Union[Dict[str, float], str]] = None,
5252
):
5353
self.label_attr: Attributes = label_attr
5454
self.label2id = label2id or {}
5555
self.id2label = id2label or {}
5656
self.labels = labels
57-
self.class_weights = class_weights
58-
57+
self.class_weights = class_weights
58+
5959
super().__init__(nlp, name)
6060
self.embedding = embedding
61-
61+
6262
self._loss_fn = loss_fn
6363
self.loss_fn = None
6464

@@ -76,19 +76,19 @@ def _compute_class_weights(self, freq_dict: Dict[str, int]) -> torch.Tensor:
7676
Uses inverse frequency weighting: weight = 1 / frequency
7777
"""
7878
total_samples = sum(freq_dict.values())
79-
79+
8080
weights = torch.zeros(len(self.label2id))
81-
81+
8282
for label, freq in freq_dict.items():
8383
if label in self.label2id:
8484
weight = total_samples / (len(self.label2id) * freq)
8585
weights[self.label2id[label]] = weight
86-
86+
8787
return weights
8888

8989
def _load_class_weights_from_file(self, filepath: str) -> Dict[str, int]:
9090
"""Load class weights from pickle file."""
91-
with open(filepath, 'rb') as f:
91+
with open(filepath, "rb") as f:
9292
return pickle.load(f)
9393

9494
def set_extensions(self) -> None:
@@ -116,22 +116,22 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
116116
self.classifier = torch.nn.Linear(
117117
self.embedding.output_size, len(self.label2id)
118118
)
119-
119+
120120
weight_tensor = None
121121
if self.class_weights is not None:
122122
if isinstance(self.class_weights, str):
123123
freq_dict = self._load_class_weights_from_file(self.class_weights)
124124
weight_tensor = self._compute_class_weights(freq_dict)
125125
elif isinstance(self.class_weights, dict):
126126
weight_tensor = self._compute_class_weights(self.class_weights)
127-
127+
128128
print(f"Using class weights: {weight_tensor}")
129-
129+
130130
if self._loss_fn is not None:
131131
self.loss_fn = self._loss_fn
132132
else:
133133
self.loss_fn = torch.nn.CrossEntropyLoss(weight=weight_tensor)
134-
134+
135135
super().post_init(gold_data, exclude=exclude)
136136

137137
def preprocess(self, doc: Doc) -> Dict[str, Any]:
@@ -161,6 +161,10 @@ def collate(self, batch: Dict[str, Sequence[Any]]) -> DocClassifierBatchInput:
161161
return batch_input
162162

163163
def forward(self, batch: DocClassifierBatchInput) -> DocClassifierBatchOutput:
164+
"""
165+
Forward pass: compute embeddings, classify, and calculate loss
166+
if targets provided.
167+
"""
164168
pooled = self.embedding(batch["embedding"])
165169
embeddings = pooled["embeddings"]
166170

@@ -187,6 +191,7 @@ def postprocess(self, docs, results, input):
187191
return docs
188192

189193
def to_disk(self, path, *, exclude=set()):
194+
"""Save classifier state to disk."""
190195
repr_id = object.__repr__(self)
191196
if repr_id in exclude:
192197
return
@@ -206,11 +211,12 @@ def to_disk(self, path, *, exclude=set()):
206211

207212
@classmethod
208213
def from_disk(cls, path, **kwargs):
214+
"""Load classifier from disk."""
209215
data_path = path / "label_attr.pkl"
210216
with open(data_path, "rb") as f:
211217
data = pickle.load(f)
212218
obj = super().from_disk(path, **kwargs)
213219
obj.label_attr = data.get("label_attr", "label")
214220
obj.label2id = data.get("label2id", {})
215221
obj.id2label = data.get("id2label", {})
216-
return obj
222+
return obj

0 commit comments

Comments
 (0)