Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit e502764

Browse files
committed
[Bug Fix] trainer.update(1) should be used after loss.mean() is called
1 parent 61ec270 commit e502764

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

scripts/sentiment_analysis/sentiment_analysis_cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def train(net, train_data, test_data, dev_data=None):
159159
L = loss(output, label).mean()
160160
L.backward()
161161
# Update parameter.
162-
trainer.step(args.batch_size)
162+
trainer.step(1)
163163
log_interval_L += L.asscalar()
164164
epoch_L += L.asscalar()
165165
if (i + 1) % args.log_interval == 0:

0 commit comments

Comments
 (0)