Skip to content

Conversation

bzantium
Copy link
Contributor

@bzantium bzantium commented Sep 16, 2025

This PR introduces "chunking" as an alternative to "truncation" in the Grain input pipeline.

Previously, the TokenizeAndTrim operation (MapTransform) would truncate any document longer than max_target_length, discarding all subsequent tokens. This change introduces a new TokenizeAndChunk operation (FlatMapTransform) that splits a single long document into multiple training examples, each no longer than max_target_length.

This new behavior is controlled by a new configuration flag, use_truncation.

  • Why is this change being made?
    The default truncation behavior is highly data-inefficient for corpora with many long documents (like C4). It wastes significant amounts of data, compute, and storage, and may bias the model by only ever training on the beginning of documents.

  • The problem being solved and any relevant context:
    This PR solves the problem of data loss during tokenization for long sequences. By using a 1:N FlatMapTransform, we can map one long input document to a list of multiple, valid training chunks, ensuring 100% of the tokenized data is used.

  • Why this is a good solution:
    This solution is efficient and flexible. It utilizes the FlatMapTransform provided by Grain, which is designed for this 1:N mapping. It is also fully backwards-compatible, as the new chunking behavior is "opt-in" by setting use_truncation = False in the config. The default behavior remains truncation.

  • Some information about the specific implementation:

    1. _grain_tokenizer.py: A new TokenizeAndChunk class has been added. It inherits from grain.experimental.FlatMapTransform and implements the flat_map method to split a list of token IDs into multiple chunks.
    2. _grain_data_processing.py: The pretrain_preprocessing_pipeline function has been updated with a conditional check for config.use_truncation:
      • If True, it uses the existing dataset.map(TokenizeAndTrim(...)).
      • If False, it uses dataset.apply(TokenizeAndChunk(...)).
    3. Requirement: The dataset.apply() method and support for FlatMapTransform are recent features in Grain. This PR requires a version of Grain installed directly from the main branch.
      pip install git+https://github.com/google/grain.git
  • Shortcomings of the solution and possible future improvements.
    The max_fan_out attribute in TokenizeAndChunk is set with a class-level default (2048). If a document is exceptionally long and produces more chunks than this, it will error. This could be exposed as a configuration option in the future if needed.

Tests

This change is tested with a new, self-contained unit test file: tests/tokenizer_transform_test.py.

  • This test does not require real data (like C4) or JAX/TPU.
  • It uses a MockTokenizer to provide known, deterministic tokenization ("a b c" -> [1, 2, 3]).
  • It uses an in-memory grain.MapDataset.source with a small, known dataset to test edge cases (short text, long text, and multi-chunk text).
  • Four separate test cases were added to verify the logic:
    1. test_tokenize_and_trim: Verifies the original 1:1 truncation logic is correct.
    2. test_tokenize_and_chunk: Verifies the new 1:N chunking logic (e.g., an input with 7 tokens and max_len=5 correctly produces two new examples with 5 and 2 tokens).
    3. test_trim_and_pad_chaining: Verifies that the output of TokenizeAndTrim can be correctly chained into a subsequent PadToMaxLength transform.
    4. test_chunk_and_pad_chaining: Verifies that all outputs from TokenizeAndChunk are correctly chained into PadToMaxLength (e.g., both the 5-token chunk and the 2-token chunk are correctly padded).

To reproduce, you can run the new test file directly:

python -m unittest tests/tokenizer_transform_test.py

Fixes: #2344

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@aireenmei aireenmei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feature! And great unit tests! Just some minor comments.

@@ -458,6 +458,7 @@ tokenize_train_data: True # False if the dataset is pre-tokenized
tokenize_eval_data: True # False if the dataset is pre-tokenized
add_bos: True
add_eos: True
use_truncation: True # If False, use chunking for long sequences instead of truncation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you point to the class for users who are interested in the detailed implementation.

@bzantium
Copy link
Contributor Author

Thanks for the great feedback!

I've pushed the changes addressing all your points:

  • base.yml comment: Added the comment to the use_truncation flag to point to the implementation classes as you suggested.
  • Tokenizer Refactoring:
    • I've simplified both TokenizeAndTrim and TokenizeAndChunk to operate on a single text_column.
    • The Rekey transform has been moved to execute after the tokenization step for both code paths, which cleans up the pipeline logic nicely.

to: @aireenmei

@@ -95,9 +95,9 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))

assert len(data_columns) == 1
rekey_dict = {"inputs": "text", "targets": "text"}
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
text_column = "text"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have text_column = data_columns[0]? Let's move the rekey_dict definition to right before the Rekey op in line 131.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I've updated the variable to text_column and moved the rekey_dict as requested. The changes have been pushed.

@@ -28,30 +28,26 @@ class TokenizeAndTrim(grain.MapTransform):
"""Tokenize and trim features to sequence length."""

# pylint: disable=attribute-defined-outside-init
feature_names: str | Sequence[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep TokenizeAndTrim as it was to allow supporting multiple columns, I think it's needed for the DPO support (#947)

Copy link
Contributor Author

@bzantium bzantium Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the context about the DPO support. I've reverted TokenizeAndTrim back to its original implementation (but refactored as recommended below) to support multiple columns, and I've pushed the change.

@aireenmei
Copy link
Collaborator

Looks like the github actions tests need to be triggered by a maintainer. Please take a look at the test failures. You can also run them locally

)
)
else:
dataset = dataset.apply(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for using dataset.apply and not dataset.map similar to how TokenizeAndTrim uses at line 119?

Copy link
Contributor Author

@bzantium bzantium Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because of MapDataset's map only takes 1:1 mapping transformations (return MapMapDataset) but apply supports 1:N transformations (return FlatMapMapDataset). You can find it here: https://github.com/google/grain/blob/main/grain/_src/python/dataset/transformations/map.py
I also want it to support this with map for consistency 😂.

  def map(
      self, transform: transforms.MapTransform | Callable[[T], S]
  ) -> MapDataset[S]:
    """Returns a dataset containing the elements transformed by ``transform``.

    Example usage::

      ds = MapDataset.range(5)
      ds = ds.map(lambda x: x + 10)
      list(ds) == [10, 11, 12, 13, 14]

    Args:
      transform: Either a ``MapTransform`` containing the ``map`` method or a
        callable that takes an element and returns a new element.

    Returns:
      A dataset containing the elements of the original dataset transformed by
      ``transform``.
    """
    # Loaded lazily due to a circular dependency (dataset <-> map).
    # pylint: disable=g-import-not-at-top
    from grain._src.python.dataset.transformations import (
        map as map_dataset,
    )
    # pylint: enable=g-import-not-at-top
    return map_dataset.MapMapDataset(parent=self, transform=transform)
  def apply(
      self,
      transformations: transforms.Transformation | transforms.Transformations,
  ) -> IterDataset:
    """Returns a dataset with the given transformation(s) applied.

    Syntactic sugar to avoid dispatch by transformation type.

    Example usage::

      ds = grain.MapDataset.range(5).to_iter_dataset()
      ds = ds.apply([AddOne(), grain.transforms.Batch(2)])
      list(ds) == [np.ndarray([1, 2]), np.ndarray([3, 4]), np.ndarray([5])]

    Args:
      transformations: one or more transformations to apply.

    Returns:
      Dataset with the given transformations applied.
    """
    return apply_transformations(self, transformations)

def apply_transformations(
    ds: _ConsistentDatasetType,
    transformations: transforms.Transformation | transforms.Transformations,
) -> _ConsistentDatasetType:
  """Applies transformations to a dataset.

  DEPRECATED: Use `ds.apply(transformations)` instead.

  Args:
    ds: `MapDataset` or `IterDataset` to apply the transformations to.
    transformations: one or more transformations to apply.

  Returns:
    Dataset of the same type with transformations applied.
  """
  if not isinstance(transformations, Sequence):
    transformations = (transformations,)
  for transformation in transformations:
    match transformation:
      case transforms.Batch():
        ds = ds.batch(
            transformation.batch_size,
            drop_remainder=transformation.drop_remainder,
        )
      case transforms.MapTransform():
        ds = ds.map(transformation)
      case transforms.RandomMapTransform():
        ds = ds.random_map(transformation)
      case transforms.MapWithIndex():
        ds = ds.map_with_index(transformation)
      case transforms.FlatMapTransform():
        # Loaded lazily due to a circular dependency (dataset <-> flatmap).
        # pylint: disable=g-import-not-at-top
        from grain._src.python.dataset.transformations import flatmap
        # pylint: enable=g-import-not-at-top
        if isinstance(ds, MapDataset):
          ds = flatmap.FlatMapMapDataset(ds, transformation)
        else:
          ds = flatmap.FlatMapIterDataset(ds, transformation)
      case transforms.Filter():
        ds = ds.filter(transformation)
      case _:
        raise NotImplementedError(
            f"Transformation type: {transformation} is not supported."
        )
  return ds

- Added comment that TokenizeAndChunk removes all columns except the text_column
- Modified _grain_tokenizer.py with latest changes
- Added note that use_truncation=False is only available in grain's pretrain preprocessing pipeline
- Move feature_names, sequence_length, add_bos, add_eos, and tokenizer to TokenizerTransformBase
- Consolidate initialization logic in base class __post_init__
- Simplify TokenizeAndTrim and TokenizeAndChunk by removing duplicate parameters
- Add common _encode method to eliminate code duplication
- Maintain backward compatibility and specialized behavior for each class
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feature request] Support chunking (splitting) long sequences instead of truncation during tokenization
3 participants