Skip to content

Commit b0c522d

Browse files
committed
clean code
Signed-off-by: Yue Weng <[email protected]>
1 parent eea4054 commit b0c522d

File tree

17 files changed

+771
-604
lines changed

17 files changed

+771
-604
lines changed

cpp/tensorrt_llm/common/attentionOp.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,6 @@ MLA_FUNC_DEFINE(__nv_bfloat16)
13181318
template <typename T, typename KVCacheBuffer>
13191319
int AttentionOp::enqueueContext(EnqueueContextParams<T> const& params, cudaStream_t stream)
13201320
{
1321-
printf("++++++++++++++++++++++++ in enqueueContext +++++++++++++++++++++++++\n");
13221321
int const headSize = getHeadSize();
13231322

13241323
int const local_hidden_units_qo = mNumHeads * headSize;
@@ -2162,7 +2161,6 @@ template int AttentionOp::enqueueContext<__nv_bfloat16, KVBlockArray>(
21622161
template <typename T, typename KVCacheBuffer>
21632162
int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cudaStream_t stream)
21642163
{
2165-
printf("++++++++++++++++++++++++ in enqueueGeneration +++++++++++++++++++++++++\n");
21662164
int const headSize = getHeadSize();
21672165
float const q_scaling = mQScaling;
21682166
float const* logn_scaling_ptr = isLognScaling() ? params.logn_scaling_ptr : nullptr;
@@ -2272,10 +2270,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams<T> const& params, cud
22722270
// self attn
22732271
XQAParams xqaParams{};
22742272
this->template convertMMHAParamsToXQAParams<T, KVCacheBuffer>(xqaParams, params, /*forConfigurePlugin=*/false);
2275-
// if (mEnableXQA && mXqaDispatcher->shouldUse(xqaParams))
2276-
bool shouldUseXQA = mEnableXQA && mXqaDispatcher->shouldUse(xqaParams);
2277-
printf("++++++++++++++++++++++++ in enqueueGeneration, mEnableXQA: %d, shouldUseXQA: %d +++++++++++++++++++++++++\n", mEnableXQA, shouldUseXQA);
2278-
if (mEnableXQA && shouldUseXQA)
2273+
if (mEnableXQA && mXqaDispatcher->shouldUse(xqaParams))
22792274
{
22802275
TLLM_LOG_DEBUG("XQA kernels are selected in the generation phase.");
22812276
xqaParams.stream = stream;

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplJIT/decoderXQAImplJIT.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,6 @@ bool DecoderXQAImplJIT::mayHavePerfGain(XQAParams const& xqaParams) const
9898

9999
bool DecoderXQAImplJIT::shouldUse(XQAParams const& umbrellaXQAParams, bool forConfigurePlugin)
100100
{
101-
printf("++++++++++++++++++++++++ in decoderXQAImplJIT::shouldUse +++++++++++++++++++++++++\n");
102-
// printf("++++++++++++++++++++++++ umbrellaXQAParams: %s +++++++++++++++++++++++++\n", umbrellaXQAParams.toString().c_str());
103-
104101
if (forConfigurePlugin)
105102
{
106103
for (int beam_width = 1; beam_width <= umbrellaXQAParams.beam_width; ++beam_width)

cpp/tensorrt_llm/kernels/xqaDispatcher.cpp

Lines changed: 0 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -491,119 +491,6 @@ void XqaDispatcher::runImpl(
491491
}
492492
else
493493
{
494-
std::cout << "===== debug: in xqaDispatcher.cpp::runImpl" << std::endl;
495-
std::cout << "============== debug: print XQA params ==============" << std::endl;
496-
497-
// batch_size
498-
auto batch_size = params.batch_size;
499-
auto beam_width = params.beam_width;
500-
auto max_draft_len = 12; // hard code
501-
std::cout << "===== debug: batch_size: " << batch_size << ", beam_width: " << beam_width << std::endl;
502-
503-
// host_past_key_value_lengths
504-
if (params.host_past_key_value_lengths)
505-
{
506-
std::cout << "===== debug: host_past_key_value_lengths: " << *(runtime::ITensor::wrap((void*) params.host_past_key_value_lengths, nvinfer1::DataType::kINT32,
507-
runtime::ITensor::makeShape({batch_size}))) << std::endl;
508-
}
509-
// host_context_lengths
510-
if (params.host_context_lengths)
511-
{
512-
std::cout << "===== debug: host_context_lengths: " << *(runtime::ITensor::wrap((void*) params.host_context_lengths, nvinfer1::DataType::kINT32,
513-
runtime::ITensor::makeShape({batch_size}))) << std::endl;
514-
}
515-
// chunked_attention_size
516-
auto chunked_attention_size = params.chunked_attention_size;
517-
std::cout << "===== debug: chunked_attention_size: " << chunked_attention_size << std::endl;
518-
// max_attention_window_size
519-
auto max_attention_window_size = params.max_attention_window_size;
520-
std::cout << "===== debug: max_attention_window_size: " << max_attention_window_size << std::endl;
521-
// cyclic_attention_window_size
522-
auto cyclic_attention_window_size = params.cyclic_attention_window_size;
523-
std::cout << "===== debug: cyclic_attention_window_size: " << cyclic_attention_window_size << std::endl;
524-
// sink_token_length
525-
auto sink_token_length = params.sink_token_length;
526-
std::cout << "===== debug: sink_token_length: " << sink_token_length << std::endl;
527-
// max_past_kv_length
528-
auto max_past_kv_length = params.max_past_kv_length;
529-
// sequence_lengths
530-
if (params.sequence_lengths)
531-
{
532-
std::cout << "===== debug: sequence_lengths: " << *(runtime::ITensor::wrap((void*) params.sequence_lengths, nvinfer1::DataType::kINT32,
533-
runtime::ITensor::makeShape({batch_size * beam_width}))) << std::endl;
534-
}
535-
// context_lengths
536-
if (params.context_lengths)
537-
{
538-
std::cout << "===== debug: context_lengths: " << *(runtime::ITensor::wrap((void*) params.context_lengths, nvinfer1::DataType::kINT32,
539-
runtime::ITensor::makeShape({batch_size * beam_width}))) << std::endl;
540-
}
541-
542-
// spec_decoding_packed_mask
543-
if (params.spec_decoding_packed_mask)
544-
{
545-
std::cout << "===== debug: spec_decoding_packed_mask: " << *(runtime::ITensor::wrap((void*) params.spec_decoding_packed_mask, nvinfer1::DataType::kINT32,
546-
runtime::ITensor::makeShape({max_draft_len+1, batch_size}))) << std::endl;
547-
}
548-
// spec_decoding_position_offsets
549-
if (params.spec_decoding_position_offsets)
550-
{
551-
std::cout << "===== debug: spec_decoding_position_offsets: " << *(runtime::ITensor::wrap((void*) params.spec_decoding_position_offsets, nvinfer1::DataType::kINT32,
552-
runtime::ITensor::makeShape({batch_size, max_draft_len+1}))) << std::endl;
553-
}
554-
// spec_decoding_generation_lengths
555-
if (params.spec_decoding_generation_lengths)
556-
{
557-
std::cout << "===== debug: spec_decoding_generation_lengths: " << *(runtime::ITensor::wrap((void*) params.spec_decoding_generation_lengths, nvinfer1::DataType::kINT32,
558-
runtime::ITensor::makeShape({batch_size}))) << std::endl;
559-
}
560-
// spec_decoding_is_generation_length_variable
561-
auto spec_decoding_is_generation_length_variable = params.spec_decoding_is_generation_length_variable;
562-
std::cout << "===== debug: spec_decoding_is_generation_length_variable: " << spec_decoding_is_generation_length_variable << std::endl;
563-
// spec_decoding_max_generation_length
564-
auto spec_decoding_max_generation_length = params.spec_decoding_max_generation_length;
565-
std::cout << "===== debug: spec_decoding_max_generation_length: " << spec_decoding_max_generation_length << std::endl;
566-
// generation_input_length
567-
auto generation_input_length = params.generation_input_length;
568-
std::cout << "===== debug: generation_input_length: " << generation_input_length << std::endl;
569-
// num_q_heads
570-
auto num_q_heads = params.num_q_heads;
571-
std::cout << "===== debug: num_q_heads: " << num_q_heads << std::endl;
572-
// num_kv_heads
573-
auto num_kv_heads = params.num_kv_heads;
574-
std::cout << "===== debug: num_kv_heads: " << num_kv_heads << std::endl;
575-
// head_size
576-
auto head_size = params.head_size;
577-
std::cout << "===== debug: head_size: " << head_size << std::endl;
578-
// unidirectional
579-
auto unidirectional = params.unidirectional;
580-
std::cout << "===== debug: unidirectional: " << unidirectional << std::endl;
581-
// position_shift_enabled
582-
auto position_shift_enabled = params.position_shift_enabled;
583-
std::cout << "===== debug: position_shift_enabled: " << position_shift_enabled << std::endl;
584-
// remove_padding
585-
auto remove_padding = params.remove_padding;
586-
std::cout << "===== debug: remove_padding: " << remove_padding << std::endl;
587-
// mask_type
588-
auto mask_type = params.mask_type;
589-
std::cout << "===== debug: mask_type: " << static_cast<int>(mask_type) << std::endl;
590-
// paged_kv_cache
591-
auto paged_kv_cache = params.paged_kv_cache;
592-
std::cout << "===== debug: paged_kv_cache: " << paged_kv_cache << std::endl;
593-
// tokens_per_block
594-
auto tokens_per_block = params.tokens_per_block;
595-
std::cout << "===== debug: tokens_per_block: " << tokens_per_block << std::endl;
596-
// max_blocks_per_sequence
597-
auto max_blocks_per_sequence = params.max_blocks_per_sequence;
598-
std::cout << "===== debug: max_blocks_per_sequence: " << max_blocks_per_sequence << std::endl;
599-
// multi_query_tokens
600-
auto multi_query_tokens = params.multi_query_tokens;
601-
std::cout << "===== debug: multi_query_tokens: " << multi_query_tokens << std::endl;
602-
// total_num_input_tokens
603-
auto total_num_input_tokens = params.total_num_input_tokens;
604-
std::cout << "===== debug: total_num_input_tokens: " << total_num_input_tokens << std::endl;
605-
606-
std::cout << "=========================================" << std::endl;
607494
mDecoderXqaRunner->template dispatch<KVCacheBuffer>(params, kv_cache_buffer, params.stream);
608495
}
609496
}

examples/llm-api/quickstart_advanced.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
TorchCompileConfig)
99

1010
example_prompts = [
11-
# "Hello, my name is",
12-
# "The capital of France is",
13-
# "The future of AI is",
14-
"You are a good assistant. Please tell me the capital of France is",
11+
"Hello, my name is",
12+
"The capital of France is",
13+
"The future of AI is",
1514
]
1615

1716

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,9 @@ def restore_from_spec_dec(self) -> None:
330330
setattr(self, f, v)
331331
self._saved_tensors.clear()
332332

333-
def update_spec_dec_param(self, is_spec_decoding_enabled, spec_metadata, spec_tree_manager,
334-
max_draft_len, max_total_draft_tokens):
333+
def update_spec_dec_param(self, is_spec_decoding_enabled, spec_metadata,
334+
spec_tree_manager, max_draft_len,
335+
max_total_draft_tokens):
335336
"""
336337
Hook to be called when using TRTLLM attention backend in spec-dec mode.
337338
"""

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1044,9 +1044,10 @@ def prepare_context_mla_with_cached_kv(self,
10441044
out=self.host_ctx_kv_indptr[1:self.num_contexts + 1])
10451045
self.ctx_kv_indptr[:self.num_contexts + 1].copy_(
10461046
self.host_ctx_kv_indptr[:self.num_contexts + 1], non_blocking=True)
1047-
1048-
def update_spec_dec_param(self, is_spec_decoding_enabled, spec_metadata, spec_tree_manager,
1049-
max_draft_len, max_total_draft_tokens):
1047+
1048+
def update_spec_dec_param(self, is_spec_decoding_enabled, spec_metadata,
1049+
spec_tree_manager, max_draft_len,
1050+
max_total_draft_tokens):
10501051
# spec_dec mode should only be enabled for pre-Blackwell machines and when there's a spec-dec tree.
10511052
self.is_spec_decoding_enabled = is_spec_decoding_enabled and get_sm_version(
10521053
) < 100
@@ -1086,16 +1087,18 @@ def update_spec_dec_param(self, is_spec_decoding_enabled, spec_metadata, spec_tr
10861087
is_target_model = not spec_metadata.is_draft_model
10871088
is_using_tree = self.is_spec_dec_tree or self.is_spec_dec_dynamic_tree
10881089
if is_target_model and is_using_tree:
1089-
assert spec_metadata.spec_dec_mode.is_eagle3(), "Tree decoding is only supported for Eagle3 now"
1090+
assert spec_metadata.spec_dec_mode.is_eagle3(
1091+
), "Tree decoding is only supported for Eagle3 now"
10901092
# If is the dynamic tree
10911093
if self.is_spec_dec_dynamic_tree:
10921094
# TODO: add dynamic tree logic
10931095
assert False, "Dynamic tree is not supported yet"
10941096
# If is the static tree
10951097
else:
1096-
self.spec_decoding_position_offsets[:,].copy_(
1097-
spec_tree_manager.spec_dec_position_offsets[0, :],
1098-
non_blocking=True)
1098+
self.spec_decoding_position_offsets[
1099+
:,
1100+
].copy_(spec_tree_manager.spec_dec_position_offsets[0, :],
1101+
non_blocking=True)
10991102
self.spec_decoding_packed_mask[:, :, :].copy_(
11001103
spec_tree_manager.spec_dec_packed_mask[0, :, :],
11011104
non_blocking=True)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,7 +1222,7 @@ def _prepare_tp_inputs(
12221222
num_accepted_draft_tokens = [] # per request
12231223
# if using tree decoding, we need to store the request type and accepted path for each request,
12241224
# which will be used to update the hidden_states_read_indices.
1225-
request_accepted_path = {} # per request
1225+
request_accepted_path = {} # per request
12261226

12271227
for request in scheduled_requests.context_requests:
12281228
request_ids.append(request.py_request_id)
@@ -1237,7 +1237,9 @@ def _prepare_tp_inputs(
12371237
gather_ids.append(len(input_ids) - 1)
12381238
sequence_lengths.append(len(prompt_tokens))
12391239
num_accepted_draft_tokens.append(len(prompt_tokens) - 1)
1240-
request_accepted_path[request.py_request_id] = request.py_num_accepted_draft_tokens_indices
1240+
request_accepted_path[
1241+
request.
1242+
py_request_id] = request.py_num_accepted_draft_tokens_indices
12411243
prompt_lengths.append(len(prompt_tokens))
12421244
past_seen_token_num = begin_compute
12431245
num_cached_tokens_per_seq.append(past_seen_token_num)
@@ -1323,7 +1325,9 @@ def _prepare_tp_inputs(
13231325
previous_pos_indices = []
13241326
for request in extend_requests:
13251327
request_ids.append(request.py_request_id)
1326-
request_accepted_path[request.py_request_id] = request.py_num_accepted_draft_tokens_indices
1328+
request_accepted_path[
1329+
request.
1330+
py_request_id] = request.py_num_accepted_draft_tokens_indices
13271331
# the request has no previous tensor:
13281332
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
13291333
# (2) a dummy request; or
@@ -1359,13 +1363,15 @@ def _prepare_tp_inputs(
13591363
assert spec_tree_manager is not None
13601364
assert num_draft_tokens == spec_tree_manager.max_total_draft_tokens
13611365
position_ids.extend(
1362-
past_seen_token_num + spec_tree_manager.spec_dec_position_offsets[0] # [max_total_draft_tokens + 1]
1366+
past_seen_token_num +
1367+
spec_tree_manager.spec_dec_position_offsets[
1368+
0] # [max_total_draft_tokens + 1]
13631369
)
13641370
else:
13651371
position_ids.extend(
13661372
list(
1367-
range(past_seen_token_num, past_seen_token_num + 1 +
1368-
num_draft_tokens)))
1373+
range(past_seen_token_num,
1374+
past_seen_token_num + 1 + num_draft_tokens)))
13691375
num_cached_tokens_per_seq.append(past_seen_token_num)
13701376
request.cached_tokens = num_cached_tokens_per_seq[-1]
13711377
# update batch index
@@ -1390,12 +1396,15 @@ def _prepare_tp_inputs(
13901396
assert spec_tree_manager is not None
13911397
assert num_draft_tokens == spec_tree_manager.max_total_draft_tokens
13921398
position_ids.extend(
1393-
past_seen_token_num + spec_tree_manager.spec_dec_position_offsets[0] # [max_total_draft_tokens + 1]
1399+
past_seen_token_num +
1400+
spec_tree_manager.spec_dec_position_offsets[
1401+
0] # [max_total_draft_tokens + 1]
13941402
)
13951403
else:
13961404
position_ids.extend(
13971405
list(
1398-
range(past_seen_token_num, past_seen_token_num + 1 +
1406+
range(
1407+
past_seen_token_num, past_seen_token_num + 1 +
13991408
self.runtime_draft_len)))
14001409
# previous tensor
14011410
previous_batch_indices.append(previous_batch_idx)
@@ -1433,7 +1442,9 @@ def _prepare_tp_inputs(
14331442
sequence_lengths.append(1 + self.original_max_draft_len)
14341443
num_accepted_draft_tokens.append(
14351444
request.py_num_accepted_draft_tokens)
1436-
request_accepted_path[request.py_request_id] = request.py_num_accepted_draft_tokens_indices
1445+
request_accepted_path[
1446+
request.
1447+
py_request_id] = request.py_num_accepted_draft_tokens_indices
14371448
prompt_lengths.append(request.py_prompt_len)
14381449
past_seen_token_num = begin_compute
14391450
num_cached_tokens_per_seq.append(past_seen_token_num)
@@ -2241,15 +2252,14 @@ def _get_lora_params_from_requests(self,
22412252
return lora_params
22422253

22432254
@nvtx_range("_prepare_inputs")
2244-
def _prepare_inputs(
2245-
self,
2246-
scheduled_requests: ScheduledRequests,
2247-
kv_cache_manager: KVCacheManager,
2248-
attn_metadata: AttentionMetadata,
2249-
spec_metadata: Optional[SpecMetadata] = None,
2250-
new_tensors_device: Optional[SampleStateTensors] = None,
2251-
cache_indirection_buffer: Optional[torch.Tensor] = None,
2252-
resource_manager: Optional[ResourceManager] = None):
2255+
def _prepare_inputs(self,
2256+
scheduled_requests: ScheduledRequests,
2257+
kv_cache_manager: KVCacheManager,
2258+
attn_metadata: AttentionMetadata,
2259+
spec_metadata: Optional[SpecMetadata] = None,
2260+
new_tensors_device: Optional[SampleStateTensors] = None,
2261+
cache_indirection_buffer: Optional[torch.Tensor] = None,
2262+
resource_manager: Optional[ResourceManager] = None):
22532263
if self.mapping is not None and 'cp_type' in self.mapping.cp_config:
22542264
cp_type = self.mapping.cp_config['cp_type']
22552265
if CpType.STAR == cp_type:
@@ -2297,7 +2307,8 @@ def forward(
22972307
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
22982308
attn_metadata.update_spec_dec_param(
22992309
is_spec_dec_mode, spec_metadata, spec_tree_manager,
2300-
self.original_max_draft_len, self.original_max_total_draft_tokens)
2310+
self.original_max_draft_len,
2311+
self.original_max_total_draft_tokens)
23012312
else:
23022313
spec_resource_manager = None
23032314
spec_metadata = None

0 commit comments

Comments
 (0)