Skip to content

Conversation

lintangsutawika
Copy link
Contributor

No description provided.

Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a quick look at MLMDataset. Thanks again for tackling this issue! Will read more when I have more time.

'They are used for span masking in the T5 model')
group.add_argument('--seq-length', type=int, default=None,
help='Maximum sequence length to process.')
group.add_argument('--input-length', type=int, default=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels like a weird argument given that you already set seq-length

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I've understood, seq-length also initializes the max position embeddings. So for MLM, a seq-length of 512 would result in a final sequence length of 626.

I didn't know how to handle this discrepancy so I opted to make a new input-length to handle 512 seq length and seq-length to handle the final length of 626.

Comment on lines +1 to +14
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed

from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset


def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference with

def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
?

If the goal is to replace GPTDataset with NonCausalMLMDataset it seems like a lot of duplication.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, Megatron-LM already has some sort of T5 dataset, why not use that one? https://github.com/NVIDIA/Megatron-LM/blob/d898a8991d1a08d29074f87819d1bf41517e35f5/megatron/data/t5_dataset.py#L29

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adjusted build_train_valid_test_datasets input arguments.

And the T5 dataset didn't count for prefix length. Also, my understanding is that get_samples_mapping in T5Dataset doesn't always return the required sequence length without paddings. So I adapted from the original T5 codebase and also made a get_samples_mapping that would concat every obtained sample to the desired input length.


tokenizer = AutoTokenizer.from_pretrained('bigscience/tokenizer')

tokenizer.add_special_tokens({
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this could be additional config to _AutoTokenizer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to changed that since each user might have different preference for their tokenizer.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum but you are making some assumptions when building the dataset self.sentinel_tokens = tokenizer.additional_special_tokens_ids ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, which didn't get used. Maybe it's better to remove that line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They did get used (otherwise your MLM implementation sounds wrong, if you don't provide a list of sentinel tokens)

Comment on lines +23 to +24
INPUT_LEN=1675
TARGET_LEN=373
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking we need a way to compute these values from a given SEQ_LEN. Typically given a noise_density, mean_noise_span_length, and sequence_length we should be able to compute an input and target no? The reason why, is because what we really care about is that SEQ_LEN is 2048 (for performance), the rest we don';t really care as they are implementation details.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. But I'm not sure where to put this function in.

Comment on lines +421 to +433
tokens_length = inputs_length

while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
tokens_length += 1

inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)

# minor hack to get the targets length to be equal to inputs length
# which is more likely to have been set to a nice round number.
if noise_density == 0.5 and targets_length > inputs_length:
tokens_length -= 1
targets_length -= 1
return tokens_length, targets_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
tokens_length = inputs_length
while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length:
tokens_length += 1
inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length)
# minor hack to get the targets length to be equal to inputs length
# which is more likely to have been set to a nice round number.
if noise_density == 0.5 and targets_length > inputs_length:
tokens_length -= 1
targets_length -= 1
return tokens_length, targets_length
tokens_length = inputs_length
while sum(_tokens_length_to_inputs_length_targets_length(tokens_length)) > inputs_length:
tokens_length -= 1
inputs_length, targets_length = tokens_length_to_inputs_length_targets_length(tokens_length)
# minor hack to get the targets length to be equal to inputs length
# which is more likely to have been set to a nice round number.
if noise_density == 0.5 and targets_length > inputs_length:
tokens_length -= 1
targets_length -= 1
# tokens_length is the number of raw tokens we need to get
# inputs_length will be the input
# targets_length will be the target
return tokens_length, inputs_length, targets_length

So typically:

>>> compute_input_and_target_lengths(2048, 0.15, 3)
(1860, 1675, 373)

And so you really need to only pass only the sequence length in argument.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, actually this solves the above arguments discussion. Thanks for the input.

Comment on lines 348 to 350
@property
def eod(self):
return self.tokenizer.eos_token_id
Copy link
Member

@thomasw21 thomasw21 Jun 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we've overloaded <eos> as it signals:

  • end of document
  • seperation between input and target in case of MLM.

Since we already use eos as eod, can we create a new special token then? Maybe use sep instead? cc @SaulLu as she's our guru.

labels_sentinel = create_sentinel_ids(labels_mask.astype(np.int8), vocab_len=len(vocab_id_list))

tokens = np.asarray([tokens])
input_tokens_ids = filter_input_ids(tokens, input_ids_sentinel, eos_id)[0]
Copy link
Member

@thomasw21 thomasw21 Jun 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically this shouldn't use eos_id as we use it already to seperate documents.

current_len += _len

print_rank_0(' > done building sapmles index maping')
np.save(indexmap_filename, samples_mapping, allow_pickle=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm mistaken, you're missing a shuffling step here. I think it's quite important to shuffle it, otherwise you have a single document within a single batch and I believe @leandro and @loubnabnl have seen poor performances when you don't actually chuffle..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the assumption that any pretraining data would already be shuffled or atleast not have any continuity between samples? Wouldn't shuffling during training also happen within the dataloader and not during data preparation in the dataset object?

I'm not sure how to shuffle in this part during initialization of the dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's shuffled here for gpt dataset:

shuffle_idx = _build_shuffle_idx(num_samples_,

So there's a few things:

  • document level shuffling, which can be assumed I guess
  • row level shuffling, ie once you create your segments of 2048 in length. Since you process documents sequentially, consecutive rows are often from the same document and that's usually an issue because you build a batch from consecutive rows. So in order to change that, gpt dataset had that shuffling index that allows you shuffle everything up together (should be the link I shared) before saving.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataLoader should not be loading things in random order:

def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
for idx in range(self.consumed_samples, self.total_samples):
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []
# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I can add a shuffling process then.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@thomasw21 I guess you mentioned by here by accident.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops @leandro my bad, I meant @lvwerra

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we saw that very long documents in the pretraining corpus could dominate a single batch leading to significantly slower convergence.

Comment on lines +304 to +325
for doc_idx, sample_len in zip(indexed_dataset.doc_idx, indexed_dataset.sizes):
_idx = 0

if current_len + sample_len > max_len:
end_idx = max_len - current_len
sample_indices.append([doc_idx, 0, end_idx])
samples_mapping.append(sample_indices)
sample_indices = []
current_len = 0
sample_len -= end_idx
_idx = end_idx

break_len = current_len + sample_len

indices = breakdown(sample_len, max_len=max_len)
for _start_idx, _end_idx in indices:
_len = _end_idx - _start_idx
if _len == max_len:
samples_mapping.append([[doc_idx, _start_idx+_idx, _end_idx+_idx]])
else:
sample_indices.append([doc_idx, _start_idx+_idx, _end_idx+_idx])
current_len += _len
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you doing the same thing as helpers.build_sample_idx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concatenating sequences up the desired sequence length like that in run_t5_mlm_flax.py

https://github.com/huggingface/transformers/blob/6589e510fa4e6c442059de2fab84752535de9b23/examples/flax/language-modeling/run_t5_mlm_flax.py#L678-L691

I noticed that get_samples_mapping from dataset_utils.py doesn't always return a sequence with the desired length which then requires padding but that would mess the t5-style MLM.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum this feels very close to what I expect helpers.build_sample_idx does. There's a python version of that same code:

def _build_sample_idx(sizes, doc_idx, seq_length,

Comment on lines +385 to +387
input_ids = np.concatenate(
[input_ids, np.full((batch_size, 1), eos_id, dtype=np.int32)], axis=-1
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait this feels like input_ids would be 1676 instead of 1675. I don't see where you make sure that it's 1675 so that the final sequence length is 2048

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can re-run to make sure.

Comment on lines +280 to +282
indexmap_filename = data_prefix
indexmap_filename += '_{}_indexmap'.format(name)
indexmap_filename += '.npy'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change the naming, something that mentions it's mlm dataset? Ideally we could have a index for gpt dataset and one for mlm dataset based on the same dataset.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also stupid question, why do you have only one file, GPT-dataset seems to generate 3 files: samples / index / shuffle.

for doc_idx, start_index, end_index in indices:
sample.append(self.indexed_dataset.get(doc_idx)[start_index:end_index])

return build_training_sample(
Copy link
Member

@thomasw21 thomasw21 Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should cache that result. ie create a number_of_samples x 2048 array with already preprocessed MLM. I can clearly see that step being a bit slow, and so loading directly from an array should make things much faster? Actually scratch that it requires some thinking as you might not be able to load the whole dataset in memory, so you need to read partially from disk ...

@Muennighoff
Copy link
Collaborator

This PR has been fully superseded by mlm & mtf, correct?

@lintangsutawika
Copy link
Contributor Author

Yes, we can close this if needed.

@thomasw21
Copy link
Member

thomasw21 commented Jun 26, 2022

Let's keep it. There's still the code from train scripts. And we can close it when everything is merged.

@lintangsutawika
Copy link
Contributor Author

Closing as the MTF and MLM has been merged in separate PRs

@lintangsutawika lintangsutawika deleted the mt0 branch July 22, 2022 03:28
adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Dec 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants