Skip to content

Commit b6ca677

Browse files
authored
refactor: remove decoder request from decoder interface (#5129)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 4f9fa9f commit b6ca677

File tree

14 files changed

+106
-95
lines changed

14 files changed

+106
-95
lines changed

cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ class CreateNewDecoderRequests : Algorithm
7171
{
7272
}
7373

74-
std::tuple<TensorPtr, std::vector<runtime::decoder_batch::Request>, std::vector<runtime::SamplingConfig>>
74+
std::tuple<TensorPtr, std::vector<runtime::SamplingConfig>, std::vector<runtime::ITensor::SharedConstPtr>,
75+
std::vector<executor::LookaheadDecodingConfig>>
7576
operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
7677
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
7778
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
@@ -113,8 +114,9 @@ class CreateNewDecoderRequests : Algorithm
113114
static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request,
114115
runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream);
115116

116-
[[nodiscard]] std::vector<runtime::decoder_batch::Request> createDecoderRequests(
117-
RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
117+
[[nodiscard]] std::tuple<std::vector<runtime::ITensor::SharedConstPtr>,
118+
std::vector<executor::LookaheadDecodingConfig>>
119+
createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
118120
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
119121
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType,
120122
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,

cpp/include/tensorrt_llm/runtime/gptDecoder.h

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ class IGptDecoder
5353

5454
virtual ~IGptDecoder() = default;
5555

56+
/// @param explicitDraftTokensDType is only used by ExplicitDraftTokens model to WAR the lack of bf16 decoder.
5657
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots,
5758
std::optional<DecodingOutput> const& output = std::nullopt,
58-
std::optional<std::vector<decoder_batch::Request> const> const& requests = std::nullopt)
59+
std::optional<nvinfer1::DataType> explicitDraftTokensDType = std::nullopt,
60+
std::optional<std::vector<TensorConstPtr>> const& lookaheadPrompt = std::nullopt,
61+
std::optional<std::vector<executor::LookaheadDecodingConfig>> const& lookaheadAlgoConfigs = std::nullopt)
5962
= 0;
6063

6164
virtual void forwardAsync(DecodingOutput& output, DecodingInput const& input) = 0;
@@ -69,7 +72,7 @@ class IGptDecoder
6972
= 0;
7073

7174
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
72-
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
75+
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
7376
BufferManager::CudaStreamPtr const& stream,
7477
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule = nullptr);
7578
};
@@ -83,12 +86,15 @@ class GptDecoder : public virtual IGptDecoder
8386
using TensorPtr = std::shared_ptr<ITensor>;
8487

8588
GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
86-
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream,
89+
size_t vocabSizePadded, CudaStreamPtr const& stream,
8790
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr);
8891

8992
void setup(SamplingConfig const& samplingConfig, size_t batchSize, TensorConstPtr const& batchSlots,
9093
std::optional<DecodingOutput> const& output = std::nullopt,
91-
std::optional<std::vector<decoder_batch::Request> const> const& requests = std::nullopt) override;
94+
std::optional<nvinfer1::DataType> explicitDraftTokensDType = std::nullopt,
95+
std::optional<std::vector<TensorConstPtr>> const& lookaheadPrompt = std::nullopt,
96+
std::optional<std::vector<executor::LookaheadDecodingConfig>> const& lookaheadAlgoConfigs
97+
= std::nullopt) override;
9298

9399
void forwardAsync(DecodingOutput& output, DecodingInput const& input) override;
94100

@@ -117,18 +123,18 @@ class GptDecoder : public virtual IGptDecoder
117123
};
118124

119125
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
120-
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
126+
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
121127
BufferManager::CudaStreamPtr const& stream,
122128
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule)
123129
{
124130
switch (dtype)
125131
{
126132
case nvinfer1::DataType::kFLOAT:
127-
return std::make_unique<GptDecoder<float>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
128-
maxSequenceLength, stream, speculativeDecodingModule);
133+
return std::make_unique<GptDecoder<float>>(
134+
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
129135
case nvinfer1::DataType::kHALF:
130-
return std::make_unique<GptDecoder<half>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
131-
maxSequenceLength, stream, speculativeDecodingModule);
136+
return std::make_unique<GptDecoder<half>>(
137+
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
132138
default:
133139
TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast<int>(dtype));
134140
return nullptr;

cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ class GptDecoderBatched : public IGptDecoderBatched
4848
explicit GptDecoderBatched(CudaStreamPtr stream);
4949

5050
void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
51-
SizeType32 maxSequenceLength, nvinfer1::DataType dtype, ModelConfig const& modelConfig,
52-
WorldConfig const& worldConfig) override;
51+
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override;
5352

