@@ -38,6 +38,14 @@ def __init__(self):
3838 def get_frames_from_video (self , video_file , pts_list ):
3939 pass
4040
41+ @abc .abstractmethod
42+ def get_consecutive_frames_from_video (self , video_file , numFramesToDecode ):
43+ pass
44+
45+ @abc .abstractmethod
46+ def decode_and_transform (self , video_file , pts_list , height , width , device ):
47+ pass
48+
4149
4250class DecordAccurate (AbstractDecoder ):
4351 def __init__ (self ):
@@ -89,8 +97,10 @@ def __init__(self, backend):
8997 self ._backend = backend
9098 self ._print_each_iteration_time = False
9199 import torchvision # noqa: F401
100+ from torchvision .transforms import v2 as transforms_v2
92101
93102 self .torchvision = torchvision
103+ self .transforms_v2 = transforms_v2
94104
95105 def get_frames_from_video (self , video_file , pts_list ):
96106 self .torchvision .set_video_backend (self ._backend )
@@ -111,6 +121,20 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
111121 frames .append (frame ["data" ].permute (1 , 2 , 0 ))
112122 return frames
113123
124+ def decode_and_transform (self , video_file , pts_list , height , width , device ):
125+ self .torchvision .set_video_backend (self ._backend )
126+ reader = self .torchvision .io .VideoReader (video_file , "video" )
127+ frames = []
128+ for pts in pts_list :
129+ reader .seek (pts )
130+ frame = next (reader )
131+ frames .append (frame ["data" ].permute (1 , 2 , 0 ))
132+ frames = [
133+ self .transforms_v2 .functional .resize (frame .to (device ), (height , width ))
134+ for frame in frames
135+ ]
136+ return frames
137+
114138
115139class TorchCodecCore (AbstractDecoder ):
116140 def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
@@ -239,6 +263,10 @@ def __init__(self, num_ffmpeg_threads=None, device="cpu"):
239263 )
240264 self ._device = device
241265
266+ from torchvision .transforms import v2 as transforms_v2
267+
268+ self .transforms_v2 = transforms_v2
269+
242270 def get_frames_from_video (self , video_file , pts_list ):
243271 decoder = VideoDecoder (
244272 video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads , device = self ._device
@@ -258,6 +286,14 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
258286 break
259287 return frames
260288
289+ def decode_and_transform (self , video_file , pts_list , height , width , device ):
290+ decoder = VideoDecoder (
291+ video_file , num_ffmpeg_threads = self ._num_ffmpeg_threads , device = self ._device
292+ )
293+ frames = decoder .get_frames_played_at (pts_list )
294+ frames = self .transforms_v2 .functional .resize (frames .data , (height , width ))
295+ return frames
296+
261297
262298@torch .compile (fullgraph = True , backend = "eager" )
263299def compiled_seek_and_next (decoder , pts ):
@@ -299,7 +335,9 @@ def __init__(self):
299335
300336 self .torchaudio = torchaudio
301337
302- pass
338+ from torchvision .transforms import v2 as transforms_v2
339+
340+ self .transforms_v2 = transforms_v2
303341
304342 def get_frames_from_video (self , video_file , pts_list ):
305343 stream_reader = self .torchaudio .io .StreamReader (src = video_file )
@@ -325,6 +363,21 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
325363
326364 return frames
327365
366+ def decode_and_transform (self , video_file , pts_list , height , width , device ):
367+ stream_reader = self .torchaudio .io .StreamReader (src = video_file )
368+ stream_reader .add_basic_video_stream (frames_per_chunk = 1 )
369+ frames = []
370+ for pts in pts_list :
371+ stream_reader .seek (pts )
372+ stream_reader .fill_buffer ()
373+ clip = stream_reader .pop_chunks ()
374+ frames .append (clip [0 ][0 ])
375+ frames = [
376+ self .transforms_v2 .functional .resize (frame .to (device ), (height , width ))
377+ for frame in frames
378+ ]
379+ return frames
380+
328381
329382def create_torchcodec_decoder_from_file (video_file ):
330383 video_decoder = create_from_file (video_file )
@@ -443,7 +496,7 @@ def plot_data(df_data, plot_path):
443496
444497 # Set the title for the subplot
445498 base_video = Path (video ).name .removesuffix (".mp4" )
446- ax .set_title (f"{ base_video } \n { vcount } x { vtype } " , fontsize = 11 )
499+ ax .set_title (f"{ base_video } \n { vtype } " , fontsize = 11 )
447500
448501 # Plot bars with error bars
449502 ax .barh (
@@ -486,6 +539,14 @@ class BatchParameters:
486539 batch_size : int
487540
488541
542+ @dataclass
543+ class DataLoaderInspiredWorkloadParameters :
544+ batch_parameters : BatchParameters
545+ resize_height : int
546+ resize_width : int
547+ resize_device : str
548+
549+
489550def run_batch_using_threads (
490551 function ,
491552 * args ,
@@ -525,6 +586,7 @@ def run_benchmarks(
525586 num_sequential_frames_from_start : list [int ],
526587 min_runtime_seconds : float ,
527588 benchmark_video_creation : bool ,
589+ dataloader_parameters : DataLoaderInspiredWorkloadParameters = None ,
528590 batch_parameters : BatchParameters = None ,
529591) -> list [dict [str , str | float | int ]]:
530592 # Ensure that we have the same seed across benchmark runs.
@@ -550,6 +612,39 @@ def run_benchmarks(
550612 for decoder_name , decoder in decoder_dict .items ():
551613 print (f"video={ video_file_path } , decoder={ decoder_name } " )
552614
615+ if dataloader_parameters :
616+ bp = dataloader_parameters .batch_parameters
617+ dataloader_result = benchmark .Timer (
618+ stmt = "run_batch_using_threads(decoder.decode_and_transform, video_file, pts_list, height, width, device, batch_parameters=batch_parameters)" ,
619+ globals = {
620+ "video_file" : str (video_file_path ),
621+ "pts_list" : uniform_pts_list ,
622+ "decoder" : decoder ,
623+ "run_batch_using_threads" : run_batch_using_threads ,
624+ "batch_parameters" : dataloader_parameters .batch_parameters ,
625+ "height" : dataloader_parameters .resize_height ,
626+ "width" : dataloader_parameters .resize_width ,
627+ "device" : dataloader_parameters .resize_device ,
628+ },
629+ label = f"video={ video_file_path } { metadata_label } " ,
630+ sub_label = decoder_name ,
631+ description = f"dataloader[threads={ bp .num_threads } batch_size={ bp .batch_size } ] { num_samples } decode_and_transform()" ,
632+ )
633+ results .append (
634+ dataloader_result .blocked_autorange (
635+ min_run_time = min_runtime_seconds
636+ )
637+ )
638+ df_data .append (
639+ convert_result_to_df_item (
640+ results [- 1 ],
641+ decoder_name ,
642+ video_file_path ,
643+ num_samples * dataloader_parameters .batch_parameters .batch_size ,
644+ f"dataloader[threads={ bp .num_threads } batch_size={ bp .batch_size } ] { num_samples } x decode_and_transform()" ,
645+ )
646+ )
647+
553648 for kind , pts_list in [
554649 ("uniform" , uniform_pts_list ),
555650 ("random" , random_pts_list ),
0 commit comments