2828 create_from_tensor ,
2929 encode_audio_to_file ,
3030 encode_video_to_file ,
31+ encode_video_to_file_like ,
3132 encode_video_to_tensor ,
3233 get_ffmpeg_library_versions ,
3334 get_frame_at_index ,
@@ -1151,7 +1152,7 @@ def test_bad_input(self, tmp_path):
11511152
11521153
11531154class TestVideoEncoderOps :
1154-
1155+ # TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity)
11551156 # TODO-VideoEncoder: Parametrize test after moving to test_encoders
11561157 def test_bad_input (self , tmp_path ):
11571158 output_file = str (tmp_path / ".mp4" )
@@ -1219,7 +1220,7 @@ def decode(self, source=None) -> torch.Tensor:
12191220 @pytest .mark .parametrize (
12201221 "format" , ("mov" , "mp4" , "mkv" , pytest .param ("webm" , marks = pytest .mark .slow ))
12211222 )
1222- @pytest .mark .parametrize ("method" , ("to_file" , "to_tensor" ))
1223+ @pytest .mark .parametrize ("method" , ("to_file" , "to_tensor" , "to_file_like" ))
12231224 def test_video_encoder_round_trip (self , tmp_path , format , method ):
12241225 # Test that decode(encode(decode(frames))) == decode(frames)
12251226 ffmpeg_version = get_ffmpeg_major_version ()
@@ -1246,11 +1247,22 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
12461247 ** params ,
12471248 )
12481249 round_trip_frames = self .decode (encoded_path ).data
1249- else : # to_tensor
1250+ elif method == " to_tensor" :
12501251 encoded_tensor = encode_video_to_tensor (
12511252 source_frames , format = format , ** params
12521253 )
12531254 round_trip_frames = self .decode (encoded_tensor ).data
1255+ elif method == "to_file_like" :
1256+ file_like = io .BytesIO ()
1257+ encode_video_to_file_like (
1258+ frames = source_frames ,
1259+ format = format ,
1260+ file_like = file_like ,
1261+ ** params ,
1262+ )
1263+ round_trip_frames = self .decode (file_like .getvalue ()).data
1264+ else :
1265+ raise ValueError (f"Unknown method: { method } " )
12541266
12551267 assert source_frames .shape == round_trip_frames .shape
12561268 assert source_frames .dtype == round_trip_frames .dtype
@@ -1279,8 +1291,9 @@ def test_video_encoder_round_trip(self, tmp_path, format, method):
12791291 pytest .param ("webm" , marks = pytest .mark .slow ),
12801292 ),
12811293 )
1282- def test_against_to_file (self , tmp_path , format ):
1283- # Test that to_file and to_tensor produce the same results
1294+ @pytest .mark .parametrize ("method" , ("to_tensor" , "to_file_like" ))
1295+ def test_against_to_file (self , tmp_path , format , method ):
1296+ # Test that to_file, to_tensor, and to_file_like produce the same results
12841297 ffmpeg_version = get_ffmpeg_major_version ()
12851298 if format == "webm" and (
12861299 ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6 , 7 ))
@@ -1292,11 +1305,24 @@ def test_against_to_file(self, tmp_path, format):
12921305
12931306 encoded_file = tmp_path / f"output.{ format } "
12941307 encode_video_to_file (frames = source_frames , filename = str (encoded_file ), ** params )
1295- encoded_tensor = encode_video_to_tensor (source_frames , format = format , ** params )
1308+
1309+ if method == "to_tensor" :
1310+ encoded_output = encode_video_to_tensor (
1311+ source_frames , format = format , ** params
1312+ )
1313+ else : # to_file_like
1314+ file_like = io .BytesIO ()
1315+ encode_video_to_file_like (
1316+ frames = source_frames ,
1317+ file_like = file_like ,
1318+ format = format ,
1319+ ** params ,
1320+ )
1321+ encoded_output = file_like .getvalue ()
12961322
12971323 torch .testing .assert_close (
12981324 self .decode (encoded_file ).data ,
1299- self .decode (encoded_tensor ).data ,
1325+ self .decode (encoded_output ).data ,
13001326 atol = 0 ,
13011327 rtol = 0 ,
13021328 )
@@ -1379,6 +1405,82 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
13791405 ff_frame , enc_frame , percentage = percentage , atol = 2
13801406 )
13811407
1408+ def test_to_file_like_custom_file_object (self ):
1409+ """Test with a custom file-like object that implements write and seek."""
1410+
1411+ class CustomFileObject :
1412+ def __init__ (self ):
1413+ self ._file = io .BytesIO ()
1414+
1415+ def write (self , data ):
1416+ return self ._file .write (data )
1417+
1418+ def seek (self , offset , whence = 0 ):
1419+ return self ._file .seek (offset , whence )
1420+
1421+ def get_encoded_data (self ):
1422+ return self ._file .getvalue ()
1423+
1424+ source_frames = self .decode (TEST_SRC_2_720P .path ).data
1425+ file_like = CustomFileObject ()
1426+ encode_video_to_file_like (
1427+ source_frames , frame_rate = 30 , crf = 0 , format = "mp4" , file_like = file_like
1428+ )
1429+ decoded_samples = self .decode (file_like .get_encoded_data ())
1430+
1431+ torch .testing .assert_close (
1432+ decoded_samples .data ,
1433+ source_frames ,
1434+ atol = 2 ,
1435+ rtol = 0 ,
1436+ )
1437+
1438+ def test_to_file_like_real_file (self , tmp_path ):
1439+ """Test to_file_like with a real file opened in binary write mode."""
1440+ source_frames = self .decode (TEST_SRC_2_720P .path ).data
1441+ file_path = tmp_path / "test_file_like.mp4"
1442+
1443+ with open (file_path , "wb" ) as file_like :
1444+ encode_video_to_file_like (
1445+ source_frames , frame_rate = 30 , crf = 0 , format = "mp4" , file_like = file_like
1446+ )
1447+ decoded_samples = self .decode (str (file_path ))
1448+
1449+ torch .testing .assert_close (
1450+ decoded_samples .data ,
1451+ source_frames ,
1452+ atol = 2 ,
1453+ rtol = 0 ,
1454+ )
1455+
1456+ def test_to_file_like_bad_methods (self ):
1457+ source_frames = self .decode (TEST_SRC_2_720P .path ).data
1458+
1459+ class NoWriteMethod :
1460+ def seek (self , offset , whence = 0 ):
1461+ return 0
1462+
1463+ with pytest .raises (
1464+ RuntimeError , match = "File like object must implement a write method"
1465+ ):
1466+ encode_video_to_file_like (
1467+ source_frames ,
1468+ frame_rate = 30 ,
1469+ format = "mp4" ,
1470+ file_like = NoWriteMethod (),
1471+ )
1472+
1473+ class NoSeekMethod :
1474+ def write (self , data ):
1475+ return len (data )
1476+
1477+ with pytest .raises (
1478+ RuntimeError , match = "File like object must implement a seek method"
1479+ ):
1480+ encode_video_to_file_like (
1481+ source_frames , frame_rate = 30 , format = "mp4" , file_like = NoSeekMethod ()
1482+ )
1483+
13821484
13831485if __name__ == "__main__" :
13841486 pytest .main ()
0 commit comments