Skip to content

Deepseek00 cuda graph #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: deepseek00-mla
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/attention/merge_attn_states.cu
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ void merge_attn_states_launcher(torch::Tensor& output,
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
"output heads must be contiguous in memory");
TORCH_CHECK(
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
"prefix_output heads must be contiguous in memory");
TORCH_CHECK(
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
"suffix_output heads must be contiguous in memory");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
Expand Down
73 changes: 73 additions & 0 deletions examples/offline_inference/basic/generate_with_full_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# SPDX-License-Identifier: Apache-2.0

from vllm import LLM, EngineArgs
from vllm.utils import FlexibleArgumentParser


def create_parser():
parser = FlexibleArgumentParser()
# Add engine args
engine_group = parser.add_argument_group("Engine arguments")
EngineArgs.add_cli_args(engine_group)
engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct")
# Add sampling params
sampling_group = parser.add_argument_group("Sampling parameters")
sampling_group.add_argument("--max-tokens", type=int)
sampling_group.add_argument("--temperature", type=float)
sampling_group.add_argument("--top-p", type=float)
sampling_group.add_argument("--top-k", type=int)

return parser


def main(args: dict):
# Pop arguments not used by LLM
max_tokens = args.pop("max_tokens")
temperature = args.pop("temperature")
top_p = args.pop("top_p")
top_k = args.pop("top_k")

# Create an LLM
args.pop("compilation_config",
None) # Remove compilation_config if it exists
args.pop("max_num_seqs", None) # Remove max_num_seqs if it exists
llm = LLM(**args,
max_num_seqs=256,
compilation_config={
"full_cuda_graph": True,
"cudagraph_capture_sizes": [64, 256]
})

# Create a sampling params object
sampling_params = llm.get_default_sampling_params()
if max_tokens is not None:
sampling_params.max_tokens = max_tokens
if temperature is not None:
sampling_params.temperature = temperature
if top_p is not None:
sampling_params.top_p = top_p
if top_k is not None:
sampling_params.top_k = top_k

# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
print("-" * 50)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
print("-" * 50)


if __name__ == "__main__":
parser = create_parser()
args: dict = vars(parser.parse_args())
main(args)
17 changes: 14 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,18 +158,29 @@ def rotary_embedding(
cos_sin_cache: torch.Tensor,
is_neox: bool,
) -> None:
torch.ops._C.rotary_embedding(positions, query, key, head_size,
cos_sin_cache, is_neox)
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
torch.ops._C.rotary_embedding(positions, query_contiguous, key_contiguous,
head_size, cos_sin_cache, is_neox)
query.copy_(query_contiguous)
key.copy_(key_contiguous)


def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, head_size: int,
cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None:
torch.ops._C.batched_rotary_embedding(positions, query, key, head_size,
# TODO: Remove this contiguous call when the kernel is updated to support tensor slices
query_contiguous = query.contiguous()
key_contiguous = key.contiguous()
torch.ops._C.batched_rotary_embedding(positions, query_contiguous,
key_contiguous, head_size,
cos_sin_cache, is_neox, rot_dim,
cos_sin_cache_offsets)
query.copy_(query_contiguous)
key.copy_(key_contiguous)


# layer norm ops
Expand Down
17 changes: 16 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def forward(
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
attn_metadata: ForwardContext = get_forward_context().attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
if self.use_output:
Expand All @@ -209,6 +211,8 @@ def forward(
if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
Expand All @@ -225,6 +229,8 @@ def forward(
if self.use_direct_call:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, attn_metadata)
Expand Down Expand Up @@ -343,6 +349,7 @@ def wait_for_kv_layer_from_connector(layer_name: str):
if attn_metadata is None:
return

assert isinstance(attn_metadata, dict)
connector.wait_for_layer_load(layer_name)


Expand All @@ -360,6 +367,7 @@ def maybe_save_kv_layer_to_connector(
if attn_metadata is None:
return

assert isinstance(attn_metadata, dict)
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)


Expand All @@ -372,6 +380,10 @@ def unified_attention(
wait_for_kv_layer_from_connector(layer_name)

forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]

attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
Expand Down Expand Up @@ -410,6 +422,9 @@ def unified_attention_with_output(
wait_for_kv_layer_from_connector(layer_name)
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_metadata = forward_context.attn_metadata
self = forward_context.no_compile_layers[layer_name]
kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
Expand Down
Loading
Loading