diff --git a/megatron/arguments.py b/megatron/arguments.py index 2211e17dc..869ebd5c1 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -925,8 +925,8 @@ def __call__(self, parser, args, values, option_string=None): 'specific positions. This option tries to un-bias the loss by reweighting loss on specific ' 'positions based on how frequently we train on that position.' 'This is mostly used for prefix_lm training') - group.add_argument("--noise_density", type=float, default=None, help="Span corruption noise density") - group.add_argument("--mean_noise_span_length", type=int, default=None, help="Span corruption mean noise span length") + group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density") + group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length") return parser diff --git a/megatron/data/mlm_dataset.py b/megatron/data/mlm_dataset.py index 024b4c345..87028d31e 100644 --- a/megatron/data/mlm_dataset.py +++ b/megatron/data/mlm_dataset.py @@ -3,7 +3,7 @@ import numpy as np import torch -from megatron import print_rank_0, get_tokenizer +from megatron import print_rank_0, get_tokenizer, get_args from megatron.data.blendable_dataset import BlendableDataset from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, get_split_by_range_ from megatron.data.dataset_utils import get_train_valid_test_split_, get_indexed_dataset_ @@ -297,13 +297,14 @@ def __init__( # according to `noise_density` and `mean_noise_span_length`. We can also define the label length accordingly. number_of_raw_tokens, inputs_length, targets_length, num_noise_spans = compute_input_and_target_lengths( # +1 is used so that we can compute the as autoregressive systems require us to add one more token. - sequence_length=self.sequence_length + 1, + sequence_length=self.sequence_length, noise_density=self.noise_density, mean_noise_span_length=self.mean_noise_span_length ) - self.number_of_raw_tokens = number_of_raw_tokens self.inputs_length = inputs_length - self.targets_length = targets_length + # As the loss we add a token at the end + self.number_of_raw_tokens = number_of_raw_tokens + 1 + self.targets_length = targets_length +1 self.num_noise_spans = num_noise_spans # Build the samples mapping. @@ -320,13 +321,24 @@ def __init__( # Vocab stuff. tokenizer = get_tokenizer() - self.sep_id = tokenizer.sep + # TODO @thomasw21 find if overloading eod is acceptable. + # self.sep_id = tokenizer.sep + self.sep_id = tokenizer.eod self.sentinel_token_ids = tokenizer.additional_special_tokens_ids + assert self.sep_id is not None, "MLM dataset requires tokenizer to have a token" assert len(self.sentinel_token_ids) > 0, "Provide the argument --vocab-extra-ids 100 to the script" assert len(self.sentinel_token_ids) >= self.num_noise_spans, "Not enough sentinel tokens, please add more" + args = get_args() + if hasattr(args, "encoder_seq_length") and args.encoder_seq_length is not None: + # T5 style + assert self.inputs_length == args.encoder_seq_length + assert self.targets_length == args.decoder_seq_length + 1 + else: + assert self.inputs_length + self.targets_length == args.seq_length + def __len__(self): - return len(self.samples_mapping) + return len(self._gpt_dataset) def __getitem__(self, idx): if isinstance(idx, slice): diff --git a/megatron/model/__init__.py b/megatron/model/__init__.py index baf54e455..0662d93bb 100644 --- a/megatron/model/__init__.py +++ b/megatron/model/__init__.py @@ -18,6 +18,7 @@ from .distributed import DistributedDataParallel from .bert_model import BertModel from .gpt_model import GPTModel, GPTModelPipe +from .shared_t5_model import SharedT5ModelPipe from .t5_model import T5Model from .language_model import get_language_model from .module import Float16Module diff --git a/megatron/model/gpt_model.py b/megatron/model/gpt_model.py index 31d33a91b..880fea7d1 100644 --- a/megatron/model/gpt_model.py +++ b/megatron/model/gpt_model.py @@ -21,7 +21,7 @@ from megatron import get_args from megatron import mpu from megatron.enums import AttnMaskType -from .module import MegatronModule, fp32_to_float16 +from .module import MegatronModule, fp32_to_16bit from .language_model import parallel_lm_logits from .language_model import get_language_model @@ -213,9 +213,9 @@ def __init__( def _to_float16(inputs): if args.fp16: - return fp32_to_float16(inputs, lambda v: v.half()) + return fp32_to_16bit(inputs, lambda v: v.half()) elif args.bf16: - return fp32_to_float16(inputs, lambda v: v.bfloat16()) + return fp32_to_16bit(inputs, lambda v: v.bfloat16()) else: return inputs diff --git a/megatron/model/module.py b/megatron/model/module.py index df92d95a9..3841dd140 100644 --- a/megatron/model/module.py +++ b/megatron/model/module.py @@ -122,7 +122,7 @@ def conversion_helper(val, conversion): return rtn -def fp32_to_float16(val, float16_convertor): +def fp32_to_16bit(val, float16_convertor): """Convert fp32 `val` to fp16/bf16""" def half_conversion(val): val_typecheck = val @@ -168,7 +168,7 @@ def float16_convertor(val): def forward(self, *inputs, **kwargs): if mpu.is_pipeline_first_stage(): - inputs = fp32_to_float16(inputs, self.float16_convertor) + inputs = fp32_to_16bit(inputs, self.float16_convertor) outputs = self.module(*inputs, **kwargs) if mpu.is_pipeline_last_stage(): outputs = float16_to_fp32(outputs) diff --git a/megatron/model/shared_t5_model.py b/megatron/model/shared_t5_model.py new file mode 100644 index 000000000..d944cb746 --- /dev/null +++ b/megatron/model/shared_t5_model.py @@ -0,0 +1,180 @@ +import torch +from deepspeed import PipelineModule +from deepspeed.runtime.pipe import TiedLayerSpec, LayerSpec +from torch.nn import LayerNorm + +from megatron.enums import AttnMaskType, LayerType + +from megatron.model.transformer import ParallelTransformerLayerPipe + +from megatron.model.language_model import EmbeddingPipe, parallel_lm_logits + +from megatron.model.utils import init_method_normal, scaled_init_method_normal + +from megatron import get_args, mpu + +from megatron.model.module import MegatronModule, fp32_to_16bit, float16_to_fp32 + +def cross_entropy(output, labels): + labels, loss_mask = labels[0], labels[1] + + losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels) + + expected_number_of_tokens = loss_mask.sum() + + loss_mask = loss_mask.view(-1) + loss = torch.sum(losses.view(-1) * loss_mask) / expected_number_of_tokens + return loss + +class SharedT5ModelPipe(PipelineModule, MegatronModule): + """Share encoder decoder language model.""" + + def __init__( + self, + num_tokentypes=0, + parallel_output=True, + ): + args = get_args() + self.parallel_output = parallel_output + + init_method = init_method_normal(args.init_method_std) + + self.specs = [] + + def _to_16bit(inputs): + if args.fp16: + return fp32_to_16bit(inputs, lambda v: v.half()) + elif args.bf16: + return fp32_to_16bit(inputs, lambda v: v.bfloat16()) + else: + return inputs + + self.specs.append(lambda inputss: tuple(_to_16bit(inputs) for inputs in inputss)) + + # Embedding layer + self.specs.append(TiedLayerSpec('embed', + EmbeddingPipe, + args.hidden_size, + args.padded_vocab_size, + args.hidden_dropout, + forward_fn=lambda module, input_and_target: (module(input_and_target[:3]), module(input_and_target[3:])), + init_method=init_method, + num_tokentypes=num_tokentypes, + tied_weight_attr='word_embeddings_weight')) + + assert hasattr(args, 'attn_mask'), "Deepspeed integration should have attention mask s" + # Drop everything beside tokens + # self.specs.append(lambda inputs, targets: (inputs[0], targets[0])) + if args.fp32_residual_connection: + self.specs.append(lambda input_and_target: (input_and_target[0].transpose(0, 1).contiguous().float(), input_and_target[1].transpose(0, 1).contiguous().float())) + else: + self.specs.append(lambda input_and_target: (input_and_target[0].transpose(0, 1).contiguous(), input_and_target[1].transpose(0, 1).contiguous())) + + ### ----- Encoder ----- + for layer_idx in range(args.num_layers): + self.specs.append( + TiedLayerSpec( + f"block_{layer_idx}", + ParallelTransformerLayerPipe, + init_method=init_method, + forward_fn=lambda module, input_and_target: (module(input_and_target[0]), input_and_target[1]), + output_layer_init_method=scaled_init_method_normal(args.init_method_std, + args.num_layers), + layer_type=LayerType.encoder, + layer_number=layer_idx, + self_attn_mask_type=AttnMaskType.causal, + tied_weight_attr=None, + tied_weight_attrs=["self_attention", "mlp"] + )) + + # Final layernorm after encoder layers + self.specs.append( + LayerSpec( + LayerNorm, + args.hidden_size, + forward_fn=lambda module, input_and_target: (module(input_and_target[0]), input_and_target[1]), + eps=args.layernorm_epsilon + )) + + # Decoder + for layer_idx in range(args.num_layers): + self.specs.append( + TiedLayerSpec( + f"block_{layer_idx}", + ParallelTransformerLayerPipe, + init_method=init_method, + forward_fn=lambda module, encoded_and_target: (encoded_and_target[0], module(encoded_and_target[1], encoder_output=encoded_and_target[0])), + output_layer_init_method=scaled_init_method_normal(args.init_method_std, + args.num_layers), + layer_number=layer_idx, + layer_type=LayerType.decoder, + self_attn_mask_type=AttnMaskType.padding, + tied_weight_attr=None, + tied_weight_attrs=["self_attention", "mlp"] + ) + ) + + # Drop encoded tokens + self.specs.append(lambda encoded_and_target: encoded_and_target[1]) + + # Final layernorm after decoder layers + self.specs.append( + LayerSpec( + LayerNorm, + args.hidden_size, + eps=args.layernorm_epsilon + )) + + # Undo data format change + self.specs.append(lambda x: x.transpose(0, 1).contiguous()) + + def _logits_helper(embedding, lm_output): + """A wrapper to massage inputs/outputs from pipeline. """ + return parallel_lm_logits( + lm_output, + embedding.word_embeddings_weight, + self.parallel_output) + + self.specs.append( + TiedLayerSpec('embed', + EmbeddingPipe, + args.hidden_size, + args.padded_vocab_size, + args.hidden_dropout, + init_method=init_method, + num_tokentypes=num_tokentypes, + forward_fn=_logits_helper, + tied_weight_attr='word_embeddings_weight') + ) + + # Convert to fp32 if needed + if args.fp16 or args.bf16: + self.specs.append(float16_to_fp32) + + if args.checkpoint_activations: + interval = args.checkpoint_num_layers + else: + interval = 0 + + from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology + topo = PipeModelDataParallelTopology(num_pp=mpu.get_pipeline_model_parallel_world_size(), + num_mp=mpu.get_tensor_model_parallel_world_size(), + num_dp=mpu.get_data_parallel_world_size()) + + # here one can extend the regex to include more layers to be counted towards partitioning, + # e.g. 'type:transformer|embedding' will add up all the transformer blocks and also the first + # and last embedding layers and then partition that transformers+2 layers - so to get a good + # balance you may want to use less transformer layers + # + # caveat emptor: the current implementation of PP fails unless each stage has at least one + # transformer layer + if args.pp_partition_method is not None: + partition_method = args.pp_partition_method + else: + partition_method = 'type:transformer' + + super().__init__(layers=self.specs, + loss_fn=cross_entropy, + topology=topo, + activation_checkpoint_interval=interval, + partition_method=partition_method) diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 7a98b5d35..bd0ec59d8 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -26,7 +26,7 @@ from megatron import get_args from megatron import get_tokenizer from megatron import mpu -from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model +from megatron.utils import get_attention_masks_and_position_ids, unwrap_model from megatron.p2p_communication import recv_forward, send_forward # These are needed to unwrap the model, would be nice to put these in megatron.utils if possible? @@ -42,7 +42,7 @@ def get_batch(context_tokens): # Move to GPU. tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().cuda() # Get the attention mask and position ids. - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + attention_mask, _, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, diff --git a/megatron/utils.py b/megatron/utils.py index 98d2f611c..29a4e9795 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -151,7 +151,8 @@ def check_adlr_autoresume_termination(iteration, model, sys.exit(0) -def get_ltor_masks_and_position_ids( + +def get_attention_masks_and_position_ids( data, eod_token, reset_position_ids, @@ -159,6 +160,7 @@ def get_ltor_masks_and_position_ids( eod_mask_loss, prefix_indices, loss_on_targets_only, + ltor=True, ): """ Build masks and position id for left to right model. @@ -177,9 +179,10 @@ def get_ltor_masks_and_position_ids( att_mask_batch = micro_batch_size else: att_mask_batch = 1 - attention_mask = torch.tril(torch.ones( - (att_mask_batch, seq_length, seq_length), device=data.device)).view( - att_mask_batch, 1, seq_length, seq_length) + attention_mask = torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) + if ltor: + attention_mask = torch.tril(attention_mask) + attention_mask = attention_mask.view(att_mask_batch, 1, seq_length, seq_length) # Loss mask. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 04f1b3b57..c1e9a11f0 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -25,7 +25,7 @@ from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain -from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices +from megatron.utils import get_attention_masks_and_position_ids, get_prefix_indices from megatron.utils import average_losses_across_data_parallel_group import deepspeed @@ -110,7 +110,7 @@ def get_batch(data_iterator): tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, @@ -141,7 +141,7 @@ def get_batch_pipe(data): tokens = tokens_[:, :-1].contiguous() # Get the masks and position ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, diff --git a/pretrain_prefix_lm.py b/pretrain_prefix_lm.py index 391186e75..0e9ed019a 100644 --- a/pretrain_prefix_lm.py +++ b/pretrain_prefix_lm.py @@ -25,7 +25,7 @@ from megatron.data.gpt_dataset import build_train_valid_test_datasets, build_dataset_group from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain -from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_ +from megatron.utils import get_attention_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_ from megatron.utils import average_losses_across_data_parallel_group import deepspeed @@ -97,7 +97,7 @@ def get_batch(data_iterator): ) # Get the masks and postition ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, @@ -131,6 +131,7 @@ def get_batch_pipe(data): tokens = tokens_[:, :-1].contiguous() # Prefix + # TODO @thomasw21 actually since this step is random, we need to make sure that random state are synchronized. Otherwise we need to broadcast after this step. prefix_indices = get_prefix_indices( tokens, tokenizer.eod, @@ -139,7 +140,7 @@ def get_batch_pipe(data): ) # Get the masks and position ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, diff --git a/pretrain_shared_t5_with_mlm.py b/pretrain_shared_t5_with_mlm.py new file mode 100644 index 000000000..c98ae8081 --- /dev/null +++ b/pretrain_shared_t5_with_mlm.py @@ -0,0 +1,167 @@ +import torch +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_tokenizer +from megatron import mpu +from megatron.data.mlm_dataset import build_train_valid_test_datasets, build_dataset_group +from megatron.model import SharedT5ModelPipe +from megatron.training import pretrain +from megatron.utils import get_attention_masks_and_position_ids + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage + +try: + from torch.distributed.elastic.multiprocessing.errors import record +except ImportError: + # noop + def record(fn): + return fn + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + see_memory_usage(f"Before Building Model", force=True) + + args = get_args() + + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): + if args.deepspeed: + # TODO @thomasw21: fix this for PP > 1 (the issue is that you're passing two values that require grad) + assert mpu.get_pipeline_model_parallel_world_size() == 1, "PP > 1 is not supported yet" + + # TODO @thomasw21 hack to bypass a specific check + args.attn_mask = None + + model = SharedT5ModelPipe( + num_tokentypes=0, + parallel_output=True + ) + # This is a hack to give us a reference to get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + else: + assert False, "Require deepspeed to be running" + see_memory_usage(f"After Building Model", force=True) + return model + + +def get_batch_pipe(data): + """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = ["input_tokens", "target_tokens"] + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + input_tokens = data_b["input_tokens"].long() + target_tokens = data_b["target_tokens"].long()[:,:-1].contiguous() + label_tokens = data_b["target_tokens"].long()[:,1:].contiguous() + + # Get the masks and position ids. + input_attention_mask, _, input_position_ids = get_attention_masks_and_position_ids( + input_tokens, + tokenizer.eod, + reset_position_ids=False, # TODO @thomasw21 doesn't work out of the box + reset_attention_mask=False, # TODO @thomasw21 doesn't work out of the box + eod_mask_loss=False, # TODO @thomasw21 doesn't work out of the box + prefix_indices=None, + loss_on_targets_only=False, + ltor=False + ) + target_attention_mask, target_loss_mask, target_position_ids = get_attention_masks_and_position_ids( + target_tokens, + tokenizer.eod, + reset_position_ids=False, # TODO @thomasw21 doesn't work out of the box + reset_attention_mask=False, # TODO @thomasw21 doesn't work out of the box + eod_mask_loss=False, # TODO @thomasw21 doesn't work out of the box + prefix_indices=None, + loss_on_targets_only=args.loss_on_targets_only, + ltor=True + ) + + return (input_tokens, input_position_ids, input_attention_mask, target_tokens, target_position_ids, target_attention_mask), (label_tokens, target_loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + + args = get_args() + train_ds, valid_ds, test_ds = None, None, None + + print_rank_0('> building train, validation, and test datasets for GPT ...') + # Option 1 of data loading using --data-path + + if args.data_path: + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + sequence_length=args.encoder_seq_length + args.decoder_seq_length, + noise_density=args.noise_density, + mean_noise_span_length=args.mean_noise_span_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup) + ) + # Option 2 of data loading using --(train|valid|test)-weighted-split-paths + elif args.train_weighted_split_paths: + assigned_train_valid_test = [] + if args.train_weighted_split_paths is not None: + train_ds = [] + assigned_train_valid_test.append("train") + if args.valid_weighted_split_paths is not None: + valid_ds = [] + assigned_train_valid_test.append("valid") + if args.test_weighted_split_paths is not None: + test_ds = [] + assigned_train_valid_test.append("test") + + for s in assigned_train_valid_test: + data_groups = zip(eval(f"args.{s}_weighted_split_paths"), + eval(f"args.{s}_weighted_split_weights"), + eval(f"args.{s}_weighted_split_splits"), + eval(f"args.{s}_weighted_split_names")) + for paths, weights, splits, name in data_groups: + d = build_dataset_group( + dataset_group_name=name, + paths=paths, + weights=weights, + splits=splits, + data_impl=args.data_impl, + train_valid_test_num_samples=train_val_test_num_samples, + seq_length=args.encoder_seq_length + args.decoder_seq_length, + noise_density=args.noise_density, + mean_noise_span_length=args.mean_noise_span_length, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + train_valid_test=s + ) + eval(f"{s}_ds").append(d) + else: + raise NotImplementedError("No dataloading argument passed") + + print_rank_0("> finished creating MLM datasets ...") + return train_ds, valid_ds, test_ds + +@record +def main(): + pretrain( + train_valid_test_datasets_provider, + model_provider, + # TODO @thomasw21: make it work without DS. + forward_step_func=None, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) + +if __name__ == "__main__": + main() diff --git a/scripts/test_multiple_dataset_sampling/test_sampling.py b/scripts/test_multiple_dataset_sampling/test_sampling.py index 2d5326c8c..8bed75c2a 100644 --- a/scripts/test_multiple_dataset_sampling/test_sampling.py +++ b/scripts/test_multiple_dataset_sampling/test_sampling.py @@ -25,7 +25,7 @@ from megatron.data.gpt_dataset import build_train_valid_test_datasets from megatron.model import GPTModel, GPTModelPipe from megatron.training import pretrain -from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import get_attention_masks_and_position_ids from megatron.utils import average_losses_across_data_parallel_group import deepspeed @@ -117,7 +117,7 @@ def get_batch(data_iterator): tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, @@ -144,7 +144,7 @@ def get_batch_pipe(data): tokens = tokens_[:, :-1].contiguous() # Get the masks and postition ids. - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, diff --git a/tasks/zeroshot_gpt/evaluate.py b/tasks/zeroshot_gpt/evaluate.py index 090533c24..b17c76848 100644 --- a/tasks/zeroshot_gpt/evaluate.py +++ b/tasks/zeroshot_gpt/evaluate.py @@ -26,7 +26,7 @@ from megatron.checkpointing import load_checkpoint from megatron.model.gpt_model import GPTModel from megatron.training import get_model -from megatron.utils import get_ltor_masks_and_position_ids, unwrap_model +from megatron.utils import get_attention_masks_and_position_ids, unwrap_model from megatron.p2p_communication import recv_forward, send_forward from tasks.finetune_utils import build_data_loader @@ -72,7 +72,7 @@ def process_batch(batch): tokens = tokens_[:, :-1].contiguous() # Get the masks and position ids. - attention_mask, _, position_ids = get_ltor_masks_and_position_ids( + attention_mask, _, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, diff --git a/tests/test_tensor_parallel.py b/tests/test_tensor_parallel.py index 25921c12a..a0d257404 100644 --- a/tests/test_tensor_parallel.py +++ b/tests/test_tensor_parallel.py @@ -18,7 +18,7 @@ from multiprocessing import Pool from megatron.checkpointing import save_checkpoint -from megatron.utils import get_ltor_masks_and_position_ids +from megatron.utils import get_attention_masks_and_position_ids @require_deepspeed @require_torch_multi_gpu @@ -98,7 +98,7 @@ def infer_model(args): def create_model_inputs(tokens): args = get_args() - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + attention_mask, loss_mask, position_ids = get_attention_masks_and_position_ids( tokens, tokenizer.eod, args.reset_position_ids, diff --git a/tests/test_training.py b/tests/test_training.py index c77cb9af2..1d4e7efc9 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -592,3 +592,137 @@ def test_skip_train_iteration(self): train_iterations = range(1,10) for i in train_iterations: self.assertTrue(f"iteration {i:8d}/" in cs.out) + + def test_pretrain_shared_t5_mlm(self): + # all in one test + src_dir = self.src_dir + data_dir = f"{self.data_dir}/gpt2" + output_dir = self.get_auto_remove_tmp_dir() # "./xxx", after=False) + logs_dir = f"{output_dir}/logs" + Path(logs_dir).mkdir(parents=True, exist_ok=True) + + pp_size, tp_size, dp_size = get_3d_dimensions() + num_gpus = pp_size * tp_size * dp_size + + # TODO @thomasw21 fix once t5 supports pipeline parallelism + dp_size *= pp_size + pp_size = 1 + + n_samples = 200 # about 37 iterations + exit_interval = 20 # some samples in the first half and then some more in the 2nd half after resume + noise_density=0.15 + mean_noise_span_length=3 + encoder_seq_length = 512 + decoder_seq_length = 114 # imposed by `noise_density=0.15` and `input_sequence_length = 512` + + + args = f""" + --tensor-model-parallel-size {tp_size} + --pipeline-model-parallel-size {pp_size} + --distributed-backend nccl + + --num-layers 2 + --hidden-size 64 + --num-attention-heads 2 + --decoder-seq-length {decoder_seq_length} + --encoder-seq-length {encoder_seq_length} + --max-position-embeddings 1024 + --micro-batch-size 1 + --rampup-batch-size 2 2 {n_samples} + --global-batch-size 16 + --train-samples {n_samples} + + --optimizer adam + --adam-beta1 0.9 + --adam-beta2 0.95 + --adam-eps 1e-8 + --lr 1e-4 + --lr-warmup-samples 5 + --clip-grad 1.0 + --weight-decay 1e-1 + --fp16 + + --log-interval 5 + --save-interval 10 + --eval-interval 10 + --eval-iters 5 + --checkpoint-activations + --exit-interval {exit_interval} + + --tokenizer-type PretrainedFromHF + --tokenizer-name-or-path gpt2 + --vocab-extra-ids 100 + --log-path {logs_dir} + --save {output_dir}/checkpoints + --load {output_dir}/checkpoints + --data-path {data_dir}/meg-gpt2-openwebtext_text_document + --noise-density {noise_density} + --mean-noise-span-length {mean_noise_span_length} + --tensorboard-dir {output_dir}/tensorboard + --tensorboard-queue-size 5 + --log-timers-to-tensorboard + --log-batch-size-to-tensorboard + --log-validation-ppl-to-tensorboard + + --log-level debug + """.split() + + ds_args = f""" + --deepspeed + --deepspeed_config {self.test_file_dir_str}/ds_config.json + --zero-stage 1 + --deepspeed-activation-checkpointing + """.split() + + script = [f"{src_dir}/pretrain_shared_t5_with_mlm.py"] + launcher = get_launcher(num_gpus) + + cmd = launcher + script + args + ds_args + # keep for quick debug + # print(" ".join([f"\nPYTHONPATH={self.src_dir_str}"] +cmd)); die + + # 1. test training from scratch (no checkpoint) + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # test deepspeed is running + self.assertIn("DeepSpeed info", cs.out) + + # test reports + self.assertIn("consumed samples", cs.out) + + # test there should be no checkpoint this round + self.assertIn(f"Unable to find latest file at {output_dir}/checkpoints/latest", cs.out) + + # test checkpoint saving + self.assertIn("successfully saved checkpoint at iteration", cs.out) + + # test tensorboard + tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*") + self.assertEqual(len(tensorboard_files), 1, "tensorboard files") + + # 2. test training from checkpoint: resume + # now do it again, this time resuming from the checkpoint + with CaptureStdout() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + + # test checkpoint loading + self.assertIn(f"successfully loaded checkpoint from {output_dir}/checkpoints", cs.out) + + # test reports + self.assertIn("consumed samples", cs.out) + + # test checkpoint saving + self.assertIn("successfully saved checkpoint at iteration", cs.out) + + # test tensorboard (1 file from the first run, plus 1 now) + tensorboard_files = glob.glob(f"{output_dir}/tensorboard/events*") + self.assertEqual(len(tensorboard_files), 2, "tensorboard files") + + def test_convert_from_gpt2_to_shared_t5_lm(self): + # Test structure: + # - run `pretrain_gpt.py` in order to obtain GPT checkpoint + # - run conversion script in order to obtain a shared t5 checkpoint + # - run `pretrain_shared_t5_with_mlm.py` from that checkpoint + raise NotImplementedError("TODO @thomasw21 write script in order to convert gpt2 checkpoint") + diff --git a/tools/convert_checkpoint/convert_gpt_to_shared_t5.py b/tools/convert_checkpoint/convert_gpt_to_shared_t5.py new file mode 100644 index 000000000..9030cf273 --- /dev/null +++ b/tools/convert_checkpoint/convert_gpt_to_shared_t5.py @@ -0,0 +1,60 @@ +import argparse +import re +from functools import partial +from multiprocessing import Pool +from pathlib import Path + +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--gpt-checkpoint-path", type=Path, required=True) + parser.add_argument("--output-shared-t5-path", type=Path, required=True) + parser.add_argument("--only-weights", type=bool, action="store_true") + parser.add_argument("--num-proc", type=int, default=1) + return parser.parse_args() + +def get_shared_t5_file(gpt_path_file: Path, output_shared_t5_path: Path) -> Path: + """"Given a GPT checkpoint file path, get the equivalent shared T5 checkpoint path""" + raise NotImplementedError() + +def get_shared_t5_weight_name(gpt_weight_name: str) -> str: + """Given a GPT checkpoint weight name, get the equivalent shated T5 checkpoint weight name""" + raise NotImplementedError() + +def map_gpt_weights_to_shared_t5_weights(filename: Path, output_shared_t5_path: Path): + gpt_weights = torch.load(filename) + + shared_t5_filename = get_shared_t5_file(filename, output_shared_t5_path=output_shared_t5_path) + shared_t5_weights = {} + for name, weight in gpt_weights.items(): + shared_t5_weight_name = get_shared_t5_weight_name(name) + shared_t5_weights[shared_t5_weight_name] = weight + + torch.save(shared_t5_weights, shared_t5_filename) + +IS_WEIGHT_REGEX=re.compile(r"layer_[\d]{2}-model_[\d]{2}-model_states.pt") +def is_weight_file(filename: Path) -> bool: + if filename.is_dir(): + return False + + basename = filename.name + return IS_WEIGHT_REGEX.match(basename) is not None + +def main(): + args = get_args() + + weight_files = [filename for filename in args.gpt_checkpoint_path.iterdir() if is_weight_file(filename)] + if args.num_proc == 1: + for weight_file in weight_files: + map_gpt_weights_to_shared_t5_weights(weight_file, output_shared_t5_path=args.output_shared_t5_path) + else: + with Pool(args.num_proc) as pool: + pool.map( + partial(map_gpt_weights_to_shared_t5_weights, output_shared_t5_path=args.output_shared_t5_path), + weight_files + ) + +if __name__ == "__main__": + main() \ No newline at end of file