Skip to content

Commit fe738a6

Browse files
Dan-FloresDaniel Flores
andauthored
Add to_file_like support for VideoEncoder (#958)
Co-authored-by: Daniel Flores <[email protected]>
1 parent 262c457 commit fe738a6

File tree

4 files changed

+178
-7
lines changed

4 files changed

+178
-7
lines changed

src/torchcodec/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
encode_audio_to_file_like,
2727
encode_audio_to_tensor,
2828
encode_video_to_file,
29+
encode_video_to_file_like,
2930
encode_video_to_tensor,
3031
get_ffmpeg_library_versions,
3132
get_frame_at_index,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4040
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
4141
m.def(
4242
"encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor");
43+
m.def(
44+
"_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()");
4345
m.def(
4446
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
4547
m.def(
@@ -628,6 +630,30 @@ at::Tensor encode_video_to_tensor(
628630
.encodeToTensor();
629631
}
630632

633+
void _encode_video_to_file_like(
634+
const at::Tensor& frames,
635+
int64_t frame_rate,
636+
std::string_view format,
637+
int64_t file_like_context,
638+
std::optional<int64_t> crf = std::nullopt) {
639+
auto fileLikeContext =
640+
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
641+
TORCH_CHECK(
642+
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
643+
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
644+
645+
VideoStreamOptions videoStreamOptions;
646+
videoStreamOptions.crf = crf;
647+
648+
VideoEncoder encoder(
649+
frames,
650+
validateInt64ToInt(frame_rate, "frame_rate"),
651+
format,
652+
std::move(avioContextHolder),
653+
videoStreamOptions);
654+
encoder.encode();
655+
}
656+
631657
// For testing only. We need to implement this operation as a core library
632658
// function because what we're testing is round-tripping pts values as
633659
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -892,6 +918,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
892918
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
893919
m.impl("encode_video_to_file", &encode_video_to_file);
894920
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
921+
m.impl("_encode_video_to_file_like", &_encode_video_to_file_like);
895922
m.impl("seek_to_pts", &seek_to_pts);
896923
m.impl("add_video_stream", &add_video_stream);
897924
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def load_torchcodec_shared_libraries():
104104
encode_video_to_tensor = torch._dynamo.disallow_in_graph(
105105
torch.ops.torchcodec_ns.encode_video_to_tensor.default
106106
)
107+
_encode_video_to_file_like = torch._dynamo.disallow_in_graph(
108+
torch.ops.torchcodec_ns._encode_video_to_file_like.default
109+
)
107110
create_from_tensor = torch._dynamo.disallow_in_graph(
108111
torch.ops.torchcodec_ns.create_from_tensor.default
109112
)
@@ -203,6 +206,33 @@ def encode_audio_to_file_like(
203206
)
204207

205208

209+
def encode_video_to_file_like(
210+
frames: torch.Tensor,
211+
frame_rate: int,
212+
format: str,
213+
file_like: Union[io.RawIOBase, io.BufferedIOBase],
214+
crf: Optional[int] = None,
215+
) -> None:
216+
"""Encode video frames to a file-like object.
217+
218+
Args:
219+
frames: Video frames tensor
220+
frame_rate: Frame rate in frames per second
221+
format: Video format (e.g., "mp4", "mov", "mkv")
222+
file_like: File-like object that supports write() and seek() methods
223+
crf: Optional constant rate factor for encoding quality
224+
"""
225+
assert _pybind_ops is not None
226+
227+
_encode_video_to_file_like(
228+
frames,
229+
frame_rate,
230+
format,
231+
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
232+
crf,
233+
)
234+
235+
206236
def get_frames_at_indices(
207237
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
208238
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -302,6 +332,17 @@ def encode_video_to_tensor_abstract(
302332
return torch.empty([], dtype=torch.long)
303333

304334

335+
@register_fake("torchcodec_ns::_encode_video_to_file_like")
336+
def _encode_video_to_file_like_abstract(
337+
frames: torch.Tensor,
338+
frame_rate: int,
339+
format: str,
340+
file_like_context: int,
341+
crf: Optional[int] = None,
342+
) -> None:
343+
return
344+
345+
305346
@register_fake("torchcodec_ns::create_from_tensor")
306347
def create_from_tensor_abstract(
307348
video_tensor: torch.Tensor, seek_mode: Optional[str]

test/test_ops.py

Lines changed: 109 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
create_from_tensor,
2929
encode_audio_to_file,
3030
encode_video_to_file,
31+
encode_video_to_file_like,
3132
encode_video_to_tensor,
3233
get_ffmpeg_library_versions,
3334
get_frame_at_index,
@@ -1151,7 +1152,7 @@ def test_bad_input(self, tmp_path):
11511152

11521153

11531154
class TestVideoEncoderOps:
1154-
1155+
# TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity)
11551156
# TODO-VideoEncoder: Parametrize test after moving to test_encoders
11561157
def test_bad_input(self, tmp_path):
11571158
output_file = str(tmp_path / ".mp4")
@@ -1219,7 +1220,7 @@ def decode(self, source=None) -> torch.Tensor:
12191220
@pytest.mark.parametrize(
12201221
"format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow))
12211222
)
1222-
@pytest.mark.parametrize("method", ("to_file", "to_tensor"))
1223+
@pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like"))
12231224
def test_video_encoder_round_trip(self, tmp_path, format, method):
12241225
# Test that decode(encode(decode(frames))) == decode(frames)
12251226
ffmpeg_version = get_ffmpeg_major_version()
@@ -1246,11 +1247,22 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
12461247
**params,
12471248
)
12481249
round_trip_frames = self.decode(encoded_path).data
1249-
else: # to_tensor
1250+
elif method == "to_tensor":
12501251
encoded_tensor = encode_video_to_tensor(
12511252
source_frames, format=format, **params
12521253
)
12531254
round_trip_frames = self.decode(encoded_tensor).data
1255+
elif method == "to_file_like":
1256+
file_like = io.BytesIO()
1257+
encode_video_to_file_like(
1258+
frames=source_frames,
1259+
format=format,
1260+
file_like=file_like,
1261+
**params,
1262+
)
1263+
round_trip_frames = self.decode(file_like.getvalue()).data
1264+
else:
1265+
raise ValueError(f"Unknown method: {method}")
12541266

12551267
assert source_frames.shape == round_trip_frames.shape
12561268
assert source_frames.dtype == round_trip_frames.dtype
@@ -1279,8 +1291,9 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
12791291
pytest.param("webm", marks=pytest.mark.slow),
12801292
),
12811293
)
1282-
def test_against_to_file(self, tmp_path, format):
1283-
# Test that to_file and to_tensor produce the same results
1294+
@pytest.mark.parametrize("method", ("to_tensor", "to_file_like"))
1295+
def test_against_to_file(self, tmp_path, format, method):
1296+
# Test that to_file, to_tensor, and to_file_like produce the same results
12841297
ffmpeg_version = get_ffmpeg_major_version()
12851298
if format == "webm" and (
12861299
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
@@ -1292,11 +1305,24 @@ def test_against_to_file(self, tmp_path, format):
12921305

12931306
encoded_file = tmp_path / f"output.{format}"
12941307
encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params)
1295-
encoded_tensor = encode_video_to_tensor(source_frames, format=format, **params)
1308+
1309+
if method == "to_tensor":
1310+
encoded_output = encode_video_to_tensor(
1311+
source_frames, format=format, **params
1312+
)
1313+
else: # to_file_like
1314+
file_like = io.BytesIO()
1315+
encode_video_to_file_like(
1316+
frames=source_frames,
1317+
file_like=file_like,
1318+
format=format,
1319+
**params,
1320+
)
1321+
encoded_output = file_like.getvalue()
12961322

12971323
torch.testing.assert_close(
12981324
self.decode(encoded_file).data,
1299-
self.decode(encoded_tensor).data,
1325+
self.decode(encoded_output).data,
13001326
atol=0,
13011327
rtol=0,
13021328
)
@@ -1379,6 +1405,82 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
13791405
ff_frame, enc_frame, percentage=percentage, atol=2
13801406
)
13811407

1408+
def test_to_file_like_custom_file_object(self):
1409+
"""Test with a custom file-like object that implements write and seek."""
1410+
1411+
class CustomFileObject:
1412+
def __init__(self):
1413+
self._file = io.BytesIO()
1414+
1415+
def write(self, data):
1416+
return self._file.write(data)
1417+
1418+
def seek(self, offset, whence=0):
1419+
return self._file.seek(offset, whence)
1420+
1421+
def get_encoded_data(self):
1422+
return self._file.getvalue()
1423+
1424+
source_frames = self.decode(TEST_SRC_2_720P.path).data
1425+
file_like = CustomFileObject()
1426+
encode_video_to_file_like(
1427+
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
1428+
)
1429+
decoded_samples = self.decode(file_like.get_encoded_data())
1430+
1431+
torch.testing.assert_close(
1432+
decoded_samples.data,
1433+
source_frames,
1434+
atol=2,
1435+
rtol=0,
1436+
)
1437+
1438+
def test_to_file_like_real_file(self, tmp_path):
1439+
"""Test to_file_like with a real file opened in binary write mode."""
1440+
source_frames = self.decode(TEST_SRC_2_720P.path).data
1441+
file_path = tmp_path / "test_file_like.mp4"
1442+
1443+
with open(file_path, "wb") as file_like:
1444+
encode_video_to_file_like(
1445+
source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like
1446+
)
1447+
decoded_samples = self.decode(str(file_path))
1448+
1449+
torch.testing.assert_close(
1450+
decoded_samples.data,
1451+
source_frames,
1452+
atol=2,
1453+
rtol=0,
1454+
)
1455+
1456+
def test_to_file_like_bad_methods(self):
1457+
source_frames = self.decode(TEST_SRC_2_720P.path).data
1458+
1459+
class NoWriteMethod:
1460+
def seek(self, offset, whence=0):
1461+
return 0
1462+
1463+
with pytest.raises(
1464+
RuntimeError, match="File like object must implement a write method"
1465+
):
1466+
encode_video_to_file_like(
1467+
source_frames,
1468+
frame_rate=30,
1469+
format="mp4",
1470+
file_like=NoWriteMethod(),
1471+
)
1472+
1473+
class NoSeekMethod:
1474+
def write(self, data):
1475+
return len(data)
1476+
1477+
with pytest.raises(
1478+
RuntimeError, match="File like object must implement a seek method"
1479+
):
1480+
encode_video_to_file_like(
1481+
source_frames, frame_rate=30, format="mp4", file_like=NoSeekMethod()
1482+
)
1483+
13821484

13831485
if __name__ == "__main__":
13841486
pytest.main()

0 commit comments

Comments
 (0)