Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions finetune_t0_non_causal_decoder.py → finetune_t0.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def model_provider(pre_process=True, post_process=True):
see_memory_usage(f"After Building Model", force=True)
return model


def get_batch_pipe(data):
"""
Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator` & in packed fashion
Expand Down Expand Up @@ -83,13 +84,13 @@ def get_batch_pipe(data):
)
# Only compute loss over causal target tokens, i.e. ignore input_tokens & padding
loss_on_targets_only = ~data_c["decoder_is_inputs"][:, 1:]
loss_on_non_pad_only = (tokens != tokenizer.pad)
loss_on_non_pad_only = (labels != tokenizer.pad)
loss_mask *= loss_on_targets_only * loss_on_non_pad_only

attention_mask = get_packed_attention_mask(
# Run non-causal decoder
is_causal=False,
causal_mask=~(causal_mask.bool()),
is_causal=not(args.prefixlm),
causal_mask=~(causal_mask.bool()), # Turn back into tril being ones
decoder_is_inputs=decoder_is_inputs.bool(),
segment_ids=segment_ids.long(),
)
Expand Down
6 changes: 5 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ def _add_regularization_args(parser):
'numerical stability')
group.add_argument('--sgd-momentum', type=float, default=0.9,
help='Momentum factor for sgd')

group.add_argument('--bitfit', action='store_true',
help='Use BitFit')
return parser


Expand Down Expand Up @@ -664,6 +665,8 @@ def _add_checkpointing_args(parser):
help='Do not load optimizer when loading checkpoint.')
group.add_argument('--no-load-rng', action='store_true', default=None,
help='Do not load rng state when loading checkpoint.')
group.add_argument('--reset-progress', action='store_true', default=None,
help='Reset iteration to 0 & do not load args.')
group.add_argument('--finetune', action='store_true',
help='Load model for finetuning. Do not load optimizer '
'or rng state from checkpoint and set iteration to 0. '
Expand Down Expand Up @@ -935,6 +938,7 @@ def __call__(self, parser, args, values, option_string=None):
'This is mostly used for prefix_lm training')
group.add_argument("--noise-density", type=float, default=None, help="Span corruption noise density")
group.add_argument("--mean-noise-span-length", type=int, default=None, help="Span corruption mean noise span length")
group.add_argument("--prefixlm", action='store_true', help="Whether to train a PrefixLM - To be used with finetune t0")


return parser
Expand Down
10 changes: 5 additions & 5 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
load_dir = getattr(args, load_arg)

if args.deepspeed:
load_optimizer_states = False if args.no_load_optim else True
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states)
load_optimizer_states = not args.no_load_optim
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_module_only=not load_optimizer_states, load_optimizer_states=load_optimizer_states, load_lr_scheduler_states=load_optimizer_states)
if loaded_dir is None:
print_rank_0('WARNING: could not find the metadata file {} '.format(
load_dir))
Expand Down Expand Up @@ -342,7 +342,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
set_checkpoint_version(state_dict.get('checkpoint_version', 0))

# Set iteration.
if args.finetune or release:
if args.finetune or release or args.reset_progress:
iteration = 0
else:
try:
Expand All @@ -361,7 +361,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Check arguments.
assert args.consumed_train_samples == 0
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
if 'args' in state_dict and not args.reset_progress:
checkpoint_args = state_dict['args']
if not args.universal_checkpoint:
check_checkpoint_args(checkpoint_args)
Expand Down Expand Up @@ -480,4 +480,4 @@ def _checkpoint_info():
return {
"padded_vocab_size": args.padded_vocab_size,
"original_vocab_size": tokenizer.vocab_size,
}
}
23 changes: 12 additions & 11 deletions megatron/data/decoder_packed_mtf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def pack_samples(self, items):
decoder_tokens[cur_len: cur_len + input_token_len] = token_dict["input_tokens"]
decoder_tokens[cur_len + input_token_len: cur_len + total_len] = token_dict["target_tokens"]
decoder_segment_ids[cur_len: cur_len + total_len] = item_num
decoder_is_inputs[cur_len: cur_len + input_token_len] = 1 # inputs
decoder_is_inputs[cur_len: cur_len + input_token_len] = True # inputs
# targets are already 0 at init, no need to update `decoder_is_inputs`

item_num += 1
Expand Down Expand Up @@ -399,7 +399,7 @@ def _build_index_mappings(
shuffle_idx_filename = _filename + '_decoder_packed_shuffle_idx.npy'

# Build the indexed mapping if not exist.
if torch.distributed.get_rank() == 0:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
if (not os.path.isfile(sample_idx_filename)) or \
(not os.path.isfile(shuffle_idx_filename)):

Expand Down Expand Up @@ -437,15 +437,16 @@ def _build_index_mappings(
print_rank_0(' > elasped time to build and save shuffle-idx and sample-idx mapping'
' (seconds): {:4f}'.format(time.time() - start_time))

# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))
if torch.distributed.is_initialized():
# This should be a barrier but nccl barrier assumes
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group())
assert counts[0].item() == (
torch.distributed.get_world_size() //
torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group()))

# Load mappings.
start_time = time.time()
Expand Down
3 changes: 2 additions & 1 deletion megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,8 @@ def _to_float16(inputs):
args.num_layers),
layer_number=layer_idx,
# TODO: Change naming of class from GPT to something that encapsulate prefix lm.
self_attn_mask_type=attn_mask_type))
self_attn_mask_type=attn_mask_type)
)

# Undo data format change
def undo(x):
Expand Down
19 changes: 18 additions & 1 deletion megatron/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@ def _get_params_for_weight_decay_optimization(modules):

#return weight_decay_params, no_weight_decay_params

def _get_params_for_bitfit(modules):
"""
Get params to do BitFit
See https://arxiv.org/abs/2106.10199
"""
# No weight (bias) decay is used
no_weight_decay_params = {'params': [], 'weight_decay': 0.0}
for module in modules:
for module_ in module.modules():
no_weight_decay_params['params'].extend(
[p for n, p in list(module_._parameters.items())
if p is not None and 'bias' in n])

return no_weight_decay_params,

def get_megatron_optimizer(model):
args = get_args()
Expand All @@ -67,7 +81,10 @@ def get_megatron_optimizer(model):
raise NotImplementedError('need to add cpu adam')

# Base optimizer.
param_groups = _get_params_for_weight_decay_optimization(model)
if args.bitfit:
param_groups = _get_params_for_bitfit(model)
else:
param_groups = _get_params_for_weight_decay_optimization(model)
if args.optimizer == 'adam':
if args.use_bnb_optimizer:
import bitsandbytes as bnb
Expand Down
3 changes: 2 additions & 1 deletion megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ def mask(self):

@property
def bos(self):
raise NotImplementedError("Missing <bos>")
candidate = self.tokenizer.bos_token_id
return self._check_token_candidate(candidate)

@property
def eos(self):
Expand Down
30 changes: 23 additions & 7 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,11 +261,18 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
- segment_ids: torch.IntTensor [batch_size, sequence_length]
Returns:
- attention_mask: torch.BoolTensor [batch_size, 1, sequence_length, sequence_length]

Input example for the mask examples:
att_mask_batch = 1
seq_length = 7
decoder_is_inputs = torch.tensor([[1, 1, 0, 1, 1, 0, 0]])
segment_ids = torch.tensor([[1, 1, 1, 2, 2, 2, 0]])
causal_mask = torch.tril(torch.ones(att_mask_batch, seq_length, seq_length)).view(att_mask_batch, 1, seq_length, seq_length)
"""

