Skip to content

Commit 0511d10

Browse files
authored
BETA CUDA interface: NVCUVID decoder implementation 1/N (#910)
1 parent b3e2e2c commit 0511d10

22 files changed

+3233
-81
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 576 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
// BETA CUDA device interface that provides direct control over NVDEC
8+
// while keeping FFmpeg for demuxing. A lot of the logic, particularly the use
9+
// of a cache for the decoders, is inspired by DALI's implementation which is
10+
// APACHE 2.0:
11+
// https://github.com/NVIDIA/DALI/blob/c7539676a24a8e9e99a6e8665e277363c5445259/dali/operators/video/frames_decoder_gpu.cc#L1
12+
//
13+
// NVDEC / NVCUVID docs:
14+
// https://docs.nvidia.com/video-technologies/video-codec-sdk/13.0/nvdec-video-decoder-api-prog-guide/index.html#using-nvidia-video-decoder-nvdecode-api
15+
16+
#pragma once
17+
18+
#include "src/torchcodec/_core/Cache.h"
19+
#include "src/torchcodec/_core/DeviceInterface.h"
20+
#include "src/torchcodec/_core/FFMPEGCommon.h"
21+
#include "src/torchcodec/_core/NVDECCache.h"
22+
23+
#include <map>
24+
#include <memory>
25+
#include <mutex>
26+
#include <queue>
27+
#include <unordered_map>
28+
#include <vector>
29+
30+
#include "src/torchcodec/_core/nvcuvid_include/cuviddec.h"
31+
#include "src/torchcodec/_core/nvcuvid_include/nvcuvid.h"
32+
33+
namespace facebook::torchcodec {
34+
35+
class BetaCudaDeviceInterface : public DeviceInterface {
36+
public:
37+
explicit BetaCudaDeviceInterface(const torch::Device& device);
38+
virtual ~BetaCudaDeviceInterface();
39+
40+
void initializeInterface(AVStream* stream) override;
41+
42+
void convertAVFrameToFrameOutput(
43+
const VideoStreamOptions& videoStreamOptions,
44+
const AVRational& timeBase,
45+
UniqueAVFrame& avFrame,
46+
FrameOutput& frameOutput,
47+
std::optional<torch::Tensor> preAllocatedOutputTensor =
48+
std::nullopt) override;
49+
50+
bool canDecodePacketDirectly() const override {
51+
return true;
52+
}
53+
54+
int sendPacket(ReferenceAVPacket& packet) override;
55+
int receiveFrame(UniqueAVFrame& avFrame, int64_t desiredPts) override;
56+
void flush() override;
57+
58+
// NVDEC callback functions (must be public for C callbacks)
59+
int streamPropertyChange(CUVIDEOFORMAT* videoFormat);
60+
int frameReadyForDecoding(CUVIDPICPARAMS* pPicParams);
61+
62+
private:
63+
// Apply bitstream filter, modifies packet in-place
64+
void applyBSF(ReferenceAVPacket& packet);
65+
66+
class FrameBuffer {
67+
public:
68+
struct Slot {
69+
CUVIDPARSERDISPINFO dispInfo;
70+
int64_t guessedPts;
71+
bool occupied = false;
72+
73+
Slot() : guessedPts(-1), occupied(false) {
74+
std::memset(&dispInfo, 0, sizeof(dispInfo));
75+
}
76+
};
77+
78+
// TODONVDEC P1: init size should probably be min_num_decode_surfaces from
79+
// video format
80+
FrameBuffer() : frameBuffer_(4) {}
81+
82+
~FrameBuffer() = default;
83+
84+
Slot* findEmptySlot();
85+
Slot* findFrameWithExactPts(int64_t desiredPts);
86+
87+
// Iterator support for range-based for loops
88+
auto begin() {
89+
return frameBuffer_.begin();
90+
}
91+
92+
auto end() {
93+
return frameBuffer_.end();
94+
}
95+
96+
private:
97+
std::vector<Slot> frameBuffer_;
98+
};
99+
100+
UniqueAVFrame convertCudaFrameToAVFrame(
101+
CUdeviceptr framePtr,
102+
unsigned int pitch,
103+
const CUVIDPARSERDISPINFO& dispInfo);
104+
105+
CUvideoparser videoParser_ = nullptr;
106+
UniqueCUvideodecoder decoder_;
107+
CUVIDEOFORMAT videoFormat_ = {};
108+
109+
FrameBuffer frameBuffer_;
110+
111+
std::queue<int64_t> packetsPtsQueue;
112+
113+
bool eofSent_ = false;
114+
115+
// Flush flag to prevent decode operations during flush (like DALI's
116+
// isFlushing_)
117+
bool isFlushing_ = false;
118+
119+
AVRational timeBase_ = {0, 0};
120+
121+
UniqueAVBSFContext bitstreamFilter_;
122+
123+
// Default CUDA interface for color conversion.
124+
// TODONVDEC P2: we shouldn't need to keep a separate instance of the default.
125+
// See other TODO there about how interfaces should be completely independent.
126+
std::unique_ptr<DeviceInterface> defaultCudaInterface_;
127+
};
128+
129+
} // namespace facebook::torchcodec

src/torchcodec/_core/CMakeLists.txt

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ function(make_torchcodec_libraries
9898
)
9999

100100
if(ENABLE_CUDA)
101-
list(APPEND core_sources CudaDeviceInterface.cpp)
101+
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp)
102102
endif()
103103