5453
void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;
5554

cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,7 @@ class IGptDecoderBatched
119119

120120
//! @brief Setup the decoder before calling `forward()`
121121
virtual void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
122-
SizeType32 maxSequenceLength, nvinfer1::DataType dtype, ModelConfig const& modelConfig,
123-
WorldConfig const& worldConfig)
122+
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig)
124123
= 0;
125124

126125
//! @brief Disable Lookahead decoding.

cpp/include/tensorrt_llm/runtime/request.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class Request
5757
std::optional<TensorPtr> draftLogits; // [generatedTokensPerEngineStep - 1, vocabSize] on gpu
5858
TensorPtr medusaPaths; // [maxDecodingTokens, maxPathLen], on gpu
5959
TensorPtr medusaTreeIds; // [maxDecodingTokens], on gpu
60-
nvinfer1::DataType dtype; // Request data type, only used by explicit draft tokens.
6160
std::optional<executor::LookaheadDecodingConfig> lookaheadRuntimeConfig;
6261
std::optional<executor::EagleConfig> eagleConfig;
6362
};

cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe
122122

123123
} // namespace
124124

125-
std::tuple<TensorPtr, std::vector<runtime::decoder_batch::Request>, std::vector<runtime::SamplingConfig>>
125+
std::tuple<TensorPtr, std::vector<runtime::SamplingConfig>, std::vector<runtime::ITensor::SharedConstPtr>,
126+
std::vector<executor::LookaheadDecodingConfig>>
126127
CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
127128
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
128129
runtime::BufferManager const& bufferManager, nvinfer1::DataType logitsType, DecoderInputBuffers& inputBuffers,
@@ -139,9 +140,9 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru
139140
copySequenceLengths(finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth,
140141
bufferManager, runtimeStream);
141142

142-
auto decoderRequests = createDecoderRequests(finishedContextRequests, inputBuffers.inputsIds, decodingConfig,
143-
decoderState, bufferManager, logitsType, modelConfig, worldConfig, runtimeStream, decoderStream,
144-
maxSequenceLength, medusaBuffers);
143+
auto [lookaheadPrompt, lookaheadAlgoConfigs] = createDecoderRequests(finishedContextRequests,
144+
inputBuffers.inputsIds, decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig,
145+
runtimeStream, decoderStream, maxSequenceLength, medusaBuffers);
145146

146147
auto const batchSize = finishedContextRequests.size();
147148

@@ -155,7 +156,8 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru
155156
TensorPtr batchSlotsView = runtime::ITensor::slice(inputBuffers.setupBatchSlots, 0, batchSize);
156157

157158
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
158-
return {std::move(batchSlotsView), std::move(decoderRequests), std::move(samplingConfigs)};
159+
return {std::move(batchSlotsView), std::move(samplingConfigs), std::move(lookaheadPrompt),
160+
std::move(lookaheadAlgoConfigs)};
159161
}
160162

161163
void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder_batch::Request const& request,
@@ -555,8 +557,8 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
555557
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
556558
}
557559

558-
[[nodiscard]] std::vector<runtime::decoder_batch::Request> CreateNewDecoderRequests::createDecoderRequests(
559-
RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
560+
std::tuple<std::vector<runtime::ITensor::SharedConstPtr>, std::vector<executor::LookaheadDecodingConfig>>
561+
CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds,
560562
executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState,
561563
BufferManager const& bufferManager, nvinfer1::DataType logitsType, runtime::ModelConfig const& modelConfig,
562564
runtime::WorldConfig const& worldConfig, runtime::CudaStream const& runtimeStream,
@@ -574,6 +576,16 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
574576
std::vector<decoder_batch::Request> decoderRequests;
575577
decoderRequests.reserve(finishedContextRequests.size());
576578

579+
std::vector<runtime::ITensor::SharedConstPtr> lookaheadPrompt;
580+
std::vector<executor::LookaheadDecodingConfig> lookaheadAlgoConfigs;
581+
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
582+
{
583+
TLLM_CHECK_WITH_INFO(
584+
decodingConfig.getLookaheadDecodingConfig().has_value(), "Lookahead decoding config must be provided");
585+
lookaheadPrompt.reserve(finishedContextRequests.size());
586+
lookaheadAlgoConfigs.reserve(finishedContextRequests.size());
587+
}
588+
577589
SizeType32 inputOffset{0};
578590
for (auto const& llmReq : finishedContextRequests)
579591
{
@@ -620,14 +632,11 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
620632
}
621633
else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
622634
{
623-
decoderRequest.lookaheadRuntimeConfig = llmReq->getLookaheadConfig()
624-
? llmReq->getLookaheadConfig()
625-
: decodingConfig.getLookaheadDecodingConfig();
626-
}
627-
else if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
628-
{
629-
// Only Explicit draft tokens model needs dtype to WAR the lack of bf16 decoder.
630-
decoderRequest.dtype = modelConfig.getDataType();
635+
lookaheadPrompt.emplace_back(ITensor::slice(decoderRequest.ids, 0, decoderRequest.inputLen));
636+
637+
auto const& lookaheadRuntimeConfig
638+
= llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value());
639+
lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig);
631640
}
632641
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
633642
{
@@ -659,7 +668,7 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec
659668
inputOffset += promptLen;
660669
}
661670

