@@ -198,7 +198,6 @@ def __init__(
198
198
context_getter : Optional [SpanGetterArg ] = None ,
199
199
values : Optional [Dict [str , List [Any ]]] = None ,
200
200
keep_none : bool = False ,
201
- weights : Dict [str , list ] = {},
202
201
):
203
202
attributes : Attributes
204
203
if attributes is None and qualifiers is None :
@@ -231,9 +230,6 @@ def __init__(
231
230
(k if k .startswith ("_." ) else f"_.{ k } " , v , [])
232
231
for k , v in attributes .items ()
233
232
]
234
- self .weights = {
235
- k : torch .tensor (v , device = self .device ) for k , v in weights .items ()
236
- } # FIXME? mettre dans un autre endroit ?
237
233
238
234
super ().__init__ (nlp , name , span_getter = span_getter )
239
235
self .embedding = embedding
@@ -524,7 +520,11 @@ def collate(self, batch: Dict[str, Sequence[Any]]) -> SpanClassifierBatchInput:
524
520
return collated
525
521
526
522
# noinspection SpellCheckingInspection
527
- def forward (self , batch : SpanClassifierBatchInput ) -> BatchOutput :
523
+ def forward (
524
+ self ,
525
+ batch : SpanClassifierBatchInput ,
526
+ weights : Dict [str , list ] = {},
527
+ ) -> BatchOutput :
528
528
"""
529
529
Apply the span classifier module to the document embeddings and given spans to:
530
530
- compute the loss
@@ -551,12 +551,16 @@ def forward(self, batch: SpanClassifierBatchInput) -> BatchOutput:
551
551
pred = []
552
552
losses = None
553
553
554
+ weights = {
555
+ k : torch .tensor (v , device = self .device ) for k , v in weights .items ()
556
+ } # FIXME? mettre dans un autre endroit ?
557
+
554
558
# For each group, for instance:
555
559
# - `event=start` and `event=stop`
556
560
# - `negated=False` and `negated=True`
557
561
for group_idx , bindings_indexer in enumerate (self .bindings_indexers ):
558
562
if "targets" in batch :
559
- weight = self . weights .get (self .bindings [group_idx ][0 ])
563
+ weight = weights .get (self .bindings [group_idx ][0 ])
560
564
mask = torch .all (batch ["targets" ][:, group_idx ] != - 100 , axis = 1 )
561
565
losses .append (
562
566
F .cross_entropy (
0 commit comments