Skip to content

mamba generation throughput lower than original due to DecodingCGCache #29699

@y1xia0w

Description

@y1xia0w

System Info

Python 3.10.13, CUDA 12.1
GPU = NVIDIA GeForce RTX 2080 Ti. Max memory = 10.747 GB.

torch==2.2.1
torchaudio==2.1.0
torchvision==0.16.0
tokenizers==0.15.2
transformers ==git+https://github.com/huggingface/transformers@dd1c9052159ae824c8acef7c2552f9fad5ca020a
triton==2.2.0
causal_conv1d==git+https://github.com/Dao-AILab/causal-conv1d.git@96456720c00393a5c32872d8352d7a7ec31fb3db#egg=causal_conv1d
mamba_ssm==git+https://github.com/state-spaces/mamba.git@9127d1f47f367f5c9cc49c73ad73557089d02cb8#egg=mamba_ssm

Who can help?

text models: @ArthurZucker and @younesbelkada
generate: @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The key model initialization and generation parts are given as below.

Original code repo

In the original code repo

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m")
model.eval()

model.generate(
        input_ids=input_ids,
        max_length=max_length,
        **cg=True**
    )

Then throughput for generating 1K length is

Number of parameters: 129135360
Prompt length: 100, generation length: 1000
Prompt processing + decoding time: 1011 ms

Using the HF library

from transformers import MambaForCausalLM
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
model.eval()

model.generate(
        input_ids=input_ids,
        max_length=max_length
    )

Then throughput for generating 1K length is

Number of parameters: 129135360
Prompt length: 100, generation length: 1000
state-spaces/mamba-130m-hf prompt processing + decoding time: 15970ms

Expected behavior

The "cg=True" is confirmed to be the part has a significant impact on the generation performance for mamba.

I have tried:

  1. Passing the "use_cache=True" as follows won't affect the results
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", use_cache=True)
or
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", cache_params={use_cache: True})
or
model.config.use_cache=True
  1. Modifying the mamba model to force the argument "use_cache=True" in the MambaModel, but still not working.

I assume this is related to the #29605, but modifying the argument directly seems not solving the problem.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions