-
Notifications
You must be signed in to change notification settings - Fork 416
feat(input_pipeline): Add support for chunking long sequences instead truncation #2354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2c0512e
04e8255
4138263
b0b78c3
3316f5d
755ad01
8334e32
d4dfc19
fb770c7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,8 +24,8 @@ | |
|
||
|
||
@dataclasses.dataclass | ||
class TokenizeAndTrim(grain.MapTransform): | ||
"""Tokenize and trim features to sequence length.""" | ||
class TokenizerTransformBase: | ||
"""Base class for tokenizer transforms with common functionality.""" | ||
|
||
# pylint: disable=attribute-defined-outside-init | ||
feature_names: str | Sequence[str] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep TokenizeAndTrim as it was to allow supporting multiple columns, I think it's needed for the DPO support (#947) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the context about the DPO support. I've reverted |
||
|
@@ -37,22 +37,23 @@ class TokenizeAndTrim(grain.MapTransform): | |
def __post_init__(self): | ||
self._processor = None | ||
self._initialize_processor_lock = threading.Lock() | ||
# Convert single values to lists for consistent processing | ||
if isinstance(self.feature_names, str): | ||
self.feature_names = [self.feature_names] | ||
if isinstance(self.sequence_length, int): | ||
self.sequence_length = [self.sequence_length] * len(self.feature_names) | ||
|
||
def map(self, element: dict[str, Any]) -> dict[str, Any]: | ||
"""Maps to each element.""" | ||
def _get_processor(self): | ||
if self._processor is None: | ||
with self._initialize_processor_lock: | ||
if self._processor is None: # Ensures only one thread initializes SPP. | ||
if self._processor is None: # Ensures only one thread initializes processor. | ||
self._processor = self.tokenizer | ||
for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True): | ||
text = element[feature_name] | ||
token_ids = self._processor.encode(text)[:sequence_length] | ||
element[feature_name] = np.asarray(token_ids, dtype=np.int32) | ||
return element | ||
return self._processor | ||
|
||
def _encode(self, text: str) -> list[int]: | ||
"""Common method to encode text using the tokenizer.""" | ||
processor = self._get_processor() | ||
return processor.encode(text) | ||
|
||
def __getstate__(self): | ||
state = self.__dict__.copy() | ||
|
@@ -64,3 +65,51 @@ def __setstate__(self, state): | |
self.__dict__.update(state) | ||
self._processor = None | ||
self._initialize_processor_lock = threading.Lock() | ||
|
||
|
||
@dataclasses.dataclass | ||
bzantium marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class TokenizeAndTrim(TokenizerTransformBase, grain.MapTransform): | ||
"""Tokenize and trim features to sequence length.""" | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
|
||
def map(self, element: dict[str, Any]) -> dict[str, Any]: | ||
"""Maps to each element.""" | ||
for feature_name, max_length in zip(self.feature_names, self.sequence_length, strict=True): | ||
text = element[feature_name] | ||
token_ids = self._encode(text)[:max_length] | ||
element[feature_name] = np.asarray(token_ids, dtype=np.int32) | ||
return element | ||
|
||
|
||
@dataclasses.dataclass | ||
class TokenizeAndChunk(TokenizerTransformBase, grain.experimental.FlatMapTransform): | ||
"""Tokenize and chunk features into multiple examples of sequence length.""" | ||
|
||
max_fan_out: int = 2048 | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
# TokenizeAndChunk only supports single feature for chunking | ||
assert len(self.feature_names) == 1, "TokenizeAndChunk only supports single feature name" | ||
assert len(self.sequence_length) == 1, "TokenizeAndChunk only supports single sequence length" | ||
self.feature_name = self.feature_names[0] # For backward compatibility | ||
self.sequence_length = self.sequence_length[0] # Convert back to int for chunking | ||
|
||
def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]: | ||
text = element[self.feature_name] | ||
chunk_size = self.sequence_length | ||
|
||
token_ids = self._encode(text) | ||
|
||
if not token_ids: | ||
return [] | ||
|
||
output_elements = [] | ||
for start_idx in range(0, len(token_ids), chunk_size): | ||
chunk = np.asarray(token_ids[start_idx : start_idx + chunk_size], dtype=np.int32) | ||
new_element = {self.feature_name: chunk} | ||
output_elements.append(new_element) | ||
|
||
return output_elements |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
# Copyright 2023–2025 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" Tests for tokenizer | ||
""" | ||
|
||
import unittest | ||
|
||
import grain.python as grain | ||
import numpy as np | ||
from MaxText.input_pipeline import _grain_tokenizer | ||
from MaxText.input_pipeline import _input_pipeline_utils | ||
from numpy.testing import assert_array_equal | ||
|
||
|
||
class MockTokenizer: | ||
""" | ||
Mocks a tokenizer by splitting on space and mapping letters to simple ints. | ||
e.g., "a b c" -> [1, 2, 3] | ||
""" | ||
def encode(self, text: str) -> list[int]: | ||
if not text: | ||
return [] | ||
# Simple 'a'=1, 'b'=2, ... mapping | ||
return [ord(c) - ord('a') + 1 for c in text.split(' ')] | ||
|
||
|
||
class TokenizerTransformTest(unittest.TestCase): | ||
"""Tests for chunking, trimming, and padding transformations.""" | ||
|
||
def setUp(self): | ||
self.max_len = 5 | ||
self.pad_length = 7 | ||
self.pad_id = 0 | ||
self.text_column = "text" | ||
self.mock_tokenizer = MockTokenizer() | ||
self.source_data = [ | ||
{"text": "a b c"}, | ||
{"text": "d e f g h i j"}, | ||
{"text": ""}, | ||
{"text": "k l m n o p q r s t"} | ||
] | ||
self.base_ds = grain.MapDataset.source(self.source_data).to_iter_dataset() | ||
|
||
def test_tokenize_and_trim(self): | ||
"""Tests the 1:1 MapTransform (truncation) logic.""" | ||
trim_op = _grain_tokenizer.TokenizeAndTrim( | ||
text_column=self.text_column, | ||
sequence_length=self.max_len, | ||
add_bos=False, | ||
add_eos=False, | ||
tokenizer=self.mock_tokenizer | ||
) | ||
trim_ds = self.base_ds.map(trim_op) | ||
results = list(trim_ds) | ||
self.assertEqual(len(results), len(self.source_data)) | ||
expected_inputs = [ | ||
np.array([1, 2, 3], dtype=np.int32), | ||
np.array([4, 5, 6, 7, 8], dtype=np.int32), | ||
np.array([], dtype=np.int32), | ||
np.array([11, 12, 13, 14, 15], dtype=np.int32) | ||
] | ||
result_inputs = [r["text"] for r in results] | ||
self.assertEqual(len(result_inputs), len(expected_inputs)) | ||
for res, exp in zip(result_inputs, expected_inputs): | ||
assert_array_equal(res, exp) | ||
|
||
def test_tokenize_and_chunk(self): | ||
"""Tests the 1:N FlatMapTransform (chunking) logic.""" | ||
chunk_op = _grain_tokenizer.TokenizeAndChunk( | ||
text_column=self.text_column, | ||
sequence_length=self.max_len, | ||
add_bos=False, | ||
add_eos=False, | ||
tokenizer=self.mock_tokenizer | ||
) | ||
chunk_ds = self.base_ds.apply(chunk_op) | ||
results = list(chunk_ds) | ||
self.assertEqual(len(results), 5) | ||
expected_inputs = [ | ||
np.array([1, 2, 3], dtype=np.int32), | ||
np.array([4, 5, 6, 7, 8], dtype=np.int32), | ||
np.array([9, 10], dtype=np.int32), | ||
np.array([11, 12, 13, 14, 15], dtype=np.int32), | ||
np.array([16, 17, 18, 19, 20], dtype=np.int32) | ||
] | ||
result_inputs = [r["text"] for r in results] | ||
self.assertEqual(len(result_inputs), len(expected_inputs)) | ||
for res, exp in zip(result_inputs, expected_inputs): | ||
assert_array_equal(res, exp) | ||
|
||
def test_trim_and_pad_chaining(self): | ||
"""Tests chaining TokenizeAndTrim.map() -> PadToMaxLength.map()""" | ||
trim_op = _grain_tokenizer.TokenizeAndTrim( | ||
text_column=self.text_column, | ||
sequence_length=self.max_len, | ||
add_bos=False, | ||
add_eos=False, | ||
tokenizer=self.mock_tokenizer | ||
) | ||
pad_op = _input_pipeline_utils.PadToMaxLength( | ||
max_length=self.pad_length, | ||
pad_id=self.pad_id | ||
) | ||
chained_ds = self.base_ds.map(trim_op).map(pad_op) | ||
results = list(chained_ds) | ||
self.assertEqual(len(results), len(self.source_data)) | ||
expected_inputs = [ | ||
np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32), | ||
np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32), | ||
np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.int32), | ||
np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32) | ||
] | ||
result_inputs = [r["text"] for r in results] | ||
self.assertEqual(len(result_inputs), len(expected_inputs)) | ||
for res, exp in zip(result_inputs, expected_inputs): | ||
assert_array_equal(res, exp) | ||
|
||
def test_chunk_and_pad_chaining(self): | ||
"""Tests chaining TokenizeAndChunk.apply() -> PadToMaxLength.map()""" | ||
chunk_op = _grain_tokenizer.TokenizeAndChunk( | ||
text_column=self.text_column, | ||
sequence_length=self.max_len, | ||
add_bos=False, | ||
add_eos=False, | ||
tokenizer=self.mock_tokenizer | ||
) | ||
pad_op = _input_pipeline_utils.PadToMaxLength( | ||
max_length=self.pad_length, | ||
pad_id=self.pad_id | ||
) | ||
chained_ds = self.base_ds.apply(chunk_op).map(pad_op) | ||
results = list(chained_ds) | ||
self.assertEqual(len(results), 5) | ||
expected_inputs = [ | ||
np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32), | ||
np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32), | ||
np.array([9, 10, 0, 0, 0, 0, 0], dtype=np.int32), | ||
np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32), | ||
np.array([16, 17, 18, 19, 20, 0, 0], dtype=np.int32), | ||
] | ||
result_inputs = [r["text"] for r in results] | ||
self.assertEqual(len(result_inputs), len(expected_inputs)) | ||
for res, exp in zip(result_inputs, expected_inputs): | ||
assert_array_equal(res, exp) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for using
dataset.apply
and notdataset.map
similar to howTokenizeAndTrim
uses at line 119?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is because of
MapDataset
'smap
only takes 1:1 mapping transformations (returnMapMapDataset
) butapply
supports 1:N transformations (returnFlatMapMapDataset
). You can find it here: https://github.com/google/grain/blob/main/grain/_src/python/dataset/transformations/map.pyI also want it to support this with
map
for consistency 😂.