Skip to content

Commit 918fedf

Browse files
authored
[None][refactor] Simplify finish reasons handling in DecoderState (#6524)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 67a3fd8 commit 918fedf

File tree

6 files changed

+9
-55
lines changed

6 files changed

+9
-55
lines changed

cpp/include/tensorrt_llm/runtime/decoderState.h

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class DecoderState
7171
//! @returns [batchSize], number of finished sequences per request, on gpu
7272
[[nodiscard]] TensorPtr getFinishedSum() const;
7373

74-
//! @returns [batchSize, beamWidth], FinishedState value, on gpu
74+
//! @returns [batchSize, beamWidth], finished states of type FinishedState, on gpu
7575
[[nodiscard]] TensorPtr getFinishReasons() const;
7676

7777
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
@@ -134,9 +134,6 @@ class DecoderState
134134
//! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu
135135
[[nodiscard]] TensorPtr getAcceptedPackedPaths() const;
136136

137-
//! @returns [maxTokensPerStep, batchSize, beamWidth], finished states of type FinishedState, on gpu
138-
[[nodiscard]] TensorPtr getFinishedSteps() const;
139-
140137
[[nodiscard]] SizeType32 getMaxBatchSize() const;
141138

142139
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
@@ -221,10 +218,6 @@ class DecoderState
221218
//! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots.
222219
DecodingOutputPtr mJointDecodingOutput;
223220

224-
//! @brief [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState for each generated token
225-
//! of maxTokensPerStep, on gpu
226-
TensorPtr mFinishedSteps;
227-
228221
//! @brief Workspace for beam search in streaming mode.
229222
std::unique_ptr<BeamSearchBuffers> mBeamSearchBuffers;
230223

cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -272,22 +272,8 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder
272272
manager.setZero(*newTokensVec);
273273
}
274274

275-
// FIXME: we call setZero mMaxDecodingEngineTokens times for only 1 element
276-
for (SizeType32 ti = 0; ti < decoderState.getMaxDecodingEngineTokens(); ++ti)
277-
{
278-
TensorPtr const finishedStepsView = ITensor::slice(decoderState.getFinishedSteps(), ti, 1);
279-
finishedStepsView->squeeze(0);
280-
TensorPtr const finishedSteps = ITensor::slice(finishedStepsView, batchSlot, 1);
281-
if (ti < numDecodingEngineTokens)
282-
{
283-
manager.setZero(*finishedSteps);
284-
}
285-
else
286-
{
287-
runtime::kernels::invokeFill(
288-
*finishedSteps, tk::FinishedState::skipDecoding().toUnderlying(), decoderStream);
289-
}
290-
}
275+
TensorPtr const finishedStepsSlice = ITensor::slice(decoderState.getFinishReasons(), batchSlot, 1);
276+
manager.setZero(*finishedStepsSlice);
291277

292278
// cumLogProb is mandatory for beamWidth > 1
293279
if ((samplingConfig.cumLogProbs.has_value() && samplingConfig.cumLogProbs->at(0)) || beamWidth > 1)

cpp/tensorrt_llm/nanobind/runtime/bindings.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,6 @@ void initBindings(nb::module_& m)
255255
.def_prop_ro("next_draft_tokens_lengths", &tr::decoder::DecoderState::getNextDraftTokensLengths)
256256
.def_prop_ro("accepted_lengths_cum_sum", &tr::decoder::DecoderState::getAcceptedLengthsCumSum)
257257
.def_prop_ro("accepted_packed_paths", &tr::decoder::DecoderState::getAcceptedPackedPaths)
258-
.def_prop_ro("finished_steps", &tr::decoder::DecoderState::getFinishedSteps)
259258
.def_prop_ro("max_beam_width", &tr::decoder::DecoderState::getMaxBeamWidth)
260259
.def_prop_ro("max_sequence_length", &tr::decoder::DecoderState::getMaxSequenceLength)
261260
.def_prop_ro("max_decoding_decoder_tokens", &tr::decoder::DecoderState::getMaxDecodingDecoderTokens)

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,6 @@ void initBindings(pybind11::module_& m)
349349
.def_property_readonly("next_draft_tokens_lengths", &tr::decoder::DecoderState::getNextDraftTokensLengths)
350350
.def_property_readonly("accepted_lengths_cum_sum", &tr::decoder::DecoderState::getAcceptedLengthsCumSum)
351351
.def_property_readonly("accepted_packed_paths", &tr::decoder::DecoderState::getAcceptedPackedPaths)
352-
.def_property_readonly("finished_steps", &tr::decoder::DecoderState::getFinishedSteps)
353352
.def_property_readonly("max_beam_width", &tr::decoder::DecoderState::getMaxBeamWidth)
354353
.def_property_readonly("max_sequence_length", &tr::decoder::DecoderState::getMaxSequenceLength)
355354
.def_property_readonly("max_decoding_decoder_tokens", &tr::decoder::DecoderState::getMaxDecodingDecoderTokens)

cpp/tensorrt_llm/runtime/decoderState.cpp

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,15 @@ void DecoderState::setupBuffers(nvinfer1::DataType dtype, BufferManager const& b
8989

9090
dOutput->lengths = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
9191

92-
// use batchSize many entries instead of the usual 1
9392
dOutput->finishedSum = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
9493
// we don't need dOutput->lengths because lengths are passed from outside
9594
dOutput->cumLogProbs = bufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
9695
dOutput->logProbs = bufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
9796
dOutput->beamHypotheses.empty(bufferManager);
97+
9898
dOutput->finishReasons
9999
= bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<tk::FinishedState::UnderlyingType>::value);
100+
dInput->finishReasons = dOutput->finishReasons;
100101

