-
Notifications
You must be signed in to change notification settings - Fork 30.8k
Description
Describe the bug
When using chunked attention with left padding, the first chunk may contain fewer valid tokens than chunk_size
, which introduces a distribution mismatch between training and inference.
Since training typically uses right padding or no padding, the model never sees a chunk where most positions are padding. This may degrade performance, especially for shorter sequences.
To Reproduce
Example:
chunk_size = 4
- Two sequences: length = 8 and length = 15
With right padding (common during training):
Seq (len=8): [x1 x2 x3 x4] [x5 x6 x7 x8, PAD ...]
Seq (len=15): [x1 x2 x3 x4] [x5 x6 x7 x8] [x9..x12] [x13..x15, PAD]
👉 Each chunk contains up to 4 valid tokens, consistent with training.
With left padding (common in batched inference):
Seq (len=8, padded to 15): [PAD PAD PAD x1] [x2 x3 x4 x5] [x6 x7 x8 PAD]
Seq (len=15): [x1 x2 x3 x4] [x5 x6 x7 x8] [x9..x12] [x13..x15]
👉 The first chunk of the shorter sequence contains only 1 valid token (x1
), which never happens during training.
Minimal code sample:
from transformers import Llama4TextConfig, Llama4TextModel
from transformers.masking_utils import create_chunked_causal_mask
import torch
configuration = Llama4TextConfig(
hidden_size=512,
intermediate_size=2048,
intermediate_size_mlp=4096,
num_hidden_layers=4,
num_attention_heads=40,
num_key_value_heads=8,
use_cache=True,
attention_chunk_size=4,
)
model = Llama4TextModel(configuration)
input_ids = torch.tensor([[151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 9707,
11, 1246, 525, 498, 30, 28],
[ 3838, 374, 279, 6722, 315, 9625, 30, 3555, 374,
279, 6722, 315, 9856, 30, 28]])
attention_mask = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
inputs_embeds = model.embed_tokens(input_ids)
cache_position = torch.tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14])
position_ids = torch.tensor([[ 1, 1, 1, 1, 1, 1, 1, 0, 1, 2, 3, 4, 5, 6, 7],
[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])
mask_kwargs = {
"config": configuration,
"input_embeds": inputs_embeds,
"attention_mask": attention_mask,
"cache_position": cache_position,
"past_key_values": None,
"position_ids": position_ids,
}
chunked_attention= create_chunked_causal_mask(**mask_kwargs)
print(chunked_attention[0])
Output:
tensor([[[[False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, True, False, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, True, False,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, True, True,
False, False, False, False, False],
[False, False, False, False, False, False, False, False, True, True,
True, False, False, False, False],
[False, False, False, False, False, False, False, False, True, True,
True, True, False, False, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, True, False, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, True, True, False],
[False, False, False, False, False, False, False, False, False, False,
False, False, True, True, True]]]])
Expected behavior
Chunk partitioning should skip leading padding so that the effective sequence is partitioned as if no left pads existed. This would ensure that all chunks contain up to chunk_size valid tokens, aligning inference with training behavior.
Environment info
transformers
version: 4.56.0.dev0- Platform: Linux-5.10.134-008.16.kangaroo.al8.x86_64-x86_64-with-glibc2.39
- Python version: 3.10.18
- Huggingface_hub version: 0.34.4
- Safetensors version: 0.5.3
- Accelerate version: 1.8.0
- Accelerate config: not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.7.1+cu126 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
- Jax version: 0.4.13
- JaxLib version: 0.4.13
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: no
- Model: Llama 4 (or other models using chunked attention)
Additional context
- Right padding avoids this issue, but some workflows (e.g. batched decoding, GPT-style left padding) rely on left padding.
- Skipping leading pads when chunking could be a simple fix.
- This may be especially relevant for short sequences, where the first chunk is disproportionately affected.
Would you consider adjusting the chunked attention implementation to handle this case?