662-
return decoderRequests;
671+
return {std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
663672
}
664673

665674
std::shared_ptr<runtime::ITensor> CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig,

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,8 +1424,8 @@ void TrtGptModelInflightBatching::createDecoder(std::optional<executor::Decoding
14241424
}
14251425

14261426
mDecoder = std::make_unique<runtime::GptDecoderBatched>(mRuntime->getStreamPtr());
1427-
mDecoder->setup(decodingMode, getMaxNumSequences(), mOperatingBeamWidth, getMaxSequenceLen(), decoderType,
1428-
mModelConfig, mWorldConfig);
1427+
mDecoder->setup(
1428+
decodingMode, getMaxNumSequences(), mOperatingBeamWidth, decoderType, mModelConfig, mWorldConfig);
14291429

14301430
mDecoderState = std::make_unique<runtime::decoder::DecoderState>(decoderType, mRuntime->getBufferManager());
14311431
if (!mModelConfig.getSpeculativeDecodingMode().isNone())
@@ -1786,18 +1786,18 @@ void TrtGptModelInflightBatching::setupDecoderStep(
17861786
{
17871787
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
17881788

1789-
auto [batchSlots, decoderRequests, samplingConfigs]
1789+
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
17901790
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests,
17911791
mRuntime->getBufferManager(), logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(),
17921792
*mDecoder->getDecoderStream(), getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers);
17931793

1794-
if (!decoderRequests.empty())
1794+
auto const localBatchSize = batchSlots->getSize();
1795+
if (localBatchSize > 0)
17951796
{
1796-
// Setup underlying decoder.
1797-
auto const localBatchSize = batchSlots->getSize();
17981797
auto samplingConfig = SamplingConfig(samplingConfigs);
17991798
mDecoder->getUnderlyingDecoder().setup(samplingConfig, localBatchSize, batchSlots,
1800-
{mDecoderState->getJointDecodingOutput()}, {decoderRequests});
1799+
{mDecoderState->getJointDecodingOutput()}, mModelConfig.getDataType(), lookaheadPrompt,
1800+
lookaheadAlgoConfigs);
18011801

18021802
auto const& stream = mDecoder->getDecoderStream();
18031803
CudaEvent event{};

cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,12 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
158158
tensorrt_llm::runtime::CudaStream const& decoderStream, SizeType32 maxSequenceLength,
159159
SizeType32 beamWidth, OptionalRef<MedusaBuffers const> medusaBuffers = std::nullopt)
160160
{
161-
auto [batchSlots, decoderRequests, samplingConfigs] = self(modelConfig, worldConfig, decodingConfig,
162-
contextRequests, bufferManager, logitsType, inputBuffers, decoderState, runtimeStream,
163-
decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
161+
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] = self(modelConfig,
162+
worldConfig, decodingConfig, contextRequests, bufferManager, logitsType, inputBuffers, decoderState,
163+
runtimeStream, decoderStream, maxSequenceLength, beamWidth, medusaBuffers);
164164

165-
return std::tuple{
166-
runtime::Torch::tensor(batchSlots), std::move(decoderRequests), std::move(samplingConfigs)};
165+
return std::tuple{runtime::Torch::tensor(batchSlots), std::move(samplingConfigs),
166+
std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)};
167167
},
168168
py::arg("model_config"), py::arg("world_config"), py::arg("decoding_config"), py::arg("context_requests"),
169169
py::arg("buffer_manager"), py::arg("logits_type"), py::arg("decoder_input_buffers"),

