Skip to content

Commit 2d3a2aa

Browse files
committed
feat: add class_weight handling
1 parent bc0f05b commit 2d3a2aa

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

edsnlp/pipes/trainable/doc_classifier/doc_classifier.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,19 @@ def __init__(
4848
id2label: Optional[Dict[int, str]] = None,
4949
loss_fn=None,
5050
labels: Optional[Sequence[str]] = None,
51+
class_weights: Optional[Union[Dict[str, float], str]] = None,
5152
):
5253
self.label_attr: Attributes = label_attr
5354
self.label2id = label2id or {}
5455
self.id2label = id2label or {}
5556
self.labels = labels
57+
self.class_weights = class_weights
58+
5659
super().__init__(nlp, name)
5760
self.embedding = embedding
58-
self.loss_fn = loss_fn or torch.nn.CrossEntropyLoss()
61+
62+
self._loss_fn = loss_fn
63+
self.loss_fn = None
5964

6065
if not hasattr(self.embedding, "output_size"):
6166
raise ValueError(
@@ -65,6 +70,27 @@ def __init__(
6570
if num_classes:
6671
self.classifier = torch.nn.Linear(embedding_size, num_classes)
6772

73+
def _compute_class_weights(self, freq_dict: Dict[str, int]) -> torch.Tensor:
74+
"""
75+
Compute class weights from frequency dictionary.
76+
Uses inverse frequency weighting: weight = 1 / frequency
77+
"""
78+
total_samples = sum(freq_dict.values())
79+
80+
weights = torch.zeros(len(self.label2id))
81+
82+
for label, freq in freq_dict.items():
83+
if label in self.label2id:
84+
weight = total_samples / (len(self.label2id) * freq)
85+
weights[self.label2id[label]] = weight
86+
87+
return weights
88+
89+
def _load_class_weights_from_file(self, filepath: str) -> Dict[str, int]:
90+
"""Load class weights from pickle file."""
91+
with open(filepath, 'rb') as f:
92+
return pickle.load(f)
93+
6894
def set_extensions(self) -> None:
6995
super().set_extensions()
7096
if not Doc.has_extension(self.label_attr):
@@ -90,6 +116,22 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
90116
self.classifier = torch.nn.Linear(
91117
self.embedding.output_size, len(self.label2id)
92118
)
119+
120+
weight_tensor = None
121+
if self.class_weights is not None:
122+
if isinstance(self.class_weights, str):
123+
freq_dict = self._load_class_weights_from_file(self.class_weights)
124+
weight_tensor = self._compute_class_weights(freq_dict)
125+
elif isinstance(self.class_weights, dict):
126+
weight_tensor = self._compute_class_weights(self.class_weights)
127+
128+
print(f"Using class weights: {weight_tensor}")
129+
130+
if self._loss_fn is not None:
131+
self.loss_fn = self._loss_fn
132+
else:
133+
self.loss_fn = torch.nn.CrossEntropyLoss(weight=weight_tensor)
134+
93135
super().post_init(gold_data, exclude=exclude)
94136

95137
def preprocess(self, doc: Doc) -> Dict[str, Any]:
@@ -171,4 +213,4 @@ def from_disk(cls, path, **kwargs):
171213
obj.label_attr = data.get("label_attr", "label")
172214
obj.label2id = data.get("label2id", {})
173215
obj.id2label = data.get("id2label", {})
174-
return obj
216+
return obj

0 commit comments

Comments
 (0)