Skip to content
Open
5 changes: 5 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,11 @@ tokenize_train_data: True # False if the dataset is pre-tokenized
tokenize_eval_data: True # False if the dataset is pre-tokenized
add_bos: True
add_eos: True
# If False, use chunking for long sequences instead of truncation.
# Note: use_truncation=False is only available in grain's pretrain preprocessing pipeline.
# See the TokenizeAndTrim and TokenizeAndChunk classes in
# `src/MaxText/input_pipeline/_grain_tokenizer.py` for implementation details.
use_truncation: True

# Dataset
per_device_batch_size: 12.0
Expand Down
26 changes: 18 additions & 8 deletions src/MaxText/input_pipeline/_grain_data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))

assert len(data_columns) == 1
rekey_dict = {"inputs": "text", "targets": "text"}
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
data_columns = ("inputs", "targets")
text_column = data_columns[0]

tokenizer_model = tokenizer.build_tokenizer(
config.tokenizer_path,
Expand All @@ -115,11 +113,23 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
pad_id = -1

if tokenize:
dataset = dataset.map(
_grain_tokenizer.TokenizeAndTrim(
data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model
)
)
if config.use_truncation:
dataset = dataset.map(
_grain_tokenizer.TokenizeAndTrim(
text_column, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model
)
)
else:
dataset = dataset.apply(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason for using dataset.apply and not dataset.map similar to how TokenizeAndTrim uses at line 119?

Copy link
Contributor Author

@bzantium bzantium Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is because of MapDataset's map only takes 1:1 mapping transformations (return MapMapDataset) but apply supports 1:N transformations (return FlatMapMapDataset). You can find it here: https://github.com/google/grain/blob/main/grain/_src/python/dataset/transformations/map.py
I also want it to support this with map for consistency 😂.

  def map(
      self, transform: transforms.MapTransform | Callable[[T], S]
  ) -> MapDataset[S]:
    """Returns a dataset containing the elements transformed by ``transform``.

    Example usage::

      ds = MapDataset.range(5)
      ds = ds.map(lambda x: x + 10)
      list(ds) == [10, 11, 12, 13, 14]

    Args:
      transform: Either a ``MapTransform`` containing the ``map`` method or a
        callable that takes an element and returns a new element.

    Returns:
      A dataset containing the elements of the original dataset transformed by
      ``transform``.
    """
    # Loaded lazily due to a circular dependency (dataset <-> map).
    # pylint: disable=g-import-not-at-top
    from grain._src.python.dataset.transformations import (
        map as map_dataset,
    )
    # pylint: enable=g-import-not-at-top
    return map_dataset.MapMapDataset(parent=self, transform=transform)
  def apply(
      self,
      transformations: transforms.Transformation | transforms.Transformations,
  ) -> IterDataset:
    """Returns a dataset with the given transformation(s) applied.

    Syntactic sugar to avoid dispatch by transformation type.

    Example usage::

      ds = grain.MapDataset.range(5).to_iter_dataset()
      ds = ds.apply([AddOne(), grain.transforms.Batch(2)])
      list(ds) == [np.ndarray([1, 2]), np.ndarray([3, 4]), np.ndarray([5])]

    Args:
      transformations: one or more transformations to apply.

    Returns:
      Dataset with the given transformations applied.
    """
    return apply_transformations(self, transformations)

def apply_transformations(
    ds: _ConsistentDatasetType,
    transformations: transforms.Transformation | transforms.Transformations,
) -> _ConsistentDatasetType:
  """Applies transformations to a dataset.

  DEPRECATED: Use `ds.apply(transformations)` instead.

  Args:
    ds: `MapDataset` or `IterDataset` to apply the transformations to.
    transformations: one or more transformations to apply.

  Returns:
    Dataset of the same type with transformations applied.
  """
  if not isinstance(transformations, Sequence):
    transformations = (transformations,)
  for transformation in transformations:
    match transformation:
      case transforms.Batch():
        ds = ds.batch(
            transformation.batch_size,
            drop_remainder=transformation.drop_remainder,
        )
      case transforms.MapTransform():
        ds = ds.map(transformation)
      case transforms.RandomMapTransform():
        ds = ds.random_map(transformation)
      case transforms.MapWithIndex():
        ds = ds.map_with_index(transformation)
      case transforms.FlatMapTransform():
        # Loaded lazily due to a circular dependency (dataset <-> flatmap).
        # pylint: disable=g-import-not-at-top
        from grain._src.python.dataset.transformations import flatmap
        # pylint: enable=g-import-not-at-top
        if isinstance(ds, MapDataset):
          ds = flatmap.FlatMapMapDataset(ds, transformation)
        else:
          ds = flatmap.FlatMapIterDataset(ds, transformation)
      case transforms.Filter():
        ds = ds.filter(transformation)
      case _:
        raise NotImplementedError(
            f"Transformation type: {transformation} is not supported."
        )
  return ds

_grain_tokenizer.TokenizeAndChunk(
text_column, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model
)
)

data_columns = ("inputs", "targets")
rekey_dict = {col: text_column for col in data_columns}
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))

