Skip to content

Commit 67b4381

Browse files
Dan-FloresDaniel Flores
andauthored
VideoEncoder first pass, round trip test (#866)
Co-authored-by: Daniel Flores <[email protected]>
1 parent ab10088 commit 67b4381

File tree

7 files changed

+414
-1
lines changed

7 files changed

+414
-1
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,4 +511,283 @@ void AudioEncoder::flushBuffers() {
511511

512512
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
513513
}
514+
515+
namespace {
516+
517+
torch::Tensor validateFrames(const torch::Tensor& frames) {
518+
TORCH_CHECK(
519+
frames.dtype() == torch::kUInt8,
520+
"frames must have uint8 dtype, got ",
521+
frames.dtype());
522+
TORCH_CHECK(
523+
frames.dim() == 4,
524+
"frames must have 4 dimensions (N, C, H, W), got ",
525+
frames.dim());
526+
TORCH_CHECK(
527+
frames.sizes()[1] == 3,
528+
"frame must have 3 channels (R, G, B), got ",
529+
frames.sizes()[1]);
530+
// TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
531+
return frames.contiguous();
532+
}
533+
534+
} // namespace
535+
536+
VideoEncoder::~VideoEncoder() {
537+
if (avFormatContext_ && avFormatContext_->pb) {
538+
avio_flush(avFormatContext_->pb);
539+
avio_close(avFormatContext_->pb);
540+
avFormatContext_->pb = nullptr;
541+
}
542+
}
543+
544+
VideoEncoder::VideoEncoder(
545+
const torch::Tensor& frames,
546+
int frameRate,
547+
std::string_view fileName,
548+
const VideoStreamOptions& videoStreamOptions)
549+
: frames_(validateFrames(frames)), inFrameRate_(frameRate) {
550+
setFFmpegLogLevel();
551+
552+
// Allocate output format context
553+
AVFormatContext* avFormatContext = nullptr;
554+
int status = avformat_alloc_output_context2(
555+
&avFormatContext, nullptr, nullptr, fileName.data());
556+
557+
TORCH_CHECK(
558+
avFormatContext != nullptr,
559+
"Couldn't allocate AVFormatContext. ",
560+
"The destination file is ",
561+
fileName,
562+
", check the desired extension? ",
563+
getFFMPEGErrorStringFromErrorCode(status));
564+
avFormatContext_.reset(avFormatContext);
565+
566+
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
567+
TORCH_CHECK(
568+
status >= 0,
569+
"avio_open failed. The destination file is ",
570+
fileName,
571+
", make sure it's a valid path? ",
572+
getFFMPEGErrorStringFromErrorCode(status));
573+
// TODO-VideoEncoder: Add tests for above fileName related checks
574+
575+
initializeEncoder(videoStreamOptions);
576+
}
577+
578+
void VideoEncoder::initializeEncoder(
579+
const VideoStreamOptions& videoStreamOptions) {
580+
const AVCodec* avCodec =
581+
avcodec_find_encoder(avFormatContext_->oformat->video_codec);
582+
TORCH_CHECK(avCodec != nullptr, "Video codec not found");
583+
584+
AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec);
585+
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
586+
avCodecContext_.reset(avCodecContext);
587+
588+
// Set encoding options
589+
// TODO-VideoEncoder: Allow bitrate to be set
590+
std::optional<int> desiredBitRate = videoStreamOptions.bitRate;
591+
if (desiredBitRate.has_value()) {
592+
TORCH_CHECK(
593+
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
594+
}
595+
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
596+
597+
// Store dimension order and input pixel format
598+
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
599+
auto sizes = frames_.sizes();
600+
inPixelFormat_ = AV_PIX_FMT_GBRP;
601+
inHeight_ = sizes[2];
602+
inWidth_ = sizes[3];
603+
604+
// Use specified dimensions or input dimensions
605+
// TODO-VideoEncoder: Allow height and width to be set
606+
outWidth_ = videoStreamOptions.width.value_or(inWidth_);
607+
outHeight_ = videoStreamOptions.height.value_or(inHeight_);
608+
609+
// Use YUV420P as default output format
610+
// TODO-VideoEncoder: Enable other pixel formats
611+
outPixelFormat_ = AV_PIX_FMT_YUV420P;
612+
613+
// Configure codec parameters
614+
avCodecContext_->codec_id = avCodec->id;
615+
avCodecContext_->width = outWidth_;
616+
avCodecContext_->height = outHeight_;
617+
avCodecContext_->pix_fmt = outPixelFormat_;
618+
// TODO-VideoEncoder: Verify that frame_rate and time_base are correct
619+
avCodecContext_->time_base = {1, inFrameRate_};
620+
avCodecContext_->framerate = {inFrameRate_, 1};
621+
622+
// TODO-VideoEncoder: Allow GOP size and max B-frames to be set
623+
if (videoStreamOptions.gopSize.has_value()) {
624+
avCodecContext_->gop_size = *videoStreamOptions.gopSize;
625+
} else {
626+
avCodecContext_->gop_size = 12; // Default GOP size
627+
}
628+
629+
if (videoStreamOptions.maxBFrames.has_value()) {
630+
avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames;
631+
} else {
632+
avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression
633+
}
634+
635+
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
636+
TORCH_CHECK(
637+
status == AVSUCCESS,
638+
"avcodec_open2 failed: ",
639+
getFFMPEGErrorStringFromErrorCode(status));
640+
641+
AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
642+
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
643+
644+
// Set the stream time base to encode correct frame timestamps
645+
avStream->time_base = avCodecContext_->time_base;
646+
status = avcodec_parameters_from_context(
647+
avStream->codecpar, avCodecContext_.get());
648+
TORCH_CHECK(
649+
status == AVSUCCESS,
650+
"avcodec_parameters_from_context failed: ",
651+
getFFMPEGErrorStringFromErrorCode(status));
652+
streamIndex_ = avStream->index;
653+
}
654+
655+
void VideoEncoder::encode() {
656+
// To be on the safe side we enforce that encode() can only be called once
657+
TORCH_CHECK(!encodeWasCalled_, "Cannot call encode() twice.");
658+
encodeWasCalled_ = true;
659+
660+
int status = avformat_write_header(avFormatContext_.get(), nullptr);
661+
TORCH_CHECK(
662+
status == AVSUCCESS,
663+
"Error in avformat_write_header: ",
664+
getFFMPEGErrorStringFromErrorCode(status));
665+
666+
AutoAVPacket autoAVPacket;
667+
int numFrames = frames_.sizes()[0];
668+
for (int i = 0; i < numFrames; ++i) {
669+
torch::Tensor currFrame = frames_[i];
670+
UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i);
671+
encodeFrame(autoAVPacket, avFrame);
672+
}
673+
674+
flushBuffers();
675+
676+
status = av_write_trailer(avFormatContext_.get());
677+
TORCH_CHECK(
678+
status == AVSUCCESS,
679+
"Error in av_write_trailer: ",
680+
getFFMPEGErrorStringFromErrorCode(status));
681+
}
682+
683+
UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
684+
const torch::Tensor& frame,
685+
int frameIndex) {
686+
// Initialize and cache scaling context if it does not exist
687+
if (!swsContext_) {
688+
swsContext_.reset(sws_getContext(
689+
inWidth_,
690+
inHeight_,
691+
inPixelFormat_,
692+
outWidth_,
693+
outHeight_,
694+
outPixelFormat_,
695+
SWS_BILINEAR,
696+
nullptr,
697+
nullptr,
698+
nullptr));
699+
TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context");
700+
}
701+
702+
UniqueAVFrame avFrame(av_frame_alloc());
703+
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");
704+
705+
// Set output frame properties
706+
avFrame->format = outPixelFormat_;
707+
avFrame->width = outWidth_;
708+
avFrame->height = outHeight_;
709+
avFrame->pts = frameIndex;
710+
711+
int status = av_frame_get_buffer(avFrame.get(), 0);
712+
TORCH_CHECK(status >= 0, "Failed to allocate frame buffer");
713+
714+
// Need to convert/scale the frame
715+
// Create temporary frame with input format
716+
UniqueAVFrame inputFrame(av_frame_alloc());
717+
TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame");
718+
719+
inputFrame->format = inPixelFormat_;
720+
inputFrame->width = inWidth_;
721+
inputFrame->height = inHeight_;
722+
723+
uint8_t* tensorData = static_cast<uint8_t*>(frame.data_ptr());
724+
725+
// TODO-VideoEncoder: Reorder tensor if in NHWC format
726+
int channelSize = inHeight_ * inWidth_;
727+
// Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
728+
// TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format
729+
inputFrame->data[0] = tensorData + channelSize;
730+
inputFrame->data[1] = tensorData + (2 * channelSize);
731+
inputFrame->data[2] = tensorData;
732+
733+
inputFrame->linesize[0] = inWidth_;
734+
inputFrame->linesize[1] = inWidth_;
735+
inputFrame->linesize[2] = inWidth_;
736+
737+
status = sws_scale(
738+
swsContext_.get(),
739+
inputFrame->data,
740+
inputFrame->linesize,
741+
0,
742+
inputFrame->height,
743+
avFrame->data,
744+
avFrame->linesize);
745+
TORCH_CHECK(status == outHeight_, "sws_scale failed");
746+
return avFrame;
747+
}
748+
749+
void VideoEncoder::encodeFrame(
750+
AutoAVPacket& autoAVPacket,
751+
const UniqueAVFrame& avFrame) {
752+
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
753+
TORCH_CHECK(
754+
status == AVSUCCESS,
755+
"Error while sending frame: ",
756+
getFFMPEGErrorStringFromErrorCode(status));
757+
758+
while (true) {
759+
ReferenceAVPacket packet(autoAVPacket);
760+
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
761+
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
762+
if (status == AVERROR_EOF) {
763+
// Flush remaining buffered packets
764+
status = av_interleaved_write_frame(avFormatContext_.get(), nullptr);
765+
TORCH_CHECK(
766+
status == AVSUCCESS,
767+
"Failed to flush packet: ",
768+
getFFMPEGErrorStringFromErrorCode(status));
769+
}
770+
return;
771+
}
772+
TORCH_CHECK(
773+
status >= 0,
774+
"Error receiving packet: ",
775+
getFFMPEGErrorStringFromErrorCode(status));
776+
777+
packet->stream_index = streamIndex_;
778+
779+
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
780+
TORCH_CHECK(
781+
status == AVSUCCESS,
782+
"Error in av_interleaved_write_frame: ",
783+
getFFMPEGErrorStringFromErrorCode(status));
784+
}
785+
}
786+
787+
void VideoEncoder::flushBuffers() {
788+
AutoAVPacket autoAVPacket;
789+
// Send null frame to signal end of input
790+
encodeFrame(autoAVPacket, UniqueAVFrame(nullptr));
791+
}
792+
514793
} // namespace facebook::torchcodec

