Skip to content

Commit cb48bd2

Browse files
authored
[valid] deadlock workaround (#282)
* [valid] deadlock workaround * parameterize valid num workers
1 parent 908dc9c commit cb48bd2

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

megatron/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,8 @@ def __call__(self, parser, args, values, option_string=None):
897897
help='Warm up mmap files.')
898898
group.add_argument('--num-workers', type=int, default=2,
899899
help="Dataloader number of workers.")
900+
group.add_argument('--valid-num-workers', type=int, default=2,
901+
help="Dataloader number of workers for validation.")
900902
group.add_argument('--tokenizer-type', type=str,
901903
default=None,
902904
choices=['BertWordPieceLowerCase',

megatron/data/data_samplers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from megatron import mpu
2323

2424

25-
def build_pretraining_data_loader(dataset, consumed_samples):
25+
def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
2626
"""Buld dataloader given an input dataset."""
2727

2828
if dataset is None:
@@ -48,10 +48,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
4848
raise Exception('{} dataloader type is not supported.'.format(
4949
args.dataloader_type))
5050

51+
if num_workers is None:
52+
num_workers = args.num_workers
53+
5154
# Torch dataloader.
5255
return torch.utils.data.DataLoader(dataset,
5356
batch_sampler=batch_sampler,
54-
num_workers=args.num_workers,
57+
num_workers=num_workers,
5558
pin_memory=True)
5659

5760
class MegatronPretrainingSampler:
@@ -141,7 +144,7 @@ def __iter__(self):
141144
* self.micro_batch_size
142145
bucket_offset = current_epoch_samples // self.data_parallel_size
143146
start_idx = self.data_parallel_rank * bucket_size
144-
147+
145148
g = torch.Generator()
146149
g.manual_seed(self.epoch)
147150
random_idx = torch.randperm(bucket_size, generator=g).tolist()

megatron/training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,12 @@ def build_train_valid_test_data_iterators(
11321132

11331133
# We collapse None and empty list as both should mean we don't run validation
11341134
# args.consumed_valid_samples accumulates the sum of valid steps for every dataset, which are all equal
1135-
valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds))
1135+
#
1136+
# XXX: we get a deadlock in the dataloader on multi-dataset eval, after the first dataset,
1137+
# possibly due to this bug in pytorch https://github.com/pytorch/pytorch/pull/25158. Using
1138+
# num_workers=0 to work around it - the training can't use that since it impacts throughput
1139+
# by a few percent
1140+
valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds), num_workers=args.valid_num_workers)
11361141
for d in valid_ds] \
11371142
if valid_ds is not None else []
11381143
# We collapse None and empty list as both should mean we don't run test

0 commit comments

Comments
 (0)