Skip to content

Commit d9a3530

Browse files
[nvbug/5393888][nvbug/5393042] Always use py_seq_slot (#6147)
Signed-off-by: Netanel Haber <[email protected]>
1 parent d475c97 commit d9a3530

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,7 +1152,7 @@ def _prepare_tp_inputs(
11521152
if multimodal_params.has_content():
11531153
multimodal_params_list.append(multimodal_params)
11541154

1155-
request.py_batch_idx = request.seq_slot
1155+
request.py_batch_idx = request.py_seq_slot
11561156

11571157
num_ctx_requests = len(scheduled_requests.context_requests)
11581158
num_ctx_tokens = len(input_ids)
@@ -1234,11 +1234,11 @@ def _prepare_tp_inputs(
12341234
num_cached_tokens_per_seq.append(past_seen_token_num)
12351235
request_ids.append(request.py_request_id)
12361236
# update batch index
1237-
request.py_batch_idx = request.seq_slot
1237+
request.py_batch_idx = request.py_seq_slot
12381238
else:
12391239
# update batch index
12401240
previous_batch_idx = request.py_batch_idx
1241-
request.py_batch_idx = request.seq_slot
1241+
request.py_batch_idx = request.py_seq_slot
12421242
# inputs
12431243
# overlap scheduler can only support the speculative decoding
12441244
# methods with a fixed number of draft tokens
@@ -1292,8 +1292,8 @@ def _prepare_tp_inputs(
12921292
gather_ids.append(len(position_ids) - 1)
12931293

12941294
request_ids.append(request.py_request_id)
1295-
gen_request_seq_slots.append(request.seq_slot)
1296-
request.py_batch_idx = request.seq_slot
1295+
gen_request_seq_slots.append(request.py_seq_slot)
1296+
request.py_batch_idx = request.py_seq_slot
12971297

12981298
previous_batch_len = len(previous_batch_indices)
12991299

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def add_token(request: LlmRequest,
194194
*,
195195
beam: int,
196196
step: int = 0) -> int:
197-
seq_slot = request.seq_slot
197+
seq_slot = request.py_seq_slot
198198
assert seq_slot is not None
199199
new_token = int(new_tokens[step, seq_slot, beam])
200200
request.add_new_token(new_token, beam)
@@ -285,14 +285,14 @@ def _handle_stop_criteria(self, request: LlmRequest,
285285

286286
def handle_logits(self, request: LlmRequest, state: SampleState, *,
287287
beam: int, count: int):
288-
current_slice = slice(0, count), request.seq_slot, beam
288+
current_slice = slice(0, count), request.py_seq_slot, beam
289289
if request.py_return_generation_logits:
290290
assert state.host.logits is not None
291291
current_logits = state.host.logits[current_slice]
292292
request.py_result.append_generation_logits(current_logits)
293293
if request.py_return_log_probs:
294294
assert state.host.log_probs is not None
295-
log_probs = state.host.log_probs[request.seq_slot][beam][:count]
295+
log_probs = state.host.log_probs[request.py_seq_slot][beam][:count]
296296
current_tokens = state.host.new_tokens[current_slice]
297297

298298
token_log_probs = [{
@@ -406,7 +406,7 @@ def _process_requests(self,
406406
no_draft_tokens = len(requests) == sum_steps
407407
fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
408408

409-
seq_slots = torch.as_tensor([r.seq_slot for r in requests])
409+
seq_slots = torch.as_tensor([r.py_seq_slot for r in requests])
410410
seq_slots = seq_slots.to(device="cuda", non_blocking=True)
411411

412412
if fast_path:
@@ -616,9 +616,9 @@ def _update_cache_indirection_buffer(self,
616616
# Copy cache indirection output to input
617617
for request in scheduled_requests.generation_requests:
618618
self.store["decoder_state"].cache_indirection_input[
619-
request.seq_slot].copy_(
619+
request.py_seq_slot].copy_(
620620
self.store["decoder_state"].cache_indirection_output[
621-
request.seq_slot],
621+
request.py_seq_slot],
622622
non_blocking=True)
623623

624624
@torch.inference_mode()
@@ -881,7 +881,7 @@ def update_requests_multiple_beams_or_drafting(self,
881881

882882
def _finalize_request(self, request: LlmRequest, streaming: bool):
883883
""" Finalizes the request. This is necessary for beam search. """
884-
seq_slot = request.seq_slot
884+
seq_slot = request.py_seq_slot
885885
event = self.algs.decoder.finalize(self.store["decoder_state"],
886886
seq_slot, request.sampling_config,
887887
streaming)
@@ -893,7 +893,7 @@ def _post_process_request(self, request: LlmRequest,
893893
request: LlmRequest which shall be post processed
894894
finalize_event: CudaEvent to wait for the finalize step to finish
895895
"""
896-
seq_slot = request.seq_slot
896+
seq_slot = request.py_seq_slot
897897
beam_width = request.sampling_config.beam_width
898898
# synchronize on the finalize event before continuing the post processing.
899899
finalize_event.synchronize()

tensorrt_llm/_torch/speculative/mtp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _request_common_handling(self, request: LlmRequest,
232232
assert not request.py_return_context_logits, "return_context_logits not implemented for MTPSampler"
233233
assert not request.py_return_generation_logits, "return_generation_logits not implemented for MTPSampler"
234234
assert not request.py_return_log_probs, "return_log_probs not implemented for MTPSampler"
235-
request.py_draft_tokens = next_draft_tokens[request.seq_slot]
235+
request.py_draft_tokens = next_draft_tokens[request.py_seq_slot]
236236
request.py_decoding_iter += 1
237237

238238
def update_requests(self, state: SampleStateMTP) -> None:
@@ -253,7 +253,7 @@ def update_requests(self, state: SampleStateMTP) -> None:
253253
for req in state.scheduled_requests.generation_requests:
254254
if req.state == LlmRequestState.GENERATION_COMPLETE:
255255
continue
256-
num_new_tokens = new_tokens_lens[req.seq_slot]
256+
num_new_tokens = new_tokens_lens[req.py_seq_slot]
257257
for i in range(num_new_tokens):
258258
new_token = add_token(req, new_tokens, beam=beam_idx, step=i)
259259
if self._handle_stop_criteria(req, new_token):
@@ -269,7 +269,7 @@ def sample_async(self, scheduled_requests: ScheduledRequests,
269269
# next_new_tokens_device: input tokens for the next iteration, device tensor, shape: batch_size, nextn + 1
270270

271271
requests = scheduled_requests.all_requests()
272-
slots = torch.as_tensor([r.seq_slot for r in requests])
272+
slots = torch.as_tensor([r.py_seq_slot for r in requests])
273273
slots = slots.to(device="cuda", non_blocking=True)
274274

275275
o_new_tokens = outputs['new_tokens'][:len(requests)]

0 commit comments

Comments
 (0)