cpp/tensorrt_llm/pybind/runtime/bindings.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838
#include "tensorrt_llm/runtime/speculativeDecodingMode.h"
3939
#include "tensorrt_llm/runtime/tllmRuntime.h"
4040
#include "tensorrt_llm/runtime/torchView.h"
41+
4142
#include <ATen/ATen.h>
4243
#include <c10/cuda/CUDAStream.h>
4344
#include <pybind11/stl.h>
4445
#include <pybind11/stl_bind.h>
4546
#include <torch/extension.h>
4647

4748
namespace tr = tensorrt_llm::runtime;
49+
namespace te = tensorrt_llm::executor;
4850

4951
class PyITensor : public tensorrt_llm::runtime::ITensor
5052
{
@@ -160,9 +162,12 @@ class PyIGptDecoder : public tr::IGptDecoder
160162
void setup(tr::SamplingConfig const& samplingConfig, size_t batchSize,
161163
tr::DecodingInput::TensorConstPtr const& batchSlots,
162164
std::optional<tr::DecodingOutput> const& output = std::nullopt,
163-
std::optional<std::vector<tr::decoder_batch::Request> const> const& requests = std::nullopt) override
165+
std::optional<nvinfer1::DataType> explicitDraftTokensDType = std::nullopt,
166+
std::optional<std::vector<tr::ITensor::SharedConstPtr>> const& lookaheadPrompt = std::nullopt,
167+
std::optional<std::vector<te::LookaheadDecodingConfig>> const& lookaheadAlgoConfigs = std::nullopt) override
164168
{
165-
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, setup, samplingConfig, batchSize, batchSlots, output, requests);
169+
PYBIND11_OVERRIDE_PURE(void, IGptDecoder, setup, samplingConfig, batchSize, batchSlots, output,
170+
explicitDraftTokensDType, lookaheadPrompt, lookaheadAlgoConfigs);
166171
}
167172

168173
void forwardAsync(tr::DecodingOutput& output, tr::DecodingInput const& input) override
@@ -314,13 +319,17 @@ void initBindings(pybind11::module_& m)
314319
"setup",
315320
[](tr::IGptDecoder& self, tr::SamplingConfig const& samplingConfig, size_t batchSize,
316321
at::Tensor const& batchSlots, std::optional<tr::DecodingOutput> const& output = std::nullopt,
317-
std::optional<std::vector<tr::decoder_batch::Request> const> const& requests = std::nullopt)
322+
std::optional<nvinfer1::DataType> explicitDraftTokensDType = std::nullopt,
323+
std::optional<std::vector<tr::ITensor::SharedConstPtr>> const& lookaheadPrompt = std::nullopt,
324+
std::optional<std::vector<te::LookaheadDecodingConfig>> const& lookaheadAlgoConfigs = std::nullopt)
318325
{
319326
auto tensorPtrBatchSlots = tr::TorchView::of(batchSlots);
320-
return self.setup(samplingConfig, batchSize, std::move(tensorPtrBatchSlots), output, requests);
327+
self.setup(samplingConfig, batchSize, std::move(tensorPtrBatchSlots), output, explicitDraftTokensDType,
328+
lookaheadPrompt, lookaheadAlgoConfigs);
321329
},
322330
py::arg("sampling_config"), py::arg("batch_size"), py::arg("batch_slots"), py::arg("output") = std::nullopt,
323-
py::arg("requests") = std::nullopt);
331+
py::arg("explicit_draft_tokens_d_type") = std::nullopt, py::arg("lookahead_prompt") = std::nullopt,
332+
py::arg("lookahead_algo_configs") = std::nullopt);
324333

325334
py::class_<tr::decoder::DecoderState>(m, "DecoderState")
326335
.def(py::init<nvinfer1::DataType, tr::BufferManager const&>(), py::arg("dtype"), py::arg("buffer_manager"))
@@ -381,8 +390,7 @@ void initBindings(pybind11::module_& m)
381390
py::class_<tr::GptDecoderBatched>(m, "GptDecoderBatched")
382391
.def(py::init<tr::GptDecoderBatched::CudaStreamPtr>(), py::arg("stream"))
383392
.def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_batch_size"),
384-
py::arg("max_beam_width"), py::arg("max_sequence_length"), py::arg("dtype"), py::arg("model_config"),
385-
py::arg("world_config"))
393+
py::arg("max_beam_width"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config"))
386394
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, py::arg("decoder_state"), py::arg("output"),
387395
py::arg("input"))
388396
.def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, py::return_value_policy::reference)

0 commit comments

Comments
 (0)