@@ -511,4 +511,283 @@ void AudioEncoder::flushBuffers() {
511
511
512
512
encodeFrame (autoAVPacket, UniqueAVFrame (nullptr ));
513
513
}
514
+
515
+ namespace {
516
+
517
+ torch::Tensor validateFrames (const torch::Tensor& frames) {
518
+ TORCH_CHECK (
519
+ frames.dtype () == torch::kUInt8 ,
520
+ " frames must have uint8 dtype, got " ,
521
+ frames.dtype ());
522
+ TORCH_CHECK (
523
+ frames.dim () == 4 ,
524
+ " frames must have 4 dimensions (N, C, H, W), got " ,
525
+ frames.dim ());
526
+ TORCH_CHECK (
527
+ frames.sizes ()[1 ] == 3 ,
528
+ " frame must have 3 channels (R, G, B), got " ,
529
+ frames.sizes ()[1 ]);
530
+ // TODO-VideoEncoder: Investigate if non-contiguous frames can be accepted
531
+ return frames.contiguous ();
532
+ }
533
+
534
+ } // namespace
535
+
536
+ VideoEncoder::~VideoEncoder () {
537
+ if (avFormatContext_ && avFormatContext_->pb ) {
538
+ avio_flush (avFormatContext_->pb );
539
+ avio_close (avFormatContext_->pb );
540
+ avFormatContext_->pb = nullptr ;
541
+ }
542
+ }
543
+
544
+ VideoEncoder::VideoEncoder (
545
+ const torch::Tensor& frames,
546
+ int frameRate,
547
+ std::string_view fileName,
548
+ const VideoStreamOptions& videoStreamOptions)
549
+ : frames_(validateFrames(frames)), inFrameRate_(frameRate) {
550
+ setFFmpegLogLevel ();
551
+
552
+ // Allocate output format context
553
+ AVFormatContext* avFormatContext = nullptr ;
554
+ int status = avformat_alloc_output_context2 (
555
+ &avFormatContext, nullptr , nullptr , fileName.data ());
556
+
557
+ TORCH_CHECK (
558
+ avFormatContext != nullptr ,
559
+ " Couldn't allocate AVFormatContext. " ,
560
+ " The destination file is " ,
561
+ fileName,
562
+ " , check the desired extension? " ,
563
+ getFFMPEGErrorStringFromErrorCode (status));
564
+ avFormatContext_.reset (avFormatContext);
565
+
566
+ status = avio_open (&avFormatContext_->pb , fileName.data (), AVIO_FLAG_WRITE);
567
+ TORCH_CHECK (
568
+ status >= 0 ,
569
+ " avio_open failed. The destination file is " ,
570
+ fileName,
571
+ " , make sure it's a valid path? " ,
572
+ getFFMPEGErrorStringFromErrorCode (status));
573
+ // TODO-VideoEncoder: Add tests for above fileName related checks
574
+
575
+ initializeEncoder (videoStreamOptions);
576
+ }
577
+
578
+ void VideoEncoder::initializeEncoder (
579
+ const VideoStreamOptions& videoStreamOptions) {
580
+ const AVCodec* avCodec =
581
+ avcodec_find_encoder (avFormatContext_->oformat ->video_codec );
582
+ TORCH_CHECK (avCodec != nullptr , " Video codec not found" );
583
+
584
+ AVCodecContext* avCodecContext = avcodec_alloc_context3 (avCodec);
585
+ TORCH_CHECK (avCodecContext != nullptr , " Couldn't allocate codec context." );
586
+ avCodecContext_.reset (avCodecContext);
587
+
588
+ // Set encoding options
589
+ // TODO-VideoEncoder: Allow bitrate to be set
590
+ std::optional<int > desiredBitRate = videoStreamOptions.bitRate ;
591
+ if (desiredBitRate.has_value ()) {
592
+ TORCH_CHECK (
593
+ *desiredBitRate >= 0 , " bit_rate=" , *desiredBitRate, " must be >= 0." );
594
+ }
595
+ avCodecContext_->bit_rate = desiredBitRate.value_or (0 );
596
+
597
+ // Store dimension order and input pixel format
598
+ // TODO-VideoEncoder: Remove assumption that tensor in NCHW format
599
+ auto sizes = frames_.sizes ();
600
+ inPixelFormat_ = AV_PIX_FMT_GBRP;
601
+ inHeight_ = sizes[2 ];
602
+ inWidth_ = sizes[3 ];
603
+
604
+ // Use specified dimensions or input dimensions
605
+ // TODO-VideoEncoder: Allow height and width to be set
606
+ outWidth_ = videoStreamOptions.width .value_or (inWidth_);
607
+ outHeight_ = videoStreamOptions.height .value_or (inHeight_);
608
+
609
+ // Use YUV420P as default output format
610
+ // TODO-VideoEncoder: Enable other pixel formats
611
+ outPixelFormat_ = AV_PIX_FMT_YUV420P;
612
+
613
+ // Configure codec parameters
614
+ avCodecContext_->codec_id = avCodec->id ;
615
+ avCodecContext_->width = outWidth_;
616
+ avCodecContext_->height = outHeight_;
617
+ avCodecContext_->pix_fmt = outPixelFormat_;
618
+ // TODO-VideoEncoder: Verify that frame_rate and time_base are correct
619
+ avCodecContext_->time_base = {1 , inFrameRate_};
620
+ avCodecContext_->framerate = {inFrameRate_, 1 };
621
+
622
+ // TODO-VideoEncoder: Allow GOP size and max B-frames to be set
623
+ if (videoStreamOptions.gopSize .has_value ()) {
624
+ avCodecContext_->gop_size = *videoStreamOptions.gopSize ;
625
+ } else {
626
+ avCodecContext_->gop_size = 12 ; // Default GOP size
627
+ }
628
+
629
+ if (videoStreamOptions.maxBFrames .has_value ()) {
630
+ avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames ;
631
+ } else {
632
+ avCodecContext_->max_b_frames = 0 ; // No max B-frames to reduce compression
633
+ }
634
+
635
+ int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
636
+ TORCH_CHECK (
637
+ status == AVSUCCESS,
638
+ " avcodec_open2 failed: " ,
639
+ getFFMPEGErrorStringFromErrorCode (status));
640
+
641
+ AVStream* avStream = avformat_new_stream (avFormatContext_.get (), nullptr );
642
+ TORCH_CHECK (avStream != nullptr , " Couldn't create new stream." );
643
+
644
+ // Set the stream time base to encode correct frame timestamps
645
+ avStream->time_base = avCodecContext_->time_base ;
646
+ status = avcodec_parameters_from_context (
647
+ avStream->codecpar , avCodecContext_.get ());
648
+ TORCH_CHECK (
649
+ status == AVSUCCESS,
650
+ " avcodec_parameters_from_context failed: " ,
651
+ getFFMPEGErrorStringFromErrorCode (status));
652
+ streamIndex_ = avStream->index ;
653
+ }
654
+
655
+ void VideoEncoder::encode () {
656
+ // To be on the safe side we enforce that encode() can only be called once
657
+ TORCH_CHECK (!encodeWasCalled_, " Cannot call encode() twice." );
658
+ encodeWasCalled_ = true ;
659
+
660
+ int status = avformat_write_header (avFormatContext_.get (), nullptr );
661
+ TORCH_CHECK (
662
+ status == AVSUCCESS,
663
+ " Error in avformat_write_header: " ,
664
+ getFFMPEGErrorStringFromErrorCode (status));
665
+
666
+ AutoAVPacket autoAVPacket;
667
+ int numFrames = frames_.sizes ()[0 ];
668
+ for (int i = 0 ; i < numFrames; ++i) {
669
+ torch::Tensor currFrame = frames_[i];
670
+ UniqueAVFrame avFrame = convertTensorToAVFrame (currFrame, i);
671
+ encodeFrame (autoAVPacket, avFrame);
672
+ }
673
+
674
+ flushBuffers ();
675
+
676
+ status = av_write_trailer (avFormatContext_.get ());
677
+ TORCH_CHECK (
678
+ status == AVSUCCESS,
679
+ " Error in av_write_trailer: " ,
680
+ getFFMPEGErrorStringFromErrorCode (status));
681
+ }
682
+
683
+ UniqueAVFrame VideoEncoder::convertTensorToAVFrame (
684
+ const torch::Tensor& frame,
685
+ int frameIndex) {
686
+ // Initialize and cache scaling context if it does not exist
687
+ if (!swsContext_) {
688
+ swsContext_.reset (sws_getContext (
689
+ inWidth_,
690
+ inHeight_,
691
+ inPixelFormat_,
692
+ outWidth_,
693
+ outHeight_,
694
+ outPixelFormat_,
695
+ SWS_BILINEAR,
696
+ nullptr ,
697
+ nullptr ,
698
+ nullptr ));
699
+ TORCH_CHECK (swsContext_ != nullptr , " Failed to create scaling context" );
700
+ }
701
+
702
+ UniqueAVFrame avFrame (av_frame_alloc ());
703
+ TORCH_CHECK (avFrame != nullptr , " Failed to allocate AVFrame" );
704
+
705
+ // Set output frame properties
706
+ avFrame->format = outPixelFormat_;
707
+ avFrame->width = outWidth_;
708
+ avFrame->height = outHeight_;
709
+ avFrame->pts = frameIndex;
710
+
711
+ int status = av_frame_get_buffer (avFrame.get (), 0 );
712
+ TORCH_CHECK (status >= 0 , " Failed to allocate frame buffer" );
713
+
714
+ // Need to convert/scale the frame
715
+ // Create temporary frame with input format
716
+ UniqueAVFrame inputFrame (av_frame_alloc ());
717
+ TORCH_CHECK (inputFrame != nullptr , " Failed to allocate input AVFrame" );
718
+
719
+ inputFrame->format = inPixelFormat_;
720
+ inputFrame->width = inWidth_;
721
+ inputFrame->height = inHeight_;
722
+
723
+ uint8_t * tensorData = static_cast <uint8_t *>(frame.data_ptr ());
724
+
725
+ // TODO-VideoEncoder: Reorder tensor if in NHWC format
726
+ int channelSize = inHeight_ * inWidth_;
727
+ // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format
728
+ // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format
729
+ inputFrame->data [0 ] = tensorData + channelSize;
730
+ inputFrame->data [1 ] = tensorData + (2 * channelSize);
731
+ inputFrame->data [2 ] = tensorData;
732
+
733
+ inputFrame->linesize [0 ] = inWidth_;
734
+ inputFrame->linesize [1 ] = inWidth_;
735
+ inputFrame->linesize [2 ] = inWidth_;
736
+
737
+ status = sws_scale (
738
+ swsContext_.get (),
739
+ inputFrame->data ,
740
+ inputFrame->linesize ,
741
+ 0 ,
742
+ inputFrame->height ,
743
+ avFrame->data ,
744
+ avFrame->linesize );
745
+ TORCH_CHECK (status == outHeight_, " sws_scale failed" );
746
+ return avFrame;
747
+ }
748
+
749
+ void VideoEncoder::encodeFrame (
750
+ AutoAVPacket& autoAVPacket,
751
+ const UniqueAVFrame& avFrame) {
752
+ auto status = avcodec_send_frame (avCodecContext_.get (), avFrame.get ());
753
+ TORCH_CHECK (
754
+ status == AVSUCCESS,
755
+ " Error while sending frame: " ,
756
+ getFFMPEGErrorStringFromErrorCode (status));
757
+
758
+ while (true ) {
759
+ ReferenceAVPacket packet (autoAVPacket);
760
+ status = avcodec_receive_packet (avCodecContext_.get (), packet.get ());
761
+ if (status == AVERROR (EAGAIN) || status == AVERROR_EOF) {
762
+ if (status == AVERROR_EOF) {
763
+ // Flush remaining buffered packets
764
+ status = av_interleaved_write_frame (avFormatContext_.get (), nullptr );
765
+ TORCH_CHECK (
766
+ status == AVSUCCESS,
767
+ " Failed to flush packet: " ,
768
+ getFFMPEGErrorStringFromErrorCode (status));
769
+ }
770
+ return ;
771
+ }
772
+ TORCH_CHECK (
773
+ status >= 0 ,
774
+ " Error receiving packet: " ,
775
+ getFFMPEGErrorStringFromErrorCode (status));
776
+
777
+ packet->stream_index = streamIndex_;
778
+
779
+ status = av_interleaved_write_frame (avFormatContext_.get (), packet.get ());
780
+ TORCH_CHECK (
781
+ status == AVSUCCESS,
782
+ " Error in av_interleaved_write_frame: " ,
783
+ getFFMPEGErrorStringFromErrorCode (status));
784
+ }
785
+ }
786
+
787
+ void VideoEncoder::flushBuffers () {
788
+ AutoAVPacket autoAVPacket;
789
+ // Send null frame to signal end of input
790
+ encodeFrame (autoAVPacket, UniqueAVFrame (nullptr ));
791
+ }
792
+
514
793
} // namespace facebook::torchcodec
0 commit comments