Skip to content

Tutorial for using DeepSpeed's activation checkpointing instead of PyTorch's #32409

@huyiwen

Description

@huyiwen

Feature request

Is there a tutorial for using DeepSpeed's activation checkpointing instead of PyTorch's?

I'm using Trainer with ZeRO integration to train my model. Here's my code:

if training_args.deepspeed_gradient_checkpointing and training_args.deepspeed:
        from deepspeed.runtime.activation_checkpointing.checkpointing import configure
        configure(mpu_=None)
        from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
        model._set_gradient_checkpointing(training_args.deepspeed_gradient_checkpointing, checkpoint)
{
"activation_checkpointing": {
    "partition_activations": true,
    "cpu_checkpointing": true,
    "contiguous_memory_optimization": false,
    "number_checkpoints": null,
    "synchronize_checkpoint_boundary": false,
    "profile": false
  }
}
torchrun --nproc_per_node=8 \
    --nnodes=${NNODES} \
    --node_rank=${NODE_RANK} \
    --master_addr=${MASTER_ADDR} \
    --master_port=${MASTER_PORT} \
    train.py \
    --deepspeed ${DEEPSPEED_CONFIG_PATH} \
    --gradient_checkpointing False

However, I got this in FlashAttention2:

class XXXFlashAttention2(XXXAttention):
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        output_attentions = False

        bsz, q_len, _ = hidden_states.size()  # <---- this got error

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
  File "modeling_xxx.py", line 518, in forward
    bsz, q_len, _ = hidden_states.size()
ValueError: not enough values to unpack (expected 3, got 2)

Motivation

It seems there isn't such a tutorial available at the moment in either deepspeed's tutorial or huggingface.

Your contribution

Provide my results

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions