-
Notifications
You must be signed in to change notification settings - Fork 416
feat(input_pipeline): Add support for chunking long sequences instead truncation #2354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feature! And great unit tests! Just some minor comments.
src/MaxText/configs/base.yml
Outdated
@@ -458,6 +458,7 @@ tokenize_train_data: True # False if the dataset is pre-tokenized | |||
tokenize_eval_data: True # False if the dataset is pre-tokenized | |||
add_bos: True | |||
add_eos: True | |||
use_truncation: True # If False, use chunking for long sequences instead of truncation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you point to the class for users who are interested in the detailed implementation.
Thanks for the great feedback! I've pushed the changes addressing all your points:
to: @aireenmei |
@@ -95,9 +95,9 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra | |||
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) | |||
|
|||
assert len(data_columns) == 1 | |||
rekey_dict = {"inputs": "text", "targets": "text"} | |||
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict)) | |||
text_column = "text" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can have text_column = data_columns[0]? Let's move the rekey_dict definition to right before the Rekey op in line 131.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. I've updated the variable to text_column
and moved the rekey_dict
as requested. The changes have been pushed.
@@ -28,30 +28,26 @@ class TokenizeAndTrim(grain.MapTransform): | |||
"""Tokenize and trim features to sequence length.""" | |||
|
|||
# pylint: disable=attribute-defined-outside-init | |||
feature_names: str | Sequence[str] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's keep TokenizeAndTrim as it was to allow supporting multiple columns, I think it's needed for the DPO support (#947)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the context about the DPO support. I've reverted TokenizeAndTrim
back to its original implementation (but refactored as recommended below) to support multiple columns, and I've pushed the change.
Looks like the github actions tests need to be triggered by a maintainer. Please take a look at the test failures. You can also run them locally |
) | ||
) | ||
else: | ||
dataset = dataset.apply( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for using dataset.apply
and not dataset.map
similar to how TokenizeAndTrim
uses at line 119?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is because of MapDataset
's map
only takes 1:1 mapping transformations (return MapMapDataset
) but apply
supports 1:N transformations (return FlatMapMapDataset
). You can find it here: https://github.com/google/grain/blob/main/grain/_src/python/dataset/transformations/map.py
I also want it to support this with map
for consistency 😂.
def map(
self, transform: transforms.MapTransform | Callable[[T], S]
) -> MapDataset[S]:
"""Returns a dataset containing the elements transformed by ``transform``.
Example usage::
ds = MapDataset.range(5)
ds = ds.map(lambda x: x + 10)
list(ds) == [10, 11, 12, 13, 14]
Args:
transform: Either a ``MapTransform`` containing the ``map`` method or a
callable that takes an element and returns a new element.
Returns:
A dataset containing the elements of the original dataset transformed by
``transform``.
"""
# Loaded lazily due to a circular dependency (dataset <-> map).
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import (
map as map_dataset,
)
# pylint: enable=g-import-not-at-top
return map_dataset.MapMapDataset(parent=self, transform=transform)
def apply(
self,
transformations: transforms.Transformation | transforms.Transformations,
) -> IterDataset:
"""Returns a dataset with the given transformation(s) applied.
Syntactic sugar to avoid dispatch by transformation type.
Example usage::
ds = grain.MapDataset.range(5).to_iter_dataset()
ds = ds.apply([AddOne(), grain.transforms.Batch(2)])
list(ds) == [np.ndarray([1, 2]), np.ndarray([3, 4]), np.ndarray([5])]
Args:
transformations: one or more transformations to apply.
Returns:
Dataset with the given transformations applied.
"""
return apply_transformations(self, transformations)
def apply_transformations(
ds: _ConsistentDatasetType,
transformations: transforms.Transformation | transforms.Transformations,
) -> _ConsistentDatasetType:
"""Applies transformations to a dataset.
DEPRECATED: Use `ds.apply(transformations)` instead.
Args:
ds: `MapDataset` or `IterDataset` to apply the transformations to.
transformations: one or more transformations to apply.
Returns:
Dataset of the same type with transformations applied.
"""
if not isinstance(transformations, Sequence):
transformations = (transformations,)
for transformation in transformations:
match transformation:
case transforms.Batch():
ds = ds.batch(
transformation.batch_size,
drop_remainder=transformation.drop_remainder,
)
case transforms.MapTransform():
ds = ds.map(transformation)
case transforms.RandomMapTransform():
ds = ds.random_map(transformation)
case transforms.MapWithIndex():
ds = ds.map_with_index(transformation)
case transforms.FlatMapTransform():
# Loaded lazily due to a circular dependency (dataset <-> flatmap).
# pylint: disable=g-import-not-at-top
from grain._src.python.dataset.transformations import flatmap
# pylint: enable=g-import-not-at-top
if isinstance(ds, MapDataset):
ds = flatmap.FlatMapMapDataset(ds, transformation)
else:
ds = flatmap.FlatMapIterDataset(ds, transformation)
case transforms.Filter():
ds = ds.filter(transformation)
case _:
raise NotImplementedError(
f"Transformation type: {transformation} is not supported."
)
return ds
- Added comment that TokenizeAndChunk removes all columns except the text_column
- Modified _grain_tokenizer.py with latest changes
- Added note that use_truncation=False is only available in grain's pretrain preprocessing pipeline
- Move feature_names, sequence_length, add_bos, add_eos, and tokenizer to TokenizerTransformBase - Consolidate initialization logic in base class __post_init__ - Simplify TokenizeAndTrim and TokenizeAndChunk by removing duplicate parameters - Add common _encode method to eliminate code duplication - Maintain backward compatibility and specialized behavior for each class
This PR introduces "chunking" as an alternative to "truncation" in the Grain input pipeline.
Previously, the
TokenizeAndTrim
operation (MapTransform
) would truncate any document longer thanmax_target_length
, discarding all subsequent tokens. This change introduces a newTokenizeAndChunk
operation (FlatMapTransform
) that splits a single long document into multiple training examples, each no longer thanmax_target_length
.This new behavior is controlled by a new configuration flag,
use_truncation
.Why is this change being made?
The default truncation behavior is highly data-inefficient for corpora with many long documents (like C4). It wastes significant amounts of data, compute, and storage, and may bias the model by only ever training on the beginning of documents.
The problem being solved and any relevant context:
This PR solves the problem of data loss during tokenization for long sequences. By using a 1:N
FlatMapTransform
, we can map one long input document to a list of multiple, valid training chunks, ensuring 100% of the tokenized data is used.Why this is a good solution:
This solution is efficient and flexible. It utilizes the
FlatMapTransform
provided by Grain, which is designed for this 1:N mapping. It is also fully backwards-compatible, as the new chunking behavior is "opt-in" by settinguse_truncation = False
in the config. The default behavior remains truncation.Some information about the specific implementation:
_grain_tokenizer.py
: A newTokenizeAndChunk
class has been added. It inherits fromgrain.experimental.FlatMapTransform
and implements theflat_map
method to split a list of token IDs into multiple chunks._grain_data_processing.py
: Thepretrain_preprocessing_pipeline
function has been updated with a conditional check forconfig.use_truncation
:True
, it uses the existingdataset.map(TokenizeAndTrim(...))
.False
, it usesdataset.apply(TokenizeAndChunk(...))
.dataset.apply()
method and support forFlatMapTransform
are recent features in Grain. This PR requires a version of Grain installed directly from the main branch.Shortcomings of the solution and possible future improvements.
The
max_fan_out
attribute inTokenizeAndChunk
is set with a class-level default (2048
). If a document is exceptionally long and produces more chunks than this, it will error. This could be exposed as a configuration option in the future if needed.Tests
This change is tested with a new, self-contained unit test file:
tests/tokenizer_transform_test.py
.MockTokenizer
to provide known, deterministic tokenization ("a b c" -> [1, 2, 3]
).grain.MapDataset.source
with a small, known dataset to test edge cases (short text, long text, and multi-chunk text).test_tokenize_and_trim
: Verifies the original 1:1 truncation logic is correct.test_tokenize_and_chunk
: Verifies the new 1:N chunking logic (e.g., an input with 7 tokens andmax_len=5
correctly produces two new examples with 5 and 2 tokens).test_trim_and_pad_chaining
: Verifies that the output ofTokenizeAndTrim
can be correctly chained into a subsequentPadToMaxLength
transform.test_chunk_and_pad_chaining
: Verifies that all outputs fromTokenizeAndChunk
are correctly chained intoPadToMaxLength
(e.g., both the 5-token chunk and the 2-token chunk are correctly padded).To reproduce, you can run the new test file directly:
Fixes: #2344
Checklist
Before submitting this PR, please make sure (put X in square brackets):