55
66import torch
77from torch import Tensor
8- from torchcodec ._core import add_video_stream , create_from_file , get_frames_by_pts
98from torchcodec .decoders import VideoDecoder
109from torchvision .transforms import v2
1110
12- DEFAULT_NUM_EXP = 20
1311
14-
15- def bench (f , * args , num_exp = DEFAULT_NUM_EXP , warmup = 1 ) -> Tensor :
12+ def bench (f , * args , num_exp , warmup = 1 ) -> Tensor :
1613
1714 for _ in range (warmup ):
1815 f (* args )
@@ -45,37 +42,55 @@ def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float:
4542
4643
4744def torchvision_resize (
48- path : Path , pts_seconds : list [float ], dims : tuple [int , int ]
49- ) -> None :
50- decoder = create_from_file (str (path ), seek_mode = "approximate" )
51- add_video_stream (decoder )
52- raw_frames , * _ = get_frames_by_pts (decoder , timestamps = pts_seconds )
53- return v2 .functional .resize (raw_frames , size = dims )
45+ path : Path , pts_seconds : list [float ], dims : tuple [int , int ], num_threads : int
46+ ) -> Tensor :
47+ decoder = VideoDecoder (
48+ path , seek_mode = "approximate" , num_ffmpeg_threads = num_threads
49+ )
50+ raw_frames = decoder .get_frames_played_at (pts_seconds )
51+ transformed_frames = v2 .Resize (size = dims )(raw_frames .data )
52+ assert len (transformed_frames ) == len (pts_seconds )
53+ return transformed_frames
5454
5555
5656def torchvision_crop (
57- path : Path , pts_seconds : list [float ], dims : tuple [int , int ], x : int , y : int
58- ) -> None :
59- decoder = create_from_file (str (path ), seek_mode = "approximate" )
60- add_video_stream (decoder )
61- raw_frames , * _ = get_frames_by_pts (decoder , timestamps = pts_seconds )
62- return v2 .functional .crop (raw_frames , top = y , left = x , height = dims [0 ], width = dims [1 ])
63-
64-
65- def decoder_native_resize (
66- path : Path , pts_seconds : list [float ], dims : tuple [int , int ]
67- ) -> None :
68- decoder = create_from_file (str (path ), seek_mode = "approximate" )
69- add_video_stream (decoder , transform_specs = f"resize, { dims [0 ]} , { dims [1 ]} " )
70- return get_frames_by_pts (decoder , timestamps = pts_seconds )[0 ]
71-
72-
73- def decoder_native_crop (
74- path : Path , pts_seconds : list [float ], dims : tuple [int , int ], x : int , y : int
75- ) -> None :
76- decoder = create_from_file (str (path ), seek_mode = "approximate" )
77- add_video_stream (decoder , transform_specs = f"crop, { dims [0 ]} , { dims [1 ]} , { x } , { y } " )
78- return get_frames_by_pts (decoder , timestamps = pts_seconds )[0 ]
57+ path : Path , pts_seconds : list [float ], dims : tuple [int , int ], num_threads : int
58+ ) -> Tensor :
59+ decoder = VideoDecoder (
60+ path , seek_mode = "approximate" , num_ffmpeg_threads = num_threads
61+ )
62+ raw_frames = decoder .get_frames_played_at (pts_seconds )
63+ transformed_frames = v2 .CenterCrop (size = dims )(raw_frames .data )
64+ assert len (transformed_frames ) == len (pts_seconds )
65+ return transformed_frames
66+
67+
68+ def decoder_resize (
69+ path : Path , pts_seconds : list [float ], dims : tuple [int , int ], num_threads : int
70+ ) -> Tensor :
71+ decoder = VideoDecoder (
72+ path ,
73+ transforms = [v2 .Resize (size = dims )],
74+ seek_mode = "approximate" ,
75+ num_ffmpeg_threads = num_threads ,
76+ )
77+ transformed_frames = decoder .get_frames_played_at (pts_seconds ).data
78+ assert len (transformed_frames ) == len (pts_seconds )
79+ return transformed_frames .data
80+
81+
82+ def decoder_crop (
83+ path : Path , pts_seconds : list [float ], dims : tuple [int , int ], num_threads : int
84+ ) -> Tensor :
85+ decoder = VideoDecoder (
86+ path ,
87+ transforms = [v2 .CenterCrop (size = dims )],
88+ seek_mode = "approximate" ,
89+ num_ffmpeg_threads = num_threads ,
90+ )
91+ transformed_frames = decoder .get_frames_played_at (pts_seconds ).data
92+ assert len (transformed_frames ) == len (pts_seconds )
93+ return transformed_frames
7994
8095
8196def main ():
@@ -84,9 +99,27 @@ def main():
8499 parser .add_argument (
85100 "--num-exp" ,
86101 type = int ,
87- default = DEFAULT_NUM_EXP ,
102+ default = 5 ,
88103 help = "number of runs to average over" ,
89104 )
105+ parser .add_argument (
106+ "--num-threads" ,
107+ type = int ,
108+ default = 1 ,
109+ help = "number of threads to use; 0 means FFmpeg decides" ,
110+ )
111+ parser .add_argument (
112+ "--total-frame-fractions" ,
113+ nargs = "+" ,
114+ type = float ,
115+ default = [0.005 , 0.01 , 0.05 , 0.1 ],
116+ )
117+ parser .add_argument (
118+ "--input-dimension-fractions" ,
119+ nargs = "+" ,
120+ type = float ,
121+ default = [0.5 , 0.25 , 0.125 ],
122+ )
90123
91124 args = parser .parse_args ()
92125 path = Path (args .path )
@@ -100,10 +133,7 @@ def main():
100133
101134 input_height = metadata .height
102135 input_width = metadata .width
103- fraction_of_total_frames_to_sample = [0.005 , 0.01 , 0.05 , 0.1 ]
104- fraction_of_input_dimensions = [0.5 , 0.25 , 0.125 ]
105-
106- for num_fraction in fraction_of_total_frames_to_sample :
136+ for num_fraction in args .total_frame_fractions :
107137 num_frames_to_sample = math .ceil (metadata .num_frames * num_fraction )
108138 print (
109139 f"Sampling { num_fraction * 100 } %, { num_frames_to_sample } , of { metadata .num_frames } frames"
@@ -112,51 +142,49 @@ def main():
112142 i * duration / num_frames_to_sample for i in range (num_frames_to_sample )
113143 ]
114144
115- for dims_fraction in fraction_of_input_dimensions :
145+ for dims_fraction in args . input_dimension_fractions :
116146 dims = (int (input_height * dims_fraction ), int (input_width * dims_fraction ))
117147
118148 times = bench (
119- torchvision_resize , path , uniform_timestamps , dims , num_exp = args .num_exp
149+ torchvision_resize ,
150+ path ,
151+ uniform_timestamps ,
152+ dims ,
153+ args .num_threads ,
154+ num_exp = args .num_exp ,
120155 )
121156 report_stats (times , prefix = f"torchvision_resize({ dims } )" )
122157
123158 times = bench (
124- decoder_native_resize ,
159+ decoder_resize ,
125160 path ,
126161 uniform_timestamps ,
127162 dims ,
163+ args .num_threads ,
128164 num_exp = args .num_exp ,
129165 )
130- report_stats (times , prefix = f"decoder_native_resize({ dims } )" )
131- print ()
166+ report_stats (times , prefix = f"decoder_resize({ dims } )" )
132167
133- center_x = (input_height - dims [0 ]) // 2
134- center_y = (input_width - dims [1 ]) // 2
135168 times = bench (
136169 torchvision_crop ,
137170 path ,
138171 uniform_timestamps ,
139172 dims ,
140- center_x ,
141- center_y ,
173+ args .num_threads ,
142174 num_exp = args .num_exp ,
143175 )
144- report_stats (
145- times , prefix = f"torchvision_crop({ dims } , { center_x } , { center_y } )"
146- )
176+ report_stats (times , prefix = f"torchvision_crop({ dims } )" )
147177
148178 times = bench (
149- decoder_native_crop ,
179+ decoder_crop ,
150180 path ,
151181 uniform_timestamps ,
152182 dims ,
153- center_x ,
154- center_y ,
183+ args .num_threads ,
155184 num_exp = args .num_exp ,
156185 )
157- report_stats (
158- times , prefix = f"decoder_native_crop({ dims } , { center_x } , { center_y } )"
159- )
186+ report_stats (times , prefix = f"decoder_crop({ dims } )" )
187+
160188 print ()
161189
162190
0 commit comments