# Pack and Batch examples.
batch_size = config.global_batch_size_to_load // jax.process_count()
if config.packing:
Expand Down
69 changes: 59 additions & 10 deletions src/MaxText/input_pipeline/_grain_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@


@dataclasses.dataclass
class TokenizeAndTrim(grain.MapTransform):
"""Tokenize and trim features to sequence length."""
class TokenizerTransformBase:
"""Base class for tokenizer transforms with common functionality."""

# pylint: disable=attribute-defined-outside-init
feature_names: str | Sequence[str]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep TokenizeAndTrim as it was to allow supporting multiple columns, I think it's needed for the DPO support (#947)

Copy link
Contributor Author

@bzantium bzantium Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the context about the DPO support. I've reverted TokenizeAndTrim back to its original implementation (but refactored as recommended below) to support multiple columns, and I've pushed the change.

Expand All @@ -37,22 +37,23 @@ class TokenizeAndTrim(grain.MapTransform):
def __post_init__(self):
self._processor = None
self._initialize_processor_lock = threading.Lock()
# Convert single values to lists for consistent processing
if isinstance(self.feature_names, str):
self.feature_names = [self.feature_names]
if isinstance(self.sequence_length, int):
self.sequence_length = [self.sequence_length] * len(self.feature_names)

def map(self, element: dict[str, Any]) -> dict[str, Any]:
"""Maps to each element."""
def _get_processor(self):
if self._processor is None:
with self._initialize_processor_lock:
if self._processor is None: # Ensures only one thread initializes SPP.
if self._processor is None: # Ensures only one thread initializes processor.
self._processor = self.tokenizer
for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True):
text = element[feature_name]
token_ids = self._processor.encode(text)[:sequence_length]
element[feature_name] = np.asarray(token_ids, dtype=np.int32)
return element
return self._processor

def _encode(self, text: str) -> list[int]:
"""Common method to encode text using the tokenizer."""
processor = self._get_processor()
return processor.encode(text)

def __getstate__(self):
state = self.__dict__.copy()
Expand All @@ -64,3 +65,51 @@ def __setstate__(self, state):
self.__dict__.update(state)
self._processor = None
self._initialize_processor_lock = threading.Lock()


@dataclasses.dataclass
class TokenizeAndTrim(TokenizerTransformBase, grain.MapTransform):
"""Tokenize and trim features to sequence length."""

def __post_init__(self):
super().__post_init__()

def map(self, element: dict[str, Any]) -> dict[str, Any]:
"""Maps to each element."""
for feature_name, max_length in zip(self.feature_names, self.sequence_length, strict=True):
text = element[feature_name]
token_ids = self._encode(text)[:max_length]
element[feature_name] = np.asarray(token_ids, dtype=np.int32)
return element


@dataclasses.dataclass
class TokenizeAndChunk(TokenizerTransformBase, grain.experimental.FlatMapTransform):
"""Tokenize and chunk features into multiple examples of sequence length."""

max_fan_out: int = 2048

def __post_init__(self):
super().__post_init__()
# TokenizeAndChunk only supports single feature for chunking
assert len(self.feature_names) == 1, "TokenizeAndChunk only supports single feature name"
assert len(self.sequence_length) == 1, "TokenizeAndChunk only supports single sequence length"
self.feature_name = self.feature_names[0] # For backward compatibility
self.sequence_length = self.sequence_length[0] # Convert back to int for chunking

def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]:
text = element[self.feature_name]
chunk_size = self.sequence_length

token_ids = self._encode(text)

if not token_ids:
return []

output_elements = []
for start_idx in range(0, len(token_ids), chunk_size):
chunk = np.asarray(token_ids[start_idx : start_idx + chunk_size], dtype=np.int32)
new_element = {self.feature_name: chunk}
output_elements.append(new_element)

