-
Notifications
You must be signed in to change notification settings - Fork 30.3k
Description
System Info
Python 3.10.13, CUDA 12.1
GPU = NVIDIA GeForce RTX 2080 Ti. Max memory = 10.747 GB.
torch==2.2.1
torchaudio==2.1.0
torchvision==0.16.0
tokenizers==0.15.2
transformers ==git+https://github.com/huggingface/transformers@dd1c9052159ae824c8acef7c2552f9fad5ca020a
triton==2.2.0
causal_conv1d==git+https://github.com/Dao-AILab/causal-conv1d.git@96456720c00393a5c32872d8352d7a7ec31fb3db#egg=causal_conv1d
mamba_ssm==git+https://github.com/state-spaces/mamba.git@9127d1f47f367f5c9cc49c73ad73557089d02cb8#egg=mamba_ssm
Who can help?
text models: @ArthurZucker and @younesbelkada
generate: @gante
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
The key model initialization and generation parts are given as below.
Original code repo
In the original code repo
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-130m")
model.eval()
model.generate(
input_ids=input_ids,
max_length=max_length,
**cg=True**
)
Then throughput for generating 1K length is
Number of parameters: 129135360
Prompt length: 100, generation length: 1000
Prompt processing + decoding time: 1011 ms
Using the HF library
from transformers import MambaForCausalLM
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
model.eval()
model.generate(
input_ids=input_ids,
max_length=max_length
)
Then throughput for generating 1K length is
Number of parameters: 129135360
Prompt length: 100, generation length: 1000
state-spaces/mamba-130m-hf prompt processing + decoding time: 15970ms
Expected behavior
The "cg=True" is confirmed to be the part has a significant impact on the generation performance for mamba.
I have tried:
- Passing the "use_cache=True" as follows won't affect the results
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", use_cache=True)
or
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf", cache_params={use_cache: True})
or
model.config.use_cache=True
- Modifying the mamba model to force the argument "use_cache=True" in the MambaModel, but still not working.
I assume this is related to the #29605, but modifying the argument directly seems not solving the problem.