Skip to content

Commit 1e5b5b2

Browse files
authored
Add CUDA decoding support (#242)
1 parent e2ed57c commit 1e5b5b2

File tree

9 files changed

+398
-19
lines changed

9 files changed

+398
-19
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import argparse
2+
import os
3+
import pathlib
4+
import time
5+
from concurrent.futures import ThreadPoolExecutor
6+
7+
import torch
8+
9+
import torch.utils.benchmark as benchmark
10+
11+
import torchcodec
12+
import torchvision.transforms.v2.functional as F
13+
14+
RESIZED_WIDTH = 256
15+
RESIZED_HEIGHT = 256
16+
17+
18+
def transfer_and_resize_frame(frame, resize_device_string):
19+
# This should be a no-op if the frame is already on the target device.
20+
frame = frame.to(resize_device_string)
21+
frame = F.resize(frame, (RESIZED_HEIGHT, RESIZED_WIDTH))
22+
return frame
23+
24+
25+
def decode_full_video(video_path, decode_device_string, resize_device_string):
26+
# We use the core API instead of SimpleVideoDecoder because the core API
27+
# allows us to natively resize as part of the decode step.
28+
print(f"{decode_device_string=} {resize_device_string=}")
29+
decoder = torchcodec.decoders._core.create_from_file(video_path)
30+
num_threads = None
31+
if "cuda" in decode_device_string:
32+
num_threads = 1
33+
width = None
34+
height = None
35+
if "native" in resize_device_string:
36+
width = RESIZED_WIDTH
37+
height = RESIZED_HEIGHT
38+
torchcodec.decoders._core._add_video_stream(
39+
decoder,
40+
stream_index=-1,
41+
device=decode_device_string,
42+
num_threads=num_threads,
43+
width=width,
44+
height=height,
45+
)
46+
47+
start_time = time.time()
48+
frame_count = 0
49+
while True:
50+
try:
51+
frame, *_ = torchcodec.decoders._core.get_next_frame(decoder)
52+
if resize_device_string != "none" and "native" not in resize_device_string:
53+
frame = transfer_and_resize_frame(frame, resize_device_string)
54+
55+
frame_count += 1
56+
except Exception as e:
57+
print("EXCEPTION", e)
58+
break
59+
60+
end_time = time.time()
61+
elapsed = end_time - start_time
62+
fps = frame_count / (end_time - start_time)
63+
print(
64+
f"****** DECODED full video {decode_device_string=} {frame_count=} {elapsed=} {fps=}"
65+
)
66+
return frame_count, end_time - start_time
67+
68+
69+
def decode_videos_using_threads(
70+
video_path,
71+
decode_device_string,
72+
resize_device_string,
73+
num_videos,
74+
num_threads,
75+
use_multiple_gpus,
76+
):
77+
executor = ThreadPoolExecutor(max_workers=num_threads)
78+
for i in range(num_videos):
79+
actual_decode_device = decode_device_string
80+
if "cuda" in decode_device_string and use_multiple_gpus:
81+
actual_decode_device = f"cuda:{i % torch.cuda.device_count()}"
82+
executor.submit(
83+
decode_full_video, video_path, actual_decode_device, resize_device_string
84+
)
85+
executor.shutdown(wait=True)
86+
87+
88+
def main():
89+
parser = argparse.ArgumentParser()
90+
parser.add_argument(
91+
"--devices",
92+
default="cuda:0,cpu",
93+
type=str,
94+
help="Comma-separated devices to test decoding on.",
95+
)
96+
parser.add_argument(
97+
"--resize_devices",
98+
default="cuda:0,cpu,native,none",
99+
type=str,
100+
help="Comma-separated devices to test preroc (resize) on. Use 'none' to specify no resize.",
101+
)
102+
parser.add_argument(
103+
"--video",
104+
type=str,
105+
default=str(
106+
pathlib.Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
107+
),
108+
)
109+
parser.add_argument(
110+
"--use_torch_benchmark",
111+
action=argparse.BooleanOptionalAction,
112+
default=True,
113+
help=(
114+
"Use pytorch benchmark to measure decode time with warmup and "
115+
"autorange. Without this we just run one iteration without warmup "
116+
"to measure the cold start time."
117+
),
118+
)
119+
parser.add_argument(
120+
"--num_threads",
121+
type=int,
122+
default=1,
123+
help="Number of threads to use for decoding. Only used when --use_torch_benchmark is set.",
124+
)
125+
parser.add_argument(
126+
"--num_videos",
127+
type=int,
128+
default=50,
129+
help="Number of videos to decode in parallel. Only used when --num_threads is set.",
130+
)
131+
parser.add_argument(
132+
"--use_multiple_gpus",
133+
action=argparse.BooleanOptionalAction,
134+
default=True,
135+
help=("Use multiple GPUs to decode multiple videos in multi-threaded mode."),
136+
)
137+
args = parser.parse_args()
138+
video_path = args.video
139+
140+
if not args.use_torch_benchmark:
141+
for device in args.devices.split(","):
142+
print("Testing on", device)
143+
decode_full_video(video_path, device)
144+
return
145+
146+
resize_devices = args.resize_devices.split(",")
147+
resize_devices = [d for d in resize_devices if d != ""]
148+
if len(resize_devices) == 0:
149+
resize_devices.append("none")
150+
151+
label = "Decode+Resize Time"
152+
153+
results = []
154+
for decode_device_string in args.devices.split(","):
155+
for resize_device_string in resize_devices:
156+
decode_label = decode_device_string
157+
if "cuda" in decode_label:
158+
# Shorten "cuda:0" to "cuda"
159+
decode_label = "cuda"
160+
resize_label = resize_device_string
161+
if "cuda" in resize_device_string:
162+
# Shorten "cuda:0" to "cuda"
163+
resize_label = "cuda"
164+
print("decode_device", decode_device_string)
165+
print("resize_device", resize_device_string)
166+
if args.num_threads > 1:
167+
t = benchmark.Timer(
168+
stmt="decode_videos_using_threads(video_path, decode_device_string, resize_device_string, num_videos, num_threads, use_multiple_gpus)",
169+
globals={
170+
"decode_device_string": decode_device_string,
171+
"video_path": video_path,
172+
"decode_full_video": decode_full_video,
173+
"decode_videos_using_threads": decode_videos_using_threads,
174+
"resize_device_string": resize_device_string,
175+
"num_videos": args.num_videos,
176+
"num_threads": args.num_threads,
177+
"use_multiple_gpus": args.use_multiple_gpus,
178+
},
179+
label=label,
180+
description=f"threads={args.num_threads} work={args.num_videos} video={os.path.basename(video_path)}",
181+
sub_label=f"D={decode_label} R={resize_label} T={args.num_threads} W={args.num_videos}",
182+
).blocked_autorange()
183+
results.append(t)
184+
else:
185+
t = benchmark.Timer(
186+
stmt="decode_full_video(video_path, decode_device_string, resize_device_string)",
187+
globals={
188+
"decode_device_string": decode_device_string,
189+
"video_path": video_path,
190+
"decode_full_video": decode_full_video,
191+
"resize_device_string": resize_device_string,
192+
},
193+
label=label,
194+
description=f"video={os.path.basename(video_path)}",
195+
sub_label=f"D={decode_label} R={resize_label}",
196+
).blocked_autorange()
197+
results.append(t)
198+
compare = benchmark.Compare(results)
199+
compare.print()
200+
print("Key: D=Decode, R=Resize T=threads W=work (number of videos to decode)")
201+
print("Native resize is done as part of the decode step")
202+
print("none resize means there is no resize step -- native or otherwise")
203+
204+
205+
if __name__ == "__main__":
206+
main()

src/torchcodec/decoders/_core/CMakeLists.txt

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,16 @@ function(make_torchcodec_library library_name ffmpeg_target)
3434
${Python3_INCLUDE_DIRS}
3535
)
3636