104104
set(core_library_dependencies
@@ -107,9 +107,27 @@ function(make_torchcodec_libraries
107107
)
108108

109109
if(ENABLE_CUDA)
110+
# Try to find NVCUVID. Try the normal way first. This should work locally.
111+
find_library(NVCUVID_LIBRARY NAMES nvcuvid)
112+
# If not found, try with version suffix, or hardcoded path. Appears
113+
# to be necessary on the CI.
114+
if(NOT NVCUVID_LIBRARY)
115+
find_library(NVCUVID_LIBRARY NAMES nvcuvid.1 PATHS /usr/lib64 /usr/lib)
116+
endif()
117+
if(NOT NVCUVID_LIBRARY)
118+
set(NVCUVID_LIBRARY "/usr/lib64/libnvcuvid.so.1")
119+
endif()
120+
121+
if(NVCUVID_LIBRARY)
122+
message(STATUS "Found NVCUVID: ${NVCUVID_LIBRARY}")
123+
else()
124+
message(FATAL_ERROR "Could not find NVCUVID library")
125+
endif()
126+
110127
list(APPEND core_library_dependencies
111128
${CUDA_nppi_LIBRARY}
112129
${CUDA_nppicc_LIBRARY}
130+
${NVCUVID_LIBRARY}
113131
)
114132
endif()
115133

src/torchcodec/_core/CpuDeviceInterface.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace facebook::torchcodec {
1010
namespace {
1111

1212
static bool g_cpu = registerDeviceInterface(
13-
torch::kCPU,
13+
DeviceInterfaceKey(torch::kCPU),
1414
[](const torch::Device& device) { return new CpuDeviceInterface(device); });
1515

1616
} // namespace

src/torchcodec/_core/CudaDeviceInterface.cpp

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,21 @@ extern "C" {
1313
#include <libavutil/pixdesc.h>
1414
}
1515

16+
// TODONVDEC P1 Changes were made to this file to accomodate for the BETA CUDA
17+
// interface (see other TODONVDEC below). That's because the BETA CUDA interface
18+
// relies on this default CUDA interface to do the color conversion. That's
19+
// hacky, ugly, and leads to complicated code. We should refactor all this so
20+
// that an interface doesn't need to know anything about any other interface.
21+
// Note - this is more than just about the BETA CUDA interface: this default
22+
// interface already relies on the CPU interface to do software decoding when
23+
// needed, and that's already leading to similar complications.
24+
1625
namespace facebook::torchcodec {
1726
namespace {
1827

19-
static bool g_cuda =
20-
registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) {
28+
static bool g_cuda = registerDeviceInterface(
29+
DeviceInterfaceKey(torch::kCUDA),
30+
[](const torch::Device& device) {
2131
return new CudaDeviceInterface(device);
2232
});
2333

@@ -216,10 +226,11 @@ std::unique_ptr<FiltersContext> CudaDeviceInterface::initializeFiltersContext(
216226
return nullptr;
217227
}
218228

219-
TORCH_CHECK(
220-
avFrame->hw_frames_ctx != nullptr,
221-
"The AVFrame does not have a hw_frames_ctx. "
222-
"That's unexpected, please report this to the TorchCodec repo.");
229+
if (avFrame->hw_frames_ctx == nullptr) {
230+
// TODONVDEC P2 return early for for beta interface where avFrames don't
231+
// have a hw_frames_ctx. We should get rid of this or improve the logic.
232+
return nullptr;
233+
}
223234

224235
auto hwFramesCtx =
225236
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
@@ -347,22 +358,23 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
347358
// Above we checked that the AVFrame was on GPU, but that's not enough, we
348359
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
349360
// because this is what the NPP color conversion routines expect.
350-
TORCH_CHECK(
351-
avFrame->hw_frames_ctx != nullptr,
352-
"The AVFrame does not have a hw_frames_ctx. "
353-
"That's unexpected, please report this to the TorchCodec repo.");
354-
355-
auto hwFramesCtx =
356-
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
357-
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
361+
// TODONVDEC P2 this can be hit from the beta interface, but there's no
362+
// hw_frames_ctx in this case. We should try to understand how that affects
363+
// this validation.
364+
AVHWFramesContext* hwFramesCtx = nullptr;
365+
if (avFrame->hw_frames_ctx != nullptr) {
366+
hwFramesCtx =
367+
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
368+
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
358369

359-
TORCH_CHECK(
360-
actualFormat == AV_PIX_FMT_NV12,
361-
"The AVFrame is ",
362-
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
363-
: "unknown"),
364-
", but we expected AV_PIX_FMT_NV12. "
365-
"That's unexpected, please report this to the TorchCodec repo.");
370+
TORCH_CHECK(
371+
actualFormat == AV_PIX_FMT_NV12,
372+
"The AVFrame is ",
373+
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
374+
: "unknown"),
375+
", but we expected AV_PIX_FMT_NV12. "
376+
"That's unexpected, please report this to the TorchCodec repo.");
377+
}
366378

367379
auto frameDims =
368380
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
@@ -396,19 +408,23 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
396408
// arbitrary, but unfortunately we know it's hardcoded to be the default
397409
// stream by FFmpeg:
398410
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
399-
TORCH_CHECK(
400-
hwFramesCtx->device_ctx != nullptr,
401-
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
402-
auto cudaDeviceCtx =
403-
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
404-
at::cuda::CUDAEvent nvdecDoneEvent;
405-
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
406-
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
407-
nvdecDoneEvent.record(nvdecStream);
408-
409-
// Don't start NPP work before NVDEC is done decoding the frame!
410411
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
411-
nvdecDoneEvent.block(nppStream);
412+
if (hwFramesCtx) {
413+
// TODONVDEC P2 this block won't be hit from the beta interface because
414+
// there is no hwFramesCtx, but we should still make sure there's no CUDA
415+
// stream sync issue in the beta interface.
416+
TORCH_CHECK(
417+
hwFramesCtx->device_ctx != nullptr,
418+
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
419+
auto cudaDeviceCtx =
420+
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
421+
at::cuda::CUDAEvent nvdecDoneEvent;
422+
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
423+
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
424+
nvdecDoneEvent.record(nvdecStream);
425+
// Don't start NPP work before NVDEC is done decoding the frame!
426+
nvdecDoneEvent.block(nppStream);
427+
}
412428

413429
// Create the NPP context if we haven't yet.
414430
nppCtx_->hStream = nppStream.stream();

src/torchcodec/_core/DeviceInterface.cpp

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
namespace facebook::torchcodec {
1212

1313
namespace {
14-
using DeviceInterfaceMap = std::map<torch::DeviceType, CreateDeviceInterfaceFn>;
14+
using DeviceInterfaceMap =
15+
std::map<DeviceInterfaceKey, CreateDeviceInterfaceFn>;
1516
static std::mutex g_interface_mutex;
1617

1718
DeviceInterfaceMap& getDeviceMap() {
@@ -30,50 +31,72 @@ std::string getDeviceType(const std::string& device) {
3031
} // namespace
3132

3233
bool registerDeviceInterface(
33-
torch::DeviceType deviceType,
34+
const DeviceInterfaceKey& key,
3435
CreateDeviceInterfaceFn createInterface) {
3536
std::scoped_lock lock(g_interface_mutex);
3637
DeviceInterfaceMap& deviceMap = getDeviceMap();
3738

3839
TORCH_CHECK(
39-
deviceMap.find(deviceType) == deviceMap.end(),
40-
"Device interface already registered for ",
41-
deviceType);
42-
deviceMap.insert({deviceType, createInterface});
40+
deviceMap.find(key) == deviceMap.end(),
41+
"Device interface already registered for device type ",
42+
key.deviceType,
43+
" variant '",
44+
key.variant,
45+
"'");
46+
deviceMap.insert({key, createInterface});
4347

4448
return true;
4549
}
4650

47-
torch::Device createTorchDevice(const std::string device) {
51+
void validateDeviceInterface(
52+
const std::string device,
53+
const std::string variant) {
4854
std::scoped_lock lock(g_interface_mutex);
4955
std::string deviceType = getDeviceType(device);
56+
5057
DeviceInterfaceMap& deviceMap = getDeviceMap();
5158

59+
// Find device interface that matches device type and variant
60+
torch::DeviceType deviceTypeEnum = torch::Device(deviceType).type();
61+
5262
auto deviceInterface = std::find_if(
5363
deviceMap.begin(),
5464
deviceMap.end(),
55-
[&](const std::pair<torch::DeviceType, CreateDeviceInterfaceFn>& arg) {
56-
return device.rfind(
57-
torch::DeviceTypeName(arg.first, /*lcase*/ true), 0) == 0;
65+
[&](const std::pair<DeviceInterfaceKey, CreateDeviceInterfaceFn>& arg) {
66+
return arg.first.deviceType == deviceTypeEnum &&
67+
arg.first.variant == variant;
5868
});
59-
TORCH_CHECK(
60-
deviceInterface != deviceMap.end(), "Unsupported device: ", device);
6169

62-
return torch::Device(device);
70+
TORCH_CHECK(
71+
deviceInterface != deviceMap.end(),
72+
"Unsupported device: ",
73+
device,
74+
" (device type: ",
75+
deviceType,
76+
", variant: ",
77+
variant,
78+
")");
6379
}
6480

6581
std::unique_ptr<DeviceInterface> createDeviceInterface(
66-
const torch::Device& device) {
67-
auto deviceType = device.type();
82+
const torch::Device& device,
83+
const std::string_view variant) {
84+
DeviceInterfaceKey key(device.type(), variant);
6885
std::scoped_lock lock(g_interface_mutex);
6986
DeviceInterfaceMap& deviceMap = getDeviceMap();
7087

71-
TORCH_CHECK(
72-
deviceMap.find(deviceType) != deviceMap.end(),
73-
"Unsupported device: ",
74-
device);
88+
auto it = deviceMap.find(key);
89+
if (it != deviceMap.end()) {
90+
return std::unique_ptr<DeviceInterface>(it->second(device));
91+
}
7592

76-
return std::unique_ptr<DeviceInterface>(deviceMap[deviceType](device));
93+
TORCH_CHECK(
94+
false,
95+
"No device interface found for device type: ",
96+
device.type(),
97+
" variant: '",
98+
variant,
99+
"'");
77100
}
78101

79102
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)