Skip to content

Commit ab9a503

Browse files
JiadeXin2021chensuyue
authored andcommitted
fix a bug of main.py (#518)
* fix a bug of main.py * fix bugs of non-distributed training (cherry picked from commit de69c2d)
1 parent 9ccee3e commit ab9a503

File tree

3 files changed

+26
-18
lines changed

3 files changed

+26
-18
lines changed

examples/pytorch/eager/image_recognition/imagenet/cpu/prune/main.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
help='number per node for distributed training')
7171
parser.add_argument('--seed', default=None, type=int,
7272
help='seed for initializing training. ')
73-
parser.add_argument('--keep-batch-size', dest='keep-batch-size',i
73+
parser.add_argument('--keep-batch-size', dest='keep_batch_size',
7474
action='store_true',
7575
help='keep the batch size rather than scale lr')
7676

@@ -98,6 +98,7 @@ def main_worker(args):
9898

9999
if args.distributed:
100100
hvd.init()
101+
print(hvd.size(), args.world_size, args.num_per_node)
101102
assert(hvd.size() == args.world_size * args.num_per_node)
102103

103104
# create model
@@ -278,7 +279,11 @@ def train(train_loader, model, criterion, optimizer, epoch, args, op):
278279
batch_time.update(time.time() - end)
279280
end = time.time()
280281

281-
if i % args.print_freq == 0 and hvd.rank() == 0:
282+
283+
if (i % args.print_freq == 0
284+
and (not args.distributed
285+
or (args.distributed and hvd.rank() == 0
286+
))):
282287
progress.print(i)
283288

284289
if args.iteration > 0 and i > args.iteration:
@@ -315,10 +320,13 @@ def validate(val_loader, model, criterion, args):
315320
top1.update(acc1[0], input.size(0))
316321
top5.update(acc5[0], input.size(0))
317322

318-
if i % args.print_freq == 0 and hvd.rank() == 0:
323+
if (i % args.print_freq == 0
324+
and (not args.distributed
325+
or (args.distributed and hvd.rank() == 0
326+
))):
319327
progress.print(i)
320328

321-
if hvd.rank() == 0:
329+
if not args.distributed or (args.distributed and hvd.rank() == 0):
322330
# TODO: this should also be done with the ProgressMeter
323331
print('Accuracy: {top1:.5f} Accuracy@5 {top5:.5f}'
324332
.format(top1=(top1.avg / 100), top5=(top5.avg / 100)))
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
torch
22
torchvision
3+
horovod
Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
1-
horovodrun -np 2
2-
python -u main.py \
3-
/path/to/imagenet/ \
4-
--topology resnet18 \
5-
--prune \
6-
--config conf.yaml \
7-
--pretrained \
8-
--output-model model_final.pth \
9-
--world-size 1 \
10-
--num-per-node 2 \
11-
--batch-size 256 \
12-
--keep-batch-size \
13-
--lr 0.001 \
14-
--iteration 30 \
1+
horovodrun -np 2 python -u main.py \
2+
/path/to/imagenet/ \
3+
--topology resnet18 \
4+
--prune \
5+
--config conf.yaml \
6+
--pretrained \
7+
--output-model model_final.pth \
8+
--world-size 1 \
9+
--num-per-node 2 \
10+
--batch-size 256 \
11+
--keep-batch-size \
12+
--lr 0.001 \
13+
--iteration 30 \

0 commit comments

Comments
 (0)