Skip to content

Conversation

@Dan-Flores
Copy link
Contributor

@Dan-Flores Dan-Flores commented Oct 29, 2025

This PR adds CUDA support to the VideoEncoder.

  • Video Encoding on CUDA device is now contained within GpuEncoder.cpp
  • TODO: Write detailed description of changes

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 29, 2025
@Dan-Flores Dan-Flores changed the title video encoder python file video encoder CUDA support Oct 29, 2025
@Dan-Flores Dan-Flores force-pushed the encode_gpu branch 2 times, most recently from f1678b7 to 997158f Compare November 3, 2025 20:56
@Dan-Flores Dan-Flores marked this pull request as ready for review November 3, 2025 21:07
@Dan-Flores Dan-Flores changed the title video encoder CUDA support Enable CUDA device for video encoder Nov 3, 2025
@Dan-Flores Dan-Flores marked this pull request as draft November 3, 2025 21:35
@Dan-Flores Dan-Flores changed the title Enable CUDA device for video encoder [wip] Enable CUDA device for video encoder Nov 6, 2025
@Dan-Flores Dan-Flores changed the title [wip] Enable CUDA device for video encoder Enable CUDA device for video encoder Nov 26, 2025

void VideoEncoder::initializeEncoder(
const VideoStreamOptions& videoStreamOptions) {
if (videoStreamOptions.device.is_cuda()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hopiing we wouldn't need to support a device parameter anywhere. Any reason we can't just rely on the input frames device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per our offline discussion, I've updated this PR to not have an explicit device param, and instead use whichever device the frames Tensor is on.

Comment on lines 541 to 549
if (device.type() != torch::kCPU) {
TORCH_CHECK(
frames.is_cuda(),
"When using CUDA encoding (device=",
device.str(),
"), frames must be on a CUDA device. Got frames on ",
frames.device().str(),
". Please move frames to a CUDA device: frames.to('cuda')");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error above should be an internal bug: if the frames are on CUDA while the device parameter is not, it means it wasn't set properly in the custom ops. So this should be an "internal bug" error message rather than a user-facing thing, since we assume users aren't using the C++ APIs.

But, I think we should push the logic from https://github.com/meta-pytorch/torchcodec/pull/1008/files#r2565059217 further: we don't need a device parameter at all, anywhere. Having a device parameter duplicates the source of truth of the device and leads to potential bugs (like the one above that the TORCH_CHECK is preventing). So I think we shouldn't set options.device at all for encoding, and always rely on the frames. We can add a comment in src/torchcodec/_core/StreamOptions.h to indicate that.

std::optional<std::string_view> preset = std::nullopt,
std::optional<std::vector<std::string>> extra_options = std::nullopt) {
VideoStreamOptions videoStreamOptions;
videoStreamOptions.device = frames.device();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need to do that, it duplicates the source of truth of the device. See longer comment below.

Comment on lines 768 to 772
avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec);
if (gpuEncoder_) {
avCodec = gpuEncoder_->findEncoder(avFormatContext_->oformat->video_codec)
.value_or(avCodec);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there actual value in having findEncoder? Did you notice any problem if we didn't use it? If not, then let's remove it.

If yes, then let's rename it into findCodec and add a comment simlar to this one

// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
// we have to do this because of an FFmpeg bug where hardware decoding is not
// appropriately set, so we just go off and find the matching codec for the CUDA
// device
std::optional<const AVCodec*> CudaDeviceInterface::findCodec(

and also indicate that this findCodec function exists for similar reasons as the one I linked to.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's delete it. The intention was to find a hardware enabled encoder if one was not specified, but as far as I can tell, FFmpeg CLI does not support that for encoding, only for decoding via the -hwaccel flag.

avFrame->height = static_cast<int>(tensor.size(1));
avFrame->pts = frameIndex;

int ret = av_hwframe_get_buffer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note here that we're letting FFmpeg allocate the CUDA memory. I think we should explore allocating the memory with pytorch instead, so that we can automatically rely on pytorch's CUDA memory allocator, which should be more efficient. There could be a TODO to investigate how to do that (this is related to my comment about setupEncodingContext above).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would an example of this be allocateEmptyHWCTensor?

VideoEncoder(frames, frame_rate=30).to_file(dest=dest, **common_params)
with open(dest, "rb") as f:
return torch.frombuffer(f.read(), dtype=torch.uint8)
return torch.frombuffer(f.read(), dtype=torch.uint8).clone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was that needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was an attempt to fix a warning on this test, but it does not prevent the warning, I'll clean it up.


test/test_encoders.py::TestVideoEncoder::test_contiguity[cuda-to_file]
  /home/dev/torchcodec/test/test_encoders.py:835: UserWarning: The given buffer is not writable, and PyTorch does not support non-writable tensors. 
This means you can write to the underlying (supposedly non-writable) buffer using the tensor. 
You may want to copy the buffer to protect its data or make it writable before converting it to a tensor. 
This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_new.cpp:1581.)
    return torch.frombuffer(f.read(), dtype=torch.uint8)

Comment on lines +1328 to +1329
if b"No NVENC capable devices found" in e.stderr:
pytest.skip("NVENC not available on this system")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO to make sure our CI never ever skips those tests. I.e. we should have a mechanism in place that makes sure our CI fails here, instead of skipping these tests. This should be the first follow-up of this PR.


@pytest.mark.needs_cuda
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
@pytest.mark.parametrize("pixel_format", ("nv12", "yuv420p"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised to see this. I thought nvenc only supports NV12 output. Is that not the case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out, I took another look at my code:

  • h264_nvenc supports multiple output pixel formats: yuv420p nv12 p010le yuv444p p016le nv16 ...

  • The bottleneck is that GpuEncoder::convertTensorToAVFrame always uses nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx, which only handles nv12.

  • Since nv12 and yuv420p do the same chroma subsampling, the results appeared to be correct.

I'll add a TODO to enable utilizing the user's selected pixel formats. There are other nvidia functions we can use based on the target pixel format, or I could investigate using filtergraph's scale_cuda to handle conversion, as is done in maybeConvertAVFrameToNV12OrRGB24.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants