Skip to content

Commit d2e1f39

Browse files
committed
fix: update ner_crf & span_classifier params in place in post_init to avoid optimizer issues
1 parent 89c1f77 commit d2e1f39

File tree

5 files changed

+13
-7
lines changed

5 files changed

+13
-7
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
- We now support `words[-10:10]` syntax in trainable span classifier `context_getter` parameter
2424
- :ambulance: Until now, `post_init` was applied **after** the instantiation of the optimizer : if the model discovered new labels, and therefore changed its parameter tensors to reflect that, these new tensors were not taken into account by the optimizer, which could likely lead to subpar performance. Now, `post_init` is applied **before** the optimizer is instantiated, so that the optimizer can correctly handle the new tensors.
2525
- Added missing entry points for readers and writers in the registry, including `write_parquet` and support for `polars` in `pyproject.toml`. Now all implemented readers and writers are correctly registered as entry points.
26+
- Parameters are now updated *in place* by "post_init" is run in `eds.ner_crf` and `eds.span_classifier`, and are therefore correctly taken into account by the optimizer.
2627

2728
### Changed
2829

edsnlp/pipes/trainable/ner_crf/ner_crf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -362,9 +362,9 @@ def update_labels(self, labels: Sequence[str]):
362362
new_linear = torch.nn.Linear(self.embedding.output_size, len(labels) * 5)
363363
new_linear.weight.data[new_index] = self.linear.weight.data[old_index]
364364
new_linear.bias.data[new_index] = self.linear.bias.data[old_index]
365-
self.linear.weight = new_linear.weight
365+
self.linear.weight.data = new_linear.weight.data
366366
self.linear.out_features = new_linear.out_features
367-
self.linear.bias = new_linear.bias
367+
self.linear.bias.data = new_linear.bias.data
368368

369369
# Update initialization arguments
370370
self.labels = labels

edsnlp/pipes/trainable/span_classifier/span_classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def update_bindings(self, bindings: List[Tuple[str, SpanFilter, List[Any]]]):
401401
new_index = [new_bindings_to_idx[label] for label in common]
402402
new_linear = torch.nn.Linear(self.classifier.in_features, len(new_bindings))
403403
new_linear.weight.data[new_index] = self.classifier.weight.data[old_index]
404-
self.classifier.weight = new_linear.weight
404+
self.classifier.weight.data = new_linear.weight.data
405405
self.classifier.out_features = new_bindings
406406
missing_bindings = set(new_bindings) - set(old_bindings)
407407
if missing_bindings and len(old_bindings) > 0:
@@ -412,7 +412,7 @@ def update_bindings(self, bindings: List[Tuple[str, SpanFilter, List[Any]]]):
412412

413413
if hasattr(self.classifier, "bias"):
414414
new_linear.bias.data[new_index] = self.classifier.bias.data[old_index]
415-
self.classifier.bias = new_linear.bias
415+
self.classifier.bias.data = new_linear.bias.data
416416

417417
def simplify_indexer(indexer):
418418
return (

edsnlp/training/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,8 @@ def train(
728728
)
729729
)
730730
accelerator.print(
731-
f"Keeping frozen {len(all_params - grad_params):} weight tensors "
731+
("! WARNING ! " if (len(all_params - grad_params) > 0) else "")
732+
+ f"Keeping frozen {len(all_params - grad_params):} weight tensors "
732733
f"({sum(p.numel() for p in all_params - grad_params):,} parameters)"
733734
)
734735

tests/training/ner_qlf_same_bert_config.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,17 @@ scorer:
5656

5757
# 🎛️ OPTIMIZER
5858
optimizer:
59-
"@core": optimizer
59+
"@core": optimizer !draft
6060
optim: AdamW
6161
module: ${ nlp }
6262
groups:
6363
# Transformer
6464
- selector: "transformer"
65-
exclude: true
65+
lr:
66+
"@schedules": linear
67+
start_value: 1e-5
68+
max_value: 2e-5
69+
warmup_rate: 0.5
6670
- selector: ".*"
6771
lr: 1e-3
6872

0 commit comments

Comments
 (0)