@@ -48,14 +48,19 @@ def __init__(
48
48
id2label : Optional [Dict [int , str ]] = None ,
49
49
loss_fn = None ,
50
50
labels : Optional [Sequence [str ]] = None ,
51
+ class_weights : Optional [Union [Dict [str , float ], str ]] = None ,
51
52
):
52
53
self .label_attr : Attributes = label_attr
53
54
self .label2id = label2id or {}
54
55
self .id2label = id2label or {}
55
56
self .labels = labels
57
+ self .class_weights = class_weights
58
+
56
59
super ().__init__ (nlp , name )
57
60
self .embedding = embedding
58
- self .loss_fn = loss_fn or torch .nn .CrossEntropyLoss ()
61
+
62
+ self ._loss_fn = loss_fn
63
+ self .loss_fn = None
59
64
60
65
if not hasattr (self .embedding , "output_size" ):
61
66
raise ValueError (
@@ -65,6 +70,27 @@ def __init__(
65
70
if num_classes :
66
71
self .classifier = torch .nn .Linear (embedding_size , num_classes )
67
72
73
+ def _compute_class_weights (self , freq_dict : Dict [str , int ]) -> torch .Tensor :
74
+ """
75
+ Compute class weights from frequency dictionary.
76
+ Uses inverse frequency weighting: weight = 1 / frequency
77
+ """
78
+ total_samples = sum (freq_dict .values ())
79
+
80
+ weights = torch .zeros (len (self .label2id ))
81
+
82
+ for label , freq in freq_dict .items ():
83
+ if label in self .label2id :
84
+ weight = total_samples / (len (self .label2id ) * freq )
85
+ weights [self .label2id [label ]] = weight
86
+
87
+ return weights
88
+
89
+ def _load_class_weights_from_file (self , filepath : str ) -> Dict [str , int ]:
90
+ """Load class weights from pickle file."""
91
+ with open (filepath , 'rb' ) as f :
92
+ return pickle .load (f )
93
+
68
94
def set_extensions (self ) -> None :
69
95
super ().set_extensions ()
70
96
if not Doc .has_extension (self .label_attr ):
@@ -90,6 +116,22 @@ def post_init(self, gold_data: Iterable[Doc], exclude: Set[str]):
90
116
self .classifier = torch .nn .Linear (
91
117
self .embedding .output_size , len (self .label2id )
92
118
)
119
+
120
+ weight_tensor = None
121
+ if self .class_weights is not None :
122
+ if isinstance (self .class_weights , str ):
123
+ freq_dict = self ._load_class_weights_from_file (self .class_weights )
124
+ weight_tensor = self ._compute_class_weights (freq_dict )
125
+ elif isinstance (self .class_weights , dict ):
126
+ weight_tensor = self ._compute_class_weights (self .class_weights )
127
+
128
+ print (f"Using class weights: { weight_tensor } " )
129
+
130
+ if self ._loss_fn is not None :
131
+ self .loss_fn = self ._loss_fn
132
+ else :
133
+ self .loss_fn = torch .nn .CrossEntropyLoss (weight = weight_tensor )
134
+
93
135
super ().post_init (gold_data , exclude = exclude )
94
136
95
137
def preprocess (self , doc : Doc ) -> Dict [str , Any ]:
@@ -171,4 +213,4 @@ def from_disk(cls, path, **kwargs):
171
213
obj .label_attr = data .get ("label_attr" , "label" )
172
214
obj .label2id = data .get ("label2id" , {})
173
215
obj .id2label = data .get ("id2label" , {})
174
- return obj
216
+ return obj
0 commit comments