"""Causal Inputs Mask:
mask = [[[[1, 1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
mask = [[[[1, 1, 0, 1, 1, 0, 0],
[1, 1, 0, 1, 1, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 0, 0],
Expand Down Expand Up @@ -299,7 +306,7 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]
[0, 0, 0, 0, 0, 0, 1]]]]
"""
segment_mask = segment_ids[:, None, :, None] == segment_ids[:, None, None, :]

Expand All @@ -311,13 +318,22 @@ def get_packed_attention_mask(is_causal: bool, causal_mask: torch.Tensor, decode
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]

If is_causal=True:
mask = [[[[1, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 0, 0],
[0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0]]]]

"""
attention_mask = causal_inputs_mask * padding_mask * segment_mask

# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
attention_mask = causal_inputs_mask * padding_mask * segment_mask

return attention_mask
# True for places we do not want to attend to
return ~attention_mask

def param_size(parameter):
return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import deepspeed
import torch

import finetune_t0_non_causal_decoder
import finetune_t0
from megatron import global_vars, get_tokenizer, initialize_megatron, get_args
from megatron.data import mlm_dataset, mtf_dataset, decoder_packed_mtf_dataset
from megatron.data.data_samplers import build_pretraining_data_loader
Expand Down Expand Up @@ -241,7 +241,7 @@ def test_decoder_packed_mtf_dataloader(self):
last_padding_size = len([None for segment_id in items["decoder_segment_ids"][micro_batch_size - 1] if segment_id == 0])


def test_finetune_t0_non_causal_decoder_get_batch_pipe(self):
def test_finetune_t0_get_batch_pipe(self):
command_args = get_default_args()
command_args["--position-embedding-type"] = "alibi"

Expand All @@ -263,7 +263,7 @@ def test_finetune_t0_non_causal_decoder_get_batch_pipe(self):
special_tokens_ids={tokenizer.pad}
)

(tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0_non_causal_decoder.get_batch_pipe(data)
(tokens, position_ids, attention_mask), (labels, loss_mask) = finetune_t0.get_batch_pipe(data)

tokens = tokens.cpu()
position_ids = position_ids.cpu()
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from megatron.training import setup_model_and_optimizer
import pretrain_gpt
import pretrain_prefix_lm
import finetune_t0_non_causal_decoder
import finetune_t0


def get_default_args(test_file_dir: str):
Expand Down Expand Up @@ -456,7 +456,7 @@ def test_non_causal_decoder_model_with_packed_input_passed_with_attention_mask_i
vocab_size=args.padded_vocab_size,
special_tokens_ids={tokenizer.pad}
)
model, _, _ = setup_model_and_optimizer(finetune_t0_non_causal_decoder.model_provider)
model, _, _ = setup_model_and_optimizer(finetune_t0.model_provider)
model = model[0]
model._config.train_micro_batch_size_per_gpu = args.micro_batch_size
model.set_train_batch_size(args.micro_batch_size)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def test_training_t0(self):
--deepspeed-activation-checkpointing
""".split()

