Skip to content

[Feature] Packing #3060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions benchmarks/test_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import importlib.util

import pytest
import torch
from tensordict import set_list_to_stack, TensorDict
from torchrl.data.llm import History
from torchrl.modules.llm.policies.common import ChatHistory
from torchrl.modules.llm.policies.transformers_wrapper import TransformersWrapper

_has_transformers = importlib.import_module("transformers") is not None


@pytest.fixture(scope="module")
def transformers_wrapper():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.device(device):
model = TransformersWrapper(
model="Qwen/Qwen2.5-0.5B",
tokenizer="Qwen/Qwen2.5-0.5B",
pad_model_input=False,
generate=False,
)
return model


@pytest.mark.skipif(not _has_transformers, reason="transformers not installed")
class TestWrappers:
@pytest.mark.parametrize("packing", [True, False])
@set_list_to_stack(True)
def test_packing(self, benchmark, transformers_wrapper, packing: bool):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.device(device):
transformers_wrapper = TransformersWrapper(
model=transformers_wrapper.model,
tokenizer=transformers_wrapper.tokenizer,
pad_model_input=not packing,
generate=False,
pad_output=False,
)
data = TensorDict(
{
"history": ChatHistory(
full=History(
role=[
["user", "assistant"],
["user", "assistant"],
["user", "assistant"],
["user", "assistant"],
],
content=[
[
"Lorem ipsum dolor sit amet",
"consectetur adipiscing elit",
],
[
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
],
[
"Lorem ipsum dolor sit amet",
"consectetur adipiscing elit",
],
[
"sed do eiusmod tempor incididunt ut labore et dolore magna aliqua",
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat",
],
],
batch_size=(4, 2),
device=device,
),
batch_size=(4,),
device=device,
)
},
batch_size=(4,),
device=device,
).to_lazystack()

def setup():
if torch.cuda.is_available():
torch.cuda.empty_cache()

benchmark.pedantic(
transformers_wrapper,
(data,),
rounds=10,
warmup_rounds=3,
setup=setup,
)
114 changes: 111 additions & 3 deletions test/llm/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import pytest
import torch
from tensordict import lazy_stack, set_list_to_stack, TensorDict
from tensordict import assert_close, lazy_stack, set_list_to_stack, TensorDict

from tensordict.utils import _zip_strict
from torchrl.data.llm import History
Expand Down Expand Up @@ -163,6 +163,22 @@ def sample_tokens(vllm_instance):
return tokenized["input_ids"], tokenized["attention_mask"]


@pytest.fixture
def sample_tokens_unpadded(vllm_instance):
"""Create sample tokens for testing."""
model, tokenizer = vllm_instance
text = [
"Are you happy? Say yes or no.",
"Explain the difference between a cat and a dog. Be very detailed.",
]
tokenized = tokenizer(text, padding=False)
return torch.nested.nested_tensor(
[torch.tensor(t) for t in tokenized["input_ids"]], layout=torch.jagged
), torch.nested.nested_tensor(
[torch.tensor(t) for t in tokenized["attention_mask"]], layout=torch.jagged
)


def check_output_shapes(out, pad_output, requested_log_probs=False):
if pad_output or not out.ndim:
# We can get all tensors or they are none
Expand Down Expand Up @@ -1656,8 +1672,6 @@ def test_log_probs_consistency(
vllm_lp_result = vllm_lp_wrapper(new_data.copy())
tf_lp_result = tf_lp_wrapper(new_data.copy())

from tensordict import assert_close

assert_close(
vllm_lp_result, tf_lp_result, atol=1e-1, rtol=1e-1, intersection=True
)
Expand Down Expand Up @@ -1825,6 +1839,100 @@ def test_transformers_custom_masking(
assert hasattr(dist, "log_prob")


@pytest.mark.skipif(not _has_transformers, reason="transformers not available")
@pytest.mark.parametrize("pad_output", [False, True])
class TestPacking:
def test_packing_history(
self, transformers_instance, sample_history_assistant, pad_output
):
model, tokenizer = transformers_instance

wrapper_packed = TransformersWrapper(
model,
tokenizer=tokenizer,
input_mode="history",
generate=False,
return_log_probs=True,
pad_output=pad_output,
pad_model_input=False,
)
wrapped_padded = TransformersWrapper(
model,
tokenizer=tokenizer,
input_mode="history",
generate=False,
return_log_probs=True,
pad_output=pad_output,
pad_model_input=True,
)

td = TensorDict(
{"history": ChatHistory(full=sample_history_assistant)}, batch_size=(2,)
).to_lazystack(0)

result_padded = wrapped_padded(td)
result_packed = wrapper_packed(td)
assert_close(result_packed["log_probs"], result_padded["log_probs"])

def test_packing_text(self, transformers_instance, sample_text, pad_output):
model, tokenizer = transformers_instance
wrapper_packed = TransformersWrapper(
model,
tokenizer=tokenizer,
input_mode="text",
generate=False,
return_log_probs=True,
pad_output=pad_output,
pad_model_input=False,
)
wrapped_padded = TransformersWrapper(
model,
tokenizer=tokenizer,
input_mode="text",
generate=False,
return_log_probs=True,
pad_output=pad_output,
pad_model_input=True,
)
td = TensorDict({"text": Text(full=sample_text)}, batch_size=(2,))
result_packed = wrapper_packed(td)
result_padded = wrapped_padded(td)
assert_close(result_packed["log_probs"], result_padded["log_probs"])

def test_packing_tokens(
self, transformers_instance, sample_tokens_unpadded, pad_output
):
model, tokenizer = transformers_instance
wrapper_packed = TransformersWrapper(
model,
tokenizer=tokenizer,
input_mode="tokens",
generate=False,
return_log_probs=True,
pad_output=pad_output,
pad_model_input=False,
)
wrapped_padded = TransformersWrapper(
model,
tokenizer=tokenizer,
input_mode="tokens",
generate=False,
return_log_probs=True,
pad_output=pad_output,
pad_model_input=True,
)
td = TensorDict(
{
"tokens": Tokens(full=sample_tokens_unpadded[0]),
"masks": Masks(all_attention_mask=sample_tokens_unpadded[1]),
},
batch_size=(2,),
).to_lazystack(0)
result_padded = wrapped_padded(td)
result_packed = wrapper_packed(td)
assert_close(result_packed["log_probs"], result_padded["log_probs"])


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
6 changes: 4 additions & 2 deletions torchrl/data/llm/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,10 @@ class History(TensorClass["nocast"]):
:class:`~torchrl.modules.llm.policies.Tokens`: Container for token data.
"""

role: str
content: str | ContentBase
role: str | list[str] | list[list[str]]
content: str | ContentBase | list[str] | list[ContentBase] | list[list[str]] | list[
list[ContentBase]
]
is_complete: bool = True
tool_calls: list[dict] | None = None
tool_responses: list[str] | None = None
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/llm/policies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ class LLMWrapperBase(TensorDictModuleBase):
generate_kwargs: Additional arguments to pass to the model's generate method.
tokenizer_kwargs: Additional arguments to pass to the tokenizer.
pad_output: Whether to pad the output sequences to a uniform length.
pad_model_input: Whether to pad the model input sequences to a uniform length.
May not be supported by all models.
inplace: Determines how the module should handle in-place operations.
device: The device to use for computation.
layout: The layout to use for the output tensors when pad_output=False.
Expand Down
Loading
Loading