-
Notifications
You must be signed in to change notification settings - Fork 30k
Open
Labels
Description
Feature request
Is there a tutorial for using DeepSpeed's activation checkpointing instead of PyTorch's?
I'm using Trainer
with ZeRO integration to train my model. Here's my code:
if training_args.deepspeed_gradient_checkpointing and training_args.deepspeed:
from deepspeed.runtime.activation_checkpointing.checkpointing import configure
configure(mpu_=None)
from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
model._set_gradient_checkpointing(training_args.deepspeed_gradient_checkpointing, checkpoint)
{
"activation_checkpointing": {
"partition_activations": true,
"cpu_checkpointing": true,
"contiguous_memory_optimization": false,
"number_checkpoints": null,
"synchronize_checkpoint_boundary": false,
"profile": false
}
}
torchrun --nproc_per_node=8 \
--nnodes=${NNODES} \
--node_rank=${NODE_RANK} \
--master_addr=${MASTER_ADDR} \
--master_port=${MASTER_PORT} \
train.py \
--deepspeed ${DEEPSPEED_CONFIG_PATH} \
--gradient_checkpointing False
However, I got this in FlashAttention2:
class XXXFlashAttention2(XXXAttention):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
output_attentions = False
bsz, q_len, _ = hidden_states.size() # <---- this got error
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
File "modeling_xxx.py", line 518, in forward
bsz, q_len, _ = hidden_states.size()
ValueError: not enough values to unpack (expected 3, got 2)
Motivation
It seems there isn't such a tutorial available at the moment in either deepspeed's tutorial or huggingface.
Your contribution
Provide my results
hrushikesh198 and veritas9872