src/torchcodec/_core/Encoder.h

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class AudioEncoder {
5757
bool encodeWasCalled_ = false;
5858
int64_t lastEncodedAVFramePts_ = 0;
5959
};
60-
} // namespace facebook::torchcodec
6160

6261
/* clang-format off */
6362
//
@@ -121,3 +120,44 @@ class AudioEncoder {
121120
//
122121
//
123122
/* clang-format on */
123+
124+
class VideoEncoder {
125+
public:
126+
~VideoEncoder();
127+
128+
VideoEncoder(
129+
const torch::Tensor& frames,
130+
int frameRate,
131+
std::string_view fileName,
132+
const VideoStreamOptions& videoStreamOptions);
133+
134+
void encode();
135+
136+
private:
137+
void initializeEncoder(const VideoStreamOptions& videoStreamOptions);
138+
UniqueAVFrame convertTensorToAVFrame(
139+
const torch::Tensor& frame,
140+
int frameIndex);
141+
void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame);
142+
void flushBuffers();
143+
144+
UniqueEncodingAVFormatContext avFormatContext_;
145+
UniqueAVCodecContext avCodecContext_;
146+
int streamIndex_;
147+
UniqueSwsContext swsContext_;
148+
149+
const torch::Tensor frames_;
150+
int inFrameRate_;
151+
152+
int inWidth_ = -1;
153+
int inHeight_ = -1;
154+
AVPixelFormat inPixelFormat_ = AV_PIX_FMT_NONE;
155+
156+
int outWidth_ = -1;
157+
int outHeight_ = -1;
158+
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
159+
160+
bool encodeWasCalled_ = false;
161+
};
162+
163+
} // namespace facebook::torchcodec

