Skip to content

Commit 2117716

Browse files
Dan-FloresDaniel Flores
andauthored
Add to_tensor support for VideoEncoder (#957)
Co-authored-by: Daniel Flores <[email protected]>
1 parent 3827dfe commit 2117716

File tree

8 files changed

+231
-84
lines changed

8 files changed

+231
-84
lines changed

src/torchcodec/_core/AVIOTensorContext.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,34 +18,34 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
1818
int read(void* opaque, uint8_t* buf, int buf_size) {
1919
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
2020
TORCH_CHECK(
21-
tensorContext->current <= tensorContext->data.numel(),
22-
"Tried to read outside of the buffer: current=",
23-
tensorContext->current,
21+
tensorContext->current_pos <= tensorContext->data.numel(),
22+
"Tried to read outside of the buffer: current_pos=",
23+
tensorContext->current_pos,
2424
", size=",
2525
tensorContext->data.numel());
2626

2727
int64_t numBytesRead = std::min(
2828
static_cast<int64_t>(buf_size),
29-
tensorContext->data.numel() - tensorContext->current);
29+
tensorContext->data.numel() - tensorContext->current_pos);
3030

3131
TORCH_CHECK(
3232
numBytesRead >= 0,
3333
"Tried to read negative bytes: numBytesRead=",
3434
numBytesRead,
3535
", size=",
3636
tensorContext->data.numel(),
37-
", current=",
38-
tensorContext->current);
37+
", current_pos=",
38+
tensorContext->current_pos);
3939

4040
if (numBytesRead == 0) {
4141
return AVERROR_EOF;
4242
}
4343

4444
std::memcpy(
4545
buf,
46-
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current,
46+
tensorContext->data.data_ptr<uint8_t>() + tensorContext->current_pos,
4747
numBytesRead);
48-
tensorContext->current += numBytesRead;
48+
tensorContext->current_pos += numBytesRead;
4949
return numBytesRead;
5050
}
5151

@@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
5454
auto tensorContext = static_cast<detail::TensorContext*>(opaque);
5555

5656
int64_t bufSize = static_cast<int64_t>(buf_size);
57-
if (tensorContext->current + bufSize > tensorContext->data.numel()) {
57+
if (tensorContext->current_pos + bufSize > tensorContext->data.numel()) {
5858
TORCH_CHECK(
5959
tensorContext->data.numel() * 2 <= MAX_TENSOR_SIZE,
6060
"We tried to allocate an output encoded tensor larger than ",
@@ -68,13 +68,17 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
6868
}
6969

7070
TORCH_CHECK(
71-
tensorContext->current + bufSize <= tensorContext->data.numel(),
71+
tensorContext->current_pos + bufSize <= tensorContext->data.numel(),
7272
"Re-allocation of the output tensor didn't work. ",
7373
"This should not happen, please report on TorchCodec bug tracker");
7474

7575
uint8_t* outputTensorData = tensorContext->data.data_ptr<uint8_t>();
76-
std::memcpy(outputTensorData + tensorContext->current, buf, bufSize);
77-
tensorContext->current += bufSize;
76+
std::memcpy(outputTensorData + tensorContext->current_pos, buf, bufSize);
77+
tensorContext->current_pos += bufSize;
78+
// Track the maximum position written so getOutputTensor's narrow() does not
79+
// truncate the file if final seek was backwards
80+
tensorContext->max_pos =
81+
std::max(tensorContext->current_pos, tensorContext->max_pos);
7882
return buf_size;
7983
}
8084

@@ -88,7 +92,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
8892
ret = tensorContext->data.numel();
8993
break;
9094
case SEEK_SET:
91-
tensorContext->current = offset;
95+
tensorContext->current_pos = offset;
9296
ret = offset;
9397
break;
9498
default:
@@ -101,7 +105,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
101105
} // namespace
102106

103107
AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
104-
: tensorContext_{data, 0} {
108+
: tensorContext_{data, 0, 0} {
105109
TORCH_CHECK(data.numel() > 0, "data must not be empty");
106110
TORCH_CHECK(data.is_contiguous(), "data must be contiguous");
107111
TORCH_CHECK(data.scalar_type() == torch::kUInt8, "data must be kUInt8");
@@ -110,14 +114,17 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
110114
}
111115

112116
AVIOToTensorContext::AVIOToTensorContext()
113-
: tensorContext_{torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}), 0} {
117+
: tensorContext_{
118+
torch::empty({INITIAL_TENSOR_SIZE}, {torch::kUInt8}),
119+
0,
120+
0} {
114121
createAVIOContext(
115122
nullptr, &write, &seek, &tensorContext_, /*isForWriting=*/true);
116123
}
117124

118125
torch::Tensor AVIOToTensorContext::getOutputTensor() {
119126
return tensorContext_.data.narrow(
120-
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.current);
127+
/*dim=*/0, /*start=*/0, /*length=*/tensorContext_.max_pos);
121128
}
122129

123130
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOTensorContext.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ namespace detail {
1515

1616
struct TensorContext {
1717
torch::Tensor data;
18-
int64_t current;
18+
int64_t current_pos;
19+
int64_t max_pos;
1920
};
2021

2122
} // namespace detail