101102
dOutput->logProbsTiled = bufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
102103

@@ -106,8 +107,6 @@ void DecoderState::setupBuffers(nvinfer1::DataType dtype, BufferManager const& b
106107
dInput->badWordsLens = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
107108
dInput->embeddingBias = bufferManager.emptyTensor(MemoryType::kGPU, dtype);
108109

109-
mFinishedSteps = bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<tk::FinishedState::UnderlyingType>::value);
110-
111110
mBeamSearchBuffers = std::make_unique<BeamSearchBuffers>(bufferManager);
112111

113112
setupCacheIndirectionBuffers(bufferManager);
@@ -245,10 +244,6 @@ void DecoderState::reshapeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWid
245244
auto& dOutput = *mJointDecodingOutput;
246245
dOutput.ids->reshape(maxTotalTokensShape);
247246

248-
auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxBatchSize, mMaxBeamWidth});
249-
mFinishedSteps->reshape(maxNewTokensShape);
250-
bufferManager.setZero(*mFinishedSteps);
251-
252247
dOutput.finishReasons->reshape(maxBatchSizeXmaxBeamWidthShape);
253248
bufferManager.setZero(*dOutput.finishReasons);
254249

@@ -260,6 +255,7 @@ void DecoderState::reshapeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWid
260255
dOutput.finishedSum->reshape(maxBatchSizeShape);
261256
bufferManager.setZero(*dOutput.finishedSum);
262257

258+
auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxBatchSize, mMaxBeamWidth});
263259
dOutput.newTokensSteps->reshape(maxNewTokensShape);
264260
bufferManager.setZero(*dOutput.newTokensSteps);
265261

@@ -342,8 +338,6 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con
342338
mMaxDecodingEngineTokens);
343339

344340
auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxBatchSize, mMaxBeamWidth});
345-
mFinishedSteps->reshape(maxNewTokensShape);
346-
bufferManager.setZero(*mFinishedSteps);
347341
dOutput.newTokensSteps->reshape(maxNewTokensShape);
348342
bufferManager.setZero(*dOutput.newTokensSteps);
349343

@@ -454,7 +448,6 @@ void DecoderState::disableLookahead(RequestVector const& genRequests)
454448

455449
auto const maxNewTokensShape = ITensor::makeShape({mMaxDecodingEngineTokens, mMaxBatchSize, mMaxBeamWidth});
456450
mJointDecodingOutput->newTokensSteps->reshape(maxNewTokensShape);
457-
mFinishedSteps->reshape(maxNewTokensShape);
458451

459452
for (auto const& llmReq : genRequests)
460453
{
@@ -562,11 +555,6 @@ TensorPtr DecoderState::getAcceptedPackedPaths() const
562555
return mJointDecodingOutput->speculativeDecodingOutputs->pathsOffsets;
563556
}
564557

565-
TensorPtr DecoderState::getFinishedSteps() const
566-
{
567-
return mFinishedSteps;
568-
}
569-
570558
SizeType32 DecoderState::getMaxBatchSize() const
571559
{
572560
return mMaxBatchSize;

cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -116,16 +116,6 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step,
116116
dInput.batchSize = static_cast<SizeType32>(dInput.batchSlots->getSize());
117117
dInput.logitsVec = input.logits.at(step);
118118

119-
TensorPtr finishedStepsInput = ITensor::slice(decoderState.getFinishedSteps(), step, 1);
120-
TensorPtr finishedStepsOutput
121-
= ITensor::slice(decoderState.getFinishedSteps(), std::min(input.maxDecoderSteps - 1, step + 1), 1);
122-
finishedStepsInput->squeeze(0);
123-
finishedStepsOutput->squeeze(0);
124-
TensorPtr newTokensStepView
125-
= ITensor::slice(dOutput.newTokensSteps, step, decoderState.getMaxDecodingDecoderTokens());
126-
127-
dInput.finishReasons = finishedStepsInput;
128-
129119
if (speculativeDecodingMode.isDraftTokensExternal())
130120
{
131121
dInput.externalDraftTokensInputs->step = step;
@@ -136,14 +126,13 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step,
136126
auto batchSlotsRange = BufferRange<SizeType32 const>(*dInput.batchSlots);
137127
for (auto batchSlot : batchSlotsRange)
138128
{
139-
TensorPtr finishedSteps = ITensor::slice(finishedStepsInput, batchSlot, 1);
140-
bufferManager.setZero(*finishedSteps);
129+
TensorPtr finishedStepsSlice = ITensor::slice(decoderState.getFinishReasons(), batchSlot, 1);
130+
bufferManager.setZero(*finishedStepsSlice);
141131
}
142132
}
143133
}
144134

145-
dOutput.newTokens = newTokensStepView;
146-
dOutput.finishReasons = finishedStepsOutput;
135+
dOutput.newTokens = ITensor::slice(dOutput.newTokensSteps, step, decoderState.getMaxDecodingDecoderTokens());
147136

148137
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
149138
}

0 commit comments

Comments
 (0)