src/torchcodec/_core/StreamOptions.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ struct VideoStreamOptions {
3838
std::optional<ColorConversionLibrary> colorConversionLibrary;
3939
// By default we use CPU for decoding for both C++ and python users.
4040
torch::Device device = torch::kCPU;
41+
42+
// Encoding options
43+
std::optional<int> bitRate;
44+
std::optional<int> gopSize;
45+
std::optional<int> maxBFrames;
4146
};
4247

4348
struct AudioStreamOptions {

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
encode_audio_to_file,
2626
encode_audio_to_file_like,
2727
encode_audio_to_tensor,
28+
encode_video_to_file,
2829
get_ffmpeg_library_versions,
2930
get_frame_at_index,
3031
get_frame_at_pts,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3131
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3232
m.def(
3333
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
34+
m.def(
35+
"encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()");
3436
m.def(
3537
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
3638
m.def(
@@ -397,6 +399,19 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
397399
return makeOpsAudioFramesOutput(result);
398400
}
399401

402+
void encode_video_to_file(
403+
const at::Tensor& frames,
404+
int64_t frame_rate,
405+
std::string_view file_name) {
406+
VideoStreamOptions videoStreamOptions;
407+
VideoEncoder(
408+
frames,
409+
validateInt64ToInt(frame_rate, "frame_rate"),
410+
file_name,
411+
videoStreamOptions)
412+
.encode();
413+
}
414+
400415
void encode_audio_to_file(
401416
const at::Tensor& samples,
402417
int64_t sample_rate,
@@ -701,6 +716,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
701716

702717
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
703718
m.impl("encode_audio_to_file", &encode_audio_to_file);
719+
m.impl("encode_video_to_file", &encode_video_to_file);
704720
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
705721
m.impl("seek_to_pts", &seek_to_pts);
706722
m.impl("add_video_stream", &add_video_stream);

0 commit comments

Comments
 (0)