src/torchcodec/_core/Encoder.cpp

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@
44
#include "src/torchcodec/_core/Encoder.h"
55
#include "torch/types.h"
66

7-
extern "C" {
8-
#include <libavutil/pixdesc.h>
9-
}
10-
117
namespace facebook::torchcodec {
128

139
namespace {
@@ -542,10 +538,17 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
542538
} // namespace
543539

544540
VideoEncoder::~VideoEncoder() {
541+
// TODO-VideoEncoder: Unify destructor with ~AudioEncoder()
545542
if (avFormatContext_ && avFormatContext_->pb) {
546-
avio_flush(avFormatContext_->pb);
547-
avio_close(avFormatContext_->pb);
548-
avFormatContext_->pb = nullptr;
543+
if (avFormatContext_->pb->error == 0) {
544+
avio_flush(avFormatContext_->pb);
545+
}
546+
if (!avioContextHolder_) {
547+
if (avFormatContext_->pb->error == 0) {
548+
avio_close(avFormatContext_->pb);
549+
}
550+
avFormatContext_->pb = nullptr;
551+
}
549552
}
550553
}
551554

@@ -581,6 +584,36 @@ VideoEncoder::VideoEncoder(
581584
initializeEncoder(videoStreamOptions);
582585
}
583586

587+
VideoEncoder::VideoEncoder(
588+
const torch::Tensor& frames,
589+
int frameRate,
590+
std::string_view formatName,
591+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
592+
const VideoStreamOptions& videoStreamOptions)
593+
: frames_(validateFrames(frames)),
594+
inFrameRate_(frameRate),
595+
avioContextHolder_(std::move(avioContextHolder)) {
596+
setFFmpegLogLevel();
597+
// Map mkv -> matroska when used as format name
598+
formatName = (formatName == "mkv") ? "matroska" : formatName;
599+
AVFormatContext* avFormatContext = nullptr;
600+
int status = avformat_alloc_output_context2(
601+
&avFormatContext, nullptr, formatName.data(), nullptr);
602+
603+
TORCH_CHECK(
604+
avFormatContext != nullptr,
605+
"Couldn't allocate AVFormatContext. ",
606+
"Check the desired format? Got format=",
607+
formatName,
608+
". ",
609+
getFFMPEGErrorStringFromErrorCode(status));
610+
avFormatContext_.reset(avFormatContext);
611+
612+
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
613+
614+
initializeEncoder(videoStreamOptions);
615+
}
616+
584617
void VideoEncoder::initializeEncoder(
585618
const VideoStreamOptions& videoStreamOptions) {
586619
const AVCodec* avCodec =
@@ -751,6 +784,17 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
751784
return avFrame;
752785
}
753786

787+
torch::Tensor VideoEncoder::encodeToTensor() {
788+
TORCH_CHECK(
789+
avioContextHolder_ != nullptr,
790+
"Cannot encode to tensor, avio tensor context doesn't exist.");
791+
encode();
792+
auto avioToTensorContext =
793+
dynamic_cast<AVIOToTensorContext*>(avioContextHolder_.get());
794+
TORCH_CHECK(avioToTensorContext != nullptr, "Invalid AVIO context holder.");
795+
return avioToTensorContext->getOutputTensor();
796+
}
797+
754798
void VideoEncoder::encodeFrame(
755799
AutoAVPacket& autoAVPacket,
756800
const UniqueAVFrame& avFrame) {

src/torchcodec/_core/Encoder.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,17 @@ class VideoEncoder {
141141
std::string_view fileName,
142142
const VideoStreamOptions& videoStreamOptions);
143143

144+
VideoEncoder(
145+
const torch::Tensor& frames,
146+
int frameRate,
147+
std::string_view formatName,
148+
std::unique_ptr<AVIOContextHolder> avioContextHolder,
149+
const VideoStreamOptions& videoStreamOptions);
150+
144151
void encode();
145152

153+
torch::Tensor encodeToTensor();
154+
146155
private:
147156
void initializeEncoder(const VideoStreamOptions& videoStreamOptions);
148157
UniqueAVFrame convertTensorToAVFrame(
@@ -167,6 +176,8 @@ class VideoEncoder {
167176
int outHeight_ = -1;
168177
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;
169178

179+
std::unique_ptr<AVIOContextHolder> avioContextHolder_;
180+
170181
bool encodeWasCalled_ = false;
171182
};
172183

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
encode_audio_to_file_like,
2727
encode_audio_to_tensor,
2828
encode_video_to_file,
29+
encode_video_to_tensor,
2930
get_ffmpeg_library_versions,
3031
get_frame_at_index,
3132
get_frame_at_pts,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,14 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3232
m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor");
3333
m.def(
3434
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
35-
m.def(
36-
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
3735
m.def(
3836
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
3937
m.def(
4038
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
39+
m.def(
40+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
41+
m.def(
42+
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
4143
m.def(
4244
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4345
m.def(
@@ -498,21 +500,6 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
498500
return makeOpsAudioFramesOutput(result);
499501
}
500502

501-
void encode_video_to_file(
502-
const at::Tensor& frames,
503-
int64_t frame_rate,
504-
std::string_view file_name,
505-
std::optional<int64_t> crf = std::nullopt) {
506-
VideoStreamOptions videoStreamOptions;
507-
videoStreamOptions.crf = crf;
508-
VideoEncoder(
509-
frames,
510-
validateInt64ToInt(frame_rate, "frame_rate"),
511-
file_name,
512-
videoStreamOptions)
513-
.encode();
514-
}
515-
516503
void encode_audio_to_file(
517504
const at::Tensor& samples,
518505
int64_t sample_rate,
@@ -587,6 +574,38 @@ void _encode_audio_to_file_like(
587574
encoder.encode();
588575
}
589576

577+
void encode_video_to_file(
578+
const at::Tensor& frames,
579+
int64_t frame_rate,
580+
std::string_view file_name,
581+
std::optional<int64_t> crf = std::nullopt) {
582+
VideoStreamOptions videoStreamOptions;
583+
videoStreamOptions.crf = crf;
584+
VideoEncoder(
585+
frames,
586+
validateInt64ToInt(frame_rate, "frame_rate"),
587+
file_name,
588+
videoStreamOptions)
589+
.encode();
590+
}
591+
592+
at::Tensor encode_video_to_tensor(
593+
const at::Tensor& frames,
594+
int64_t frame_rate,
595+
std::string_view format,
596+
std::optional<int64_t> crf = std::nullopt) {
597+
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
598+
VideoStreamOptions videoStreamOptions;
599+
videoStreamOptions.crf = crf;
600+
return VideoEncoder(
601+
frames,
602+
validateInt64ToInt(frame_rate, "frame_rate"),
603+
format,
604+
std::move(avioContextHolder),
605+
videoStreamOptions)
606+
.encodeToTensor();
607+
}
608+
590609
// For testing only. We need to implement this operation as a core library
591610
// function because what we're testing is round-tripping pts values as
592611
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -847,9 +866,10 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
847866

848867
TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
849868
m.impl("encode_audio_to_file", &encode_audio_to_file);
850-
m.impl("encode_video_to_file", &encode_video_to_file);
851869
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
852870
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
871+
m.impl("encode_video_to_file", &encode_video_to_file);
872+
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
853873
m.impl("seek_to_pts", &seek_to_pts);
854874
m.impl("add_video_stream", &add_video_stream);
855875
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,18 @@ def load_torchcodec_shared_libraries():
9292
encode_audio_to_file = torch._dynamo.disallow_in_graph(
9393
torch.ops.torchcodec_ns.encode_audio_to_file.default
9494
)
95-
encode_video_to_file = torch._dynamo.disallow_in_graph(
96-
torch.ops.torchcodec_ns.encode_video_to_file.default
97-
)
9895
encode_audio_to_tensor = torch._dynamo.disallow_in_graph(
9996
torch.ops.torchcodec_ns.encode_audio_to_tensor.default
10097
)
10198
_encode_audio_to_file_like = torch._dynamo.disallow_in_graph(
10299
torch.ops.torchcodec_ns._encode_audio_to_file_like.default
103100
)
101+
encode_video_to_file = torch._dynamo.disallow_in_graph(
102+
torch.ops.torchcodec_ns.encode_video_to_file.default
103+
)
104+
encode_video_to_tensor = torch._dynamo.disallow_in_graph(
105+
torch.ops.torchcodec_ns.encode_video_to_tensor.default
106+
)
104107
create_from_tensor = torch._dynamo.disallow_in_graph(
105108
torch.ops.torchcodec_ns.create_from_tensor.default
106109
)
@@ -254,16 +257,6 @@ def encode_audio_to_file_abstract(
254257
return
255258

256259

257-
@register_fake("torchcodec_ns::encode_video_to_file")
258-
def encode_video_to_file_abstract(
259-
frames: torch.Tensor,
260-
frame_rate: int,
261-
filename: str,
262-
crf: Optional[int] = None,
263-
) -> None:
264-
return
265-
266-
267260
@register_fake("torchcodec_ns::encode_audio_to_tensor")
268261
def encode_audio_to_tensor_abstract(
269262
samples: torch.Tensor,
@@ -289,6 +282,26 @@ def _encode_audio_to_file_like_abstract(
289282
return
290283

291284

285+
@register_fake("torchcodec_ns::encode_video_to_file")
286+
def encode_video_to_file_abstract(
287+
frames: torch.Tensor,
288+
frame_rate: int,
289+
filename: str,
290+
crf: Optional[int],
291+
) -> None:
292+
return
293+
294+
295+
@register_fake("torchcodec_ns::encode_video_to_tensor")
296+
def encode_video_to_tensor_abstract(
297+
frames: torch.Tensor,
298+
frame_rate: int,
299+
format: str,
300+
crf: Optional[int],
301+
) -> torch.Tensor:
302+
return torch.empty([], dtype=torch.long)
303+
304+
292305
@register_fake("torchcodec_ns::create_from_tensor")
293306
def create_from_tensor_abstract(
294307
video_tensor: torch.Tensor, seek_mode: Optional[str]

0 commit comments

Comments
 (0)