Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,8 @@ def __call__(self, parser, args, values, option_string=None):
help='Warm up mmap files.')
group.add_argument('--num-workers', type=int, default=2,
help="Dataloader number of workers.")
group.add_argument('--valid-num-workers', type=int, default=2,
help="Dataloader number of workers for validation.")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
Expand Down
9 changes: 6 additions & 3 deletions megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from megatron import mpu


def build_pretraining_data_loader(dataset, consumed_samples):
def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
"""Buld dataloader given an input dataset."""

if dataset is None:
Expand All @@ -48,10 +48,13 @@ def build_pretraining_data_loader(dataset, consumed_samples):
raise Exception('{} dataloader type is not supported.'.format(
args.dataloader_type))

if num_workers is None:
num_workers = args.num_workers

# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=args.num_workers,
num_workers=num_workers,
pin_memory=True)

class MegatronPretrainingSampler:
Expand Down Expand Up @@ -141,7 +144,7 @@ def __iter__(self):
* self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
Expand Down
7 changes: 6 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,12 @@ def build_train_valid_test_data_iterators(

# We collapse None and empty list as both should mean we don't run validation
# args.consumed_valid_samples accumulates the sum of valid steps for every dataset, which are all equal
valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds))
#
# XXX: we get a deadlock in the dataloader on multi-dataset eval, after the first dataset,
# possibly due to this bug in pytorch https://github.com/pytorch/pytorch/pull/25158. Using
# num_workers=0 to work around it - the training can't use that since it impacts throughput
# by a few percent
valid_dataloaders = [build_pretraining_data_loader(d, args.consumed_valid_samples // len(valid_ds), num_workers=args.valid_num_workers)
for d in valid_ds] \
if valid_ds is not None else []
# We collapse None and empty list as both should mean we don't run test
Expand Down