@@ -137,6 +137,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
137137 return UniqueCUvideodecoder (decoder, CUvideoDecoderDeleter{});
138138}
139139
140+ cudaVideoCodec validateCodecSupport (AVCodecID codecId) {
141+ switch (codecId) {
142+ case AV_CODEC_ID_H264:
143+ return cudaVideoCodec_H264;
144+ case AV_CODEC_ID_HEVC:
145+ return cudaVideoCodec_HEVC;
146+ // TODONVDEC P0: support more codecs
147+ // case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
148+ // case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
149+ // case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
150+ // case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
151+ // case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
152+ default : {
153+ TORCH_CHECK (false , " Unsupported codec type: " , avcodec_get_name (codecId));
154+ }
155+ }
156+ }
157+
140158} // namespace
141159
142160BetaCudaDeviceInterface::BetaCudaDeviceInterface (const torch::Device& device)
@@ -162,36 +180,100 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
162180 }
163181}
164182
165- void BetaCudaDeviceInterface::initialize (const AVStream* avStream) {
183+ void BetaCudaDeviceInterface::initialize (
184+ const AVStream* avStream,
185+ const UniqueDecodingAVFormatContext& avFormatCtx) {
166186 torch::Tensor dummyTensorForCudaInitialization = torch::empty (
167187 {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
168188
169- TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
170- timeBase_ = avStream->time_base ;
171-
172189 auto cudaDevice = torch::Device (torch::kCUDA );
173190 defaultCudaInterface_ =
174191 std::unique_ptr<DeviceInterface>(createDeviceInterface (cudaDevice));
175192 AVCodecContext dummyCodecContext = {};
176- defaultCudaInterface_->initialize (avStream);
193+ defaultCudaInterface_->initialize (avStream, avFormatCtx );
177194 defaultCudaInterface_->registerHardwareDeviceWithCodec (&dummyCodecContext);
178195
179- const AVCodecParameters* codecpar = avStream->codecpar ;
180- TORCH_CHECK (codecpar != nullptr , " CodecParameters cannot be null" );
196+ TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
197+ timeBase_ = avStream->time_base ;
198+
199+ const AVCodecParameters* codecPar = avStream->codecpar ;
200+ TORCH_CHECK (codecPar != nullptr , " CodecParameters cannot be null" );
201+
202+ initializeBSF (codecPar, avFormatCtx);
203+
204+ // Create parser. Default values that aren't obvious are taken from DALI.
205+ CUVIDPARSERPARAMS parserParams = {};
206+ parserParams.CodecType = validateCodecSupport (codecPar->codec_id );
207+ parserParams.ulMaxNumDecodeSurfaces = 8 ;
208+ parserParams.ulMaxDisplayDelay = 0 ;
209+ // Callback setup, all are triggered by the parser within a call
210+ // to cuvidParseVideoData
211+ parserParams.pUserData = this ;
212+ parserParams.pfnSequenceCallback = pfnSequenceCallback;
213+ parserParams.pfnDecodePicture = pfnDecodePictureCallback;
214+ parserParams.pfnDisplayPicture = pfnDisplayPictureCallback;
181215
216+ CUresult result = cuvidCreateVideoParser (&videoParser_, &parserParams);
182217 TORCH_CHECK (
183- // TODONVDEC P0 support more
184- avStream->codecpar ->codec_id == AV_CODEC_ID_H264,
185- " Can only do H264 for now" );
218+ result == CUDA_SUCCESS, " Failed to create video parser: " , result);
219+ }
186220
221+ void BetaCudaDeviceInterface::initializeBSF (
222+ const AVCodecParameters* codecPar,
223+ const UniqueDecodingAVFormatContext& avFormatCtx) {
187224 // Setup bit stream filters (BSF):
188225 // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
189- // This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
190- // now we apply BSF unconditionally, but it should be optional and dependent
191- // on codec and container.
192- const AVBitStreamFilter* avBSF = av_bsf_get_by_name (" h264_mp4toannexb" );
226+ // This is only needed for some formats, like H264 or HEVC.
227+
228+ TORCH_CHECK (codecPar != nullptr , " codecPar cannot be null" );
229+ TORCH_CHECK (avFormatCtx != nullptr , " AVFormatContext cannot be null" );
230+ TORCH_CHECK (
231+ avFormatCtx->iformat != nullptr ,
232+ " AVFormatContext->iformat cannot be null" );
233+ std::string filterName;
234+
235+ // Matching logic is taken from DALI
236+ switch (codecPar->codec_id ) {
237+ case AV_CODEC_ID_H264: {
238+ const std::string formatName = avFormatCtx->iformat ->long_name
239+ ? avFormatCtx->iformat ->long_name
240+ : " " ;
241+
242+ if (formatName == " QuickTime / MOV" ||
243+ formatName == " FLV (Flash Video)" ||
244+ formatName == " Matroska / WebM" || formatName == " raw H.264 video" ) {
245+ filterName = " h264_mp4toannexb" ;
246+ }
247+ break ;
248+ }
249+
250+ case AV_CODEC_ID_HEVC: {
251+ const std::string formatName = avFormatCtx->iformat ->long_name
252+ ? avFormatCtx->iformat ->long_name
253+ : " " ;
254+
255+ if (formatName == " QuickTime / MOV" ||
256+ formatName == " FLV (Flash Video)" ||
257+ formatName == " Matroska / WebM" || formatName == " raw HEVC video" ) {
258+ filterName = " hevc_mp4toannexb" ;
259+ }
260+ break ;
261+ }
262+
263+ default :
264+ // No bitstream filter needed for other codecs
265+ // TODONVDEC P1 MPEG4 will need one!
266+ break ;
267+ }
268+
269+ if (filterName.empty ()) {
270+ // Only initialize BSF if we actually need one
271+ return ;
272+ }
273+
274+ const AVBitStreamFilter* avBSF = av_bsf_get_by_name (filterName.c_str ());
193275 TORCH_CHECK (
194- avBSF != nullptr , " Failed to find h264_mp4toannexb bitstream filter" );
276+ avBSF != nullptr , " Failed to find bitstream filter: " , filterName );
195277
196278 AVBSFContext* avBSFContext = nullptr ;
197279 int retVal = av_bsf_alloc (avBSF, &avBSFContext);
@@ -202,7 +284,7 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
202284
203285 bitstreamFilter_.reset (avBSFContext);
204286
205- retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecpar );
287+ retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecPar );
206288 TORCH_CHECK (
207289 retVal >= AVSUCCESS,
208290 " Failed to copy codec parameters: " ,
@@ -213,22 +295,6 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) {
213295 retVal == AVSUCCESS,
214296 " Failed to initialize bitstream filter: " ,
215297 getFFMPEGErrorStringFromErrorCode (retVal));
216-
217- // Create parser. Default values that aren't obvious are taken from DALI.
218- CUVIDPARSERPARAMS parserParams = {};
219- parserParams.CodecType = cudaVideoCodec_H264;
220- parserParams.ulMaxNumDecodeSurfaces = 8 ;
221- parserParams.ulMaxDisplayDelay = 0 ;
222- // Callback setup, all are triggered by the parser within a call
223- // to cuvidParseVideoData
224- parserParams.pUserData = this ;
225- parserParams.pfnSequenceCallback = pfnSequenceCallback;
226- parserParams.pfnDecodePicture = pfnDecodePictureCallback;
227- parserParams.pfnDisplayPicture = pfnDisplayPictureCallback;
228-
229- CUresult result = cuvidCreateVideoParser (&videoParser_, &parserParams);
230- TORCH_CHECK (
231- result == CUDA_SUCCESS, " Failed to create video parser: " , result);
232298}
233299
234300// This callback is called by the parser within cuvidParseVideoData when there
0 commit comments