script = [f"{self.src_dir}/finetune_t0_non_causal_decoder.py"]
script = [f"{self.src_dir}/finetune_t0.py"]
launcher = get_launcher(num_gpus)

cmd = launcher + script + args + ds_args
Expand Down
13 changes: 11 additions & 2 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,18 @@ def encode(self, json_line):
ids = {}
for key in self.args.json_keys:
text = data[key]
if self.args.prepend_space:
text = f" {text}"
doc_ids = []
for sentence in Encoder.splitter.tokenize(text):
sentence_ids = Encoder.tokenizer.tokenize(sentence)
if len(sentence_ids) > 0:
doc_ids.append(sentence_ids)
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
if len(doc_ids) > 0:
if self.args.append_eod:
doc_ids[-1].append(Encoder.tokenizer.eod)
elif self.args.append_bos:
doc_ids[-1].append(Encoder.tokenizer.bos)
ids[key] = doc_ids
return ids, len(json_line)

Expand Down Expand Up @@ -117,6 +122,10 @@ def get_args():
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--append-bos', action='store_true',
help='Append a bos token to the end of a document.')
group.add_argument('--prepend-space', action='store_true',
help='Prepends a space to the beginning of a document')
group.add_argument("--tokenizer-name-or-path", type=str, default=None,
help="Name or path of the huggingface tokenizer.")
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
Expand Down