37+
set(NEEDED_LIBRARIES ${ffmpeg_target} ${TORCH_LIBRARIES}
38+
${Python3_LIBRARIES})
39+
if(ENABLE_CUDA)
40+
list(APPEND NEEDED_LIBRARIES ${CUDA_CUDA_LIBRARY}
41+
${CUDA_nppi_LIBRARY} ${CUDA_nppicc_LIBRARY} )
42+
endif()
3743
target_link_libraries(
3844
${library_name}
3945
PUBLIC
40-
${ffmpeg_target}
41-
${TORCH_LIBRARIES}
42-
${Python3_LIBRARIES}
46+
${NEEDED_LIBRARIES}
4347
)
4448

4549
# We already set the library_name to be libtorchcodecN, so we don't want

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,31 @@
11
#include <torch/types.h>
2+
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
23

34
namespace facebook::torchcodec {
45

56
// This file is linked with the CPU-only version of torchcodec.
67
// So all functions will throw an error because they should only be called if
78
// the device is not CPU.
89

9-
void throwUnsupportedDeviceError(const torch::Device& device) {
10+
[[noreturn]] void throwUnsupportedDeviceError(const torch::Device& device) {
1011
TORCH_CHECK(
1112
device.type() != torch::kCPU,
1213
"Device functions should only be called if the device is not CPU.")
13-
throw std::runtime_error("Unsupported device: " + device.str());
14+
TORCH_CHECK(false, "Unsupported device: " + device.str());
1415
}
1516

16-
void initializeDeviceContext(const torch::Device& device) {
17+
void convertAVFrameToDecodedOutputOnCuda(
18+
const torch::Device& device,
19+
const VideoDecoder::VideoStreamDecoderOptions& options,
20+
AVCodecContext* codecContext,
21+
VideoDecoder::RawDecodedOutput& rawOutput,
22+
VideoDecoder::DecodedOutput& output) {
23+
throwUnsupportedDeviceError(device);
24+
}
25+
26+
void initializeContextOnCuda(
27+
const torch::Device& device,
28+
AVCodecContext* codecContext) {
1729
throwUnsupportedDeviceError(device);
1830
}
1931

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,52 @@
1+
#include <ATen/cuda/CUDAEvent.h>
2+
#include <c10/cuda/CUDAStream.h>
3+
#include <npp.h>
14
#include <torch/types.h>
5+
#include "src/torchcodec/decoders/_core/DeviceInterface.h"
6+
#include "src/torchcodec/decoders/_core/FFMPEGCommon.h"
7+
#include "src/torchcodec/decoders/_core/VideoDecoder.h"
8+
9+
extern "C" {
10+
#include <libavcodec/avcodec.h>
11+
#include <libavutil/hwcontext_cuda.h>
12+
#include <libavutil/pixdesc.h>
13+
}
214

315
namespace facebook::torchcodec {
16+
namespace {
17+
AVBufferRef* getCudaContext(const torch::Device& device) {
18+
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
19+
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");
20+
torch::DeviceIndex deviceIndex = device.index();
21+
// FFMPEG cannot handle negative device indices.
22+
// For single GPU- machines libtorch returns -1 for the device index. So for
23+
// that case we set the device index to 0.
24+
// TODO: Double check if this works for multi-GPU machines correctly.
25+
deviceIndex = std::max<at::DeviceIndex>(deviceIndex, 0);
26+
std::string deviceOrdinal = std::to_string(deviceIndex);
27+
AVBufferRef* hw_device_ctx;
28+
int err = av_hwdevice_ctx_create(
29+
&hw_device_ctx, type, deviceOrdinal.c_str(), nullptr, 0);
30+
if (err < 0) {
31+
TORCH_CHECK(
32+
false,
33+
"Failed to create specified HW device",
34+
getFFMPEGErrorStringFromErrorCode(err));
35+
}
36+
return hw_device_ctx;
37+
}
38+
39+
torch::Tensor allocateDeviceTensor(
40+
at::IntArrayRef shape,
41+
torch::Device device,
42+
const torch::Dtype dtype = torch::kUInt8) {
43+
return torch::empty(
44+
shape,
45+
torch::TensorOptions()
46+
.dtype(dtype)
47+
.layout(torch::kStrided)
48+
.device(device));
49+
}
450

551
void throwErrorIfNonCudaDevice(const torch::Device& device) {
652
TORCH_CHECK(
@@ -10,13 +56,70 @@ void throwErrorIfNonCudaDevice(const torch::Device& device) {
1056
throw std::runtime_error("Unsupported device: " + device.str());
1157
}
1258
}
59+
} // namespace
1360

14-
void initializeDeviceContext(const torch::Device& device) {
61+
void initializeContextOnCuda(
62+
const torch::Device& device,
63+
AVCodecContext* codecContext) {
1564
throwErrorIfNonCudaDevice(device);
16-
// TODO: https://github.com/pytorch/torchcodec/issues/238: Implement CUDA
17-
// device.
18-
throw std::runtime_error(
19-
"CUDA device is unimplemented. Follow this issue for tracking progress: https://github.com/pytorch/torchcodec/issues/238");
65+
// It is important for pytorch itself to create the cuda context. If ffmpeg
66+
// creates the context it may not be compatible with pytorch.
67+
// This is a dummy tensor to initialize the cuda context.
68+
torch::Tensor dummyTensorForCudaInitialization = torch::empty(
69+
{1}, torch::TensorOptions().dtype(torch::kUInt8).device(device));
70+
codecContext->hw_device_ctx = getCudaContext(device);
71+
return;
72+
}
73+
74+
void convertAVFrameToDecodedOutputOnCuda(
75+
const torch::Device& device,
76+
const VideoDecoder::VideoStreamDecoderOptions& options,
77+
AVCodecContext* codecContext,
78+
VideoDecoder::RawDecodedOutput& rawOutput,
79+
VideoDecoder::DecodedOutput& output) {
80+
AVFrame* src = rawOutput.frame.get();
81+
82+
TORCH_CHECK(
83+
src->format == AV_PIX_FMT_CUDA,
84+
"Expected format to be AV_PIX_FMT_CUDA, got " +
85+
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
86+
int width = options.width.value_or(codecContext->width);
87+
int height = options.height.value_or(codecContext->height);
88+
NppiSize oSizeROI = {width, height};
89+
Npp8u* input[2] = {src->data[0], src->data[1]};
90+
torch::Tensor& dst = output.frame;
91+
dst = allocateDeviceTensor({height, width, 3}, options.device);
92+
93+
// Use the user-requested GPU for running the NPP kernel.
94+
c10::cuda::CUDAGuard deviceGuard(device);
95+
96+
auto start = std::chrono::high_resolution_clock::now();
97+
98+
NppStatus status = nppiNV12ToRGB_8u_P2C3R(
99+
input,
100+
src->linesize[0],
101+
static_cast<Npp8u*>(dst.data_ptr()),
102+
dst.stride(0),
103+
oSizeROI);
104+
TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame.");
105+
// Make the pytorch stream wait for the npp kernel to finish before using the
106+
// output.
107+
at::cuda::CUDAEvent nppDoneEvent;
108+
at::cuda::CUDAStream nppStreamWrapper =
109+
c10::cuda::getStreamFromExternal(nppGetStream(), device.index());
110+
nppDoneEvent.record(nppStreamWrapper);
111+
nppDoneEvent.block(at::cuda::getCurrentCUDAStream());
112+
113+
auto end = std::chrono::high_resolution_clock::now();
114+
115+
std::chrono::duration<double, std::micro> duration = end - start;
116+
VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width
117+
<< " took: " << duration.count() << "us" << std::endl;
118+
if (options.dimensionOrder == "NCHW") {
119+
// The docs guaranty this to return a view:
120+
// https://pytorch.org/docs/stable/generated/torch.permute.html
121+
dst = dst.permute({2, 0, 1});
122+
}
20123
}
21124

22125
} // namespace facebook::torchcodec

0 commit comments

Comments
 (0)