Skip to content

Commit 853f980

Browse files
committed
fix?
1 parent 174f86a commit 853f980

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

edsnlp/pipes/trainable/span_classifier/span_classifier.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def __init__(
198198
context_getter: Optional[SpanGetterArg] = None,
199199
values: Optional[Dict[str, List[Any]]] = None,
200200
keep_none: bool = False,
201-
weights: Dict[str, list] = {},
202201
):
203202
attributes: Attributes
204203
if attributes is None and qualifiers is None:
@@ -231,9 +230,6 @@ def __init__(
231230
(k if k.startswith("_.") else f"_.{k}", v, [])
232231
for k, v in attributes.items()
233232
]
234-
self.weights = {
235-
k: torch.tensor(v, device=self.device) for k, v in weights.items()
236-
} # FIXME? mettre dans un autre endroit ?
237233

238234
super().__init__(nlp, name, span_getter=span_getter)
239235
self.embedding = embedding
@@ -524,7 +520,11 @@ def collate(self, batch: Dict[str, Sequence[Any]]) -> SpanClassifierBatchInput:
524520
return collated
525521

526522
# noinspection SpellCheckingInspection
527-
def forward(self, batch: SpanClassifierBatchInput) -> BatchOutput:
523+
def forward(
524+
self,
525+
batch: SpanClassifierBatchInput,
526+
weights: Dict[str, list] = {},
527+
) -> BatchOutput:
528528
"""
529529
Apply the span classifier module to the document embeddings and given spans to:
530530
- compute the loss
@@ -551,12 +551,16 @@ def forward(self, batch: SpanClassifierBatchInput) -> BatchOutput:
551551
pred = []
552552
losses = None
553553

554+
weights = {
555+
k: torch.tensor(v, device=self.device) for k, v in weights.items()
556+
} # FIXME? mettre dans un autre endroit ?
557+
554558
# For each group, for instance:
555559
# - `event=start` and `event=stop`
556560
# - `negated=False` and `negated=True`
557561
for group_idx, bindings_indexer in enumerate(self.bindings_indexers):
558562
if "targets" in batch:
559-
weight = self.weights.get(self.bindings[group_idx][0])
563+
weight = weights.get(self.bindings[group_idx][0])
560564
mask = torch.all(batch["targets"][:, group_idx] != -100, axis=1)
561565
losses.append(
562566
F.cross_entropy(

0 commit comments

Comments
 (0)