return output_elements
160 changes: 160 additions & 0 deletions tests/tokenizer_transform_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Tests for tokenizer
"""

import unittest

import grain.python as grain
import numpy as np
from MaxText.input_pipeline import _grain_tokenizer
from MaxText.input_pipeline import _input_pipeline_utils
from numpy.testing import assert_array_equal


class MockTokenizer:
"""
Mocks a tokenizer by splitting on space and mapping letters to simple ints.
e.g., "a b c" -> [1, 2, 3]
"""
def encode(self, text: str) -> list[int]:
if not text:
return []
# Simple 'a'=1, 'b'=2, ... mapping
return [ord(c) - ord('a') + 1 for c in text.split(' ')]


class TokenizerTransformTest(unittest.TestCase):
"""Tests for chunking, trimming, and padding transformations."""

def setUp(self):
self.max_len = 5
self.pad_length = 7
self.pad_id = 0
self.text_column = "text"
self.mock_tokenizer = MockTokenizer()
self.source_data = [
{"text": "a b c"},
{"text": "d e f g h i j"},
{"text": ""},
{"text": "k l m n o p q r s t"}
]
self.base_ds = grain.MapDataset.source(self.source_data).to_iter_dataset()

def test_tokenize_and_trim(self):
"""Tests the 1:1 MapTransform (truncation) logic."""
trim_op = _grain_tokenizer.TokenizeAndTrim(
text_column=self.text_column,
sequence_length=self.max_len,
add_bos=False,
add_eos=False,
tokenizer=self.mock_tokenizer
)
trim_ds = self.base_ds.map(trim_op)
results = list(trim_ds)
self.assertEqual(len(results), len(self.source_data))
expected_inputs = [
np.array([1, 2, 3], dtype=np.int32),
np.array([4, 5, 6, 7, 8], dtype=np.int32),
np.array([], dtype=np.int32),
np.array([11, 12, 13, 14, 15], dtype=np.int32)
]
result_inputs = [r["text"] for r in results]
self.assertEqual(len(result_inputs), len(expected_inputs))
for res, exp in zip(result_inputs, expected_inputs):
assert_array_equal(res, exp)

def test_tokenize_and_chunk(self):
"""Tests the 1:N FlatMapTransform (chunking) logic."""
chunk_op = _grain_tokenizer.TokenizeAndChunk(
text_column=self.text_column,
sequence_length=self.max_len,
add_bos=False,
add_eos=False,
tokenizer=self.mock_tokenizer
)
chunk_ds = self.base_ds.apply(chunk_op)
results = list(chunk_ds)
self.assertEqual(len(results), 5)
expected_inputs = [
np.array([1, 2, 3], dtype=np.int32),
np.array([4, 5, 6, 7, 8], dtype=np.int32),
np.array([9, 10], dtype=np.int32),
np.array([11, 12, 13, 14, 15], dtype=np.int32),
np.array([16, 17, 18, 19, 20], dtype=np.int32)
]
result_inputs = [r["text"] for r in results]
self.assertEqual(len(result_inputs), len(expected_inputs))
for res, exp in zip(result_inputs, expected_inputs):
assert_array_equal(res, exp)

def test_trim_and_pad_chaining(self):
"""Tests chaining TokenizeAndTrim.map() -> PadToMaxLength.map()"""
trim_op = _grain_tokenizer.TokenizeAndTrim(
text_column=self.text_column,
sequence_length=self.max_len,
add_bos=False,
add_eos=False,
tokenizer=self.mock_tokenizer
)
pad_op = _input_pipeline_utils.PadToMaxLength(
max_length=self.pad_length,
pad_id=self.pad_id
)
chained_ds = self.base_ds.map(trim_op).map(pad_op)
results = list(chained_ds)
self.assertEqual(len(results), len(self.source_data))
expected_inputs = [
np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32),
np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32),
np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.int32),
np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32)
]
result_inputs = [r["text"] for r in results]
self.assertEqual(len(result_inputs), len(expected_inputs))
for res, exp in zip(result_inputs, expected_inputs):
assert_array_equal(res, exp)

def test_chunk_and_pad_chaining(self):
"""Tests chaining TokenizeAndChunk.apply() -> PadToMaxLength.map()"""
chunk_op = _grain_tokenizer.TokenizeAndChunk(
text_column=self.text_column,
sequence_length=self.max_len,
add_bos=False,
add_eos=False,
tokenizer=self.mock_tokenizer
)
pad_op = _input_pipeline_utils.PadToMaxLength(
max_length=self.pad_length,
pad_id=self.pad_id
)
chained_ds = self.base_ds.apply(chunk_op).map(pad_op)
results = list(chained_ds)
self.assertEqual(len(results), 5)
expected_inputs = [
np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32),
np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32),
np.array([9, 10, 0, 0, 0, 0, 0], dtype=np.int32),
np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32),
np.array([16, 17, 18, 19, 20, 0, 0], dtype=np.int32),
]
result_inputs = [r["text"] for r in results]
self.assertEqual(len(result_inputs), len(expected_inputs))
for res, exp in zip(result_inputs, expected_inputs):
assert_array_equal(res, exp)


if __name__ == "__main__":
unittest.main()