Skip to content

Commit 780c584

Browse files
committed
fix sum pooling
1 parent a938317 commit 780c584

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

edsnlp/pipes/trainable/embeddings/doc_pooler/doc_pooler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def forward(self, batch: DocPoolerBatchInput) -> DocPoolerBatchOutput:
9595
elif self.pooling_mode == "max":
9696
pooled = embeds.max(dim=1).values
9797
elif self.pooling_mode == "sum":
98-
pooled = embeds.sum(dim=1)
98+
pooled = embeds.sum(dim=1) / embeds.size(1)
9999
elif self.pooling_mode == "cls":
100100
pooled = self.embedding(batch["embedding"])["cls"].to(device)
101101
else:

0 commit comments

Comments
 (0)