Skip to content

Commit d9bca63

Browse files
committed
fix trainer.fill_flat_stats causing max recursion error
1 parent 2ad50a5 commit d9bca63

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

edsnlp/resources/verbs.csv.gz

196 KB
Binary file not shown.

edsnlp/training/trainer.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,26 @@ def flatten_dict(d, path=""):
8585

8686

8787
def fill_flat_stats(x, result, path=()):
88-
print("fill_flat_stats", path, type(x), x)
88+
# Only recurse into dicts and lists/tuples
8989
if result is None:
9090
result = {}
9191
if isinstance(x, dict):
9292
for k, v in x.items():
9393
fill_flat_stats(v, result, (*path, k))
9494
return result
95+
if isinstance(x, (list, tuple)):
96+
for idx, v in enumerate(x):
97+
fill_flat_stats(v, result, (*path, str(idx)))
98+
return result
99+
# Only accumulate numbers (int, float, or 0-dim tensor)
95100
if "stats" in path and "__batch_hash__" not in path[-1]:
96-
path = "/".join(path)
97-
result[path] = result.get(path, 0) + x
101+
if isinstance(x, (int, float)):
102+
path_str = "/".join(path)
103+
result[path_str] = result.get(path_str, 0) + x
104+
elif hasattr(x, "item") and callable(x.item):
105+
# For 0-dim tensors
106+
path_str = "/".join(path)
107+
result[path_str] = result.get(path_str, 0) + x.item()
98108
return result
99109

100110

@@ -646,7 +656,8 @@ def train(
646656
*(
647657
td(nlp, device).set_processing(
648658
num_cpu_workers=num_workers,
649-
process_start_method="spawn",
659+
# process_start_method="spawn",
660+
backend="simple",
650661
)
651662
for td in phase_training_data
652663
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ where = ["."]
282282
"eds.ner_overlap" = "edsnlp.metrics.ner:NerOverlapMetric"
283283
"eds.span_attributes" = "edsnlp.metrics.span_attributes:SpanAttributeMetric"
284284
"eds.dep_parsing" = "edsnlp.metrics.dep_parsing:DependencyParsingMetric"
285-
"eds.doc_classif" = "edsnlp.metrics.doc_classif:DocClassificationMetric"
285+
"eds.doc_classif" = "edsnlp.metrics.doc_classif:DocClassificationMetric"
286286

287287
# Deprecated
288288
"eds.ner_exact_metric" = "edsnlp.metrics.ner:NerExactMetric"

0 commit comments

Comments
 (0)