diff --git a/benchmarks/decoders/benchmark_transforms.py b/benchmarks/decoders/benchmark_transforms.py index 75a49d63b..8342ef7f8 100644 --- a/benchmarks/decoders/benchmark_transforms.py +++ b/benchmarks/decoders/benchmark_transforms.py @@ -5,14 +5,11 @@ import torch from torch import Tensor -from torchcodec._core import add_video_stream, create_from_file, get_frames_by_pts from torchcodec.decoders import VideoDecoder from torchvision.transforms import v2 -DEFAULT_NUM_EXP = 20 - -def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1) -> Tensor: +def bench(f, *args, num_exp, warmup=1) -> Tensor: for _ in range(warmup): f(*args) @@ -45,37 +42,55 @@ def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float: def torchvision_resize( - path: Path, pts_seconds: list[float], dims: tuple[int, int] -) -> None: - decoder = create_from_file(str(path), seek_mode="approximate") - add_video_stream(decoder) - raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds) - return v2.functional.resize(raw_frames, size=dims) + path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int +) -> Tensor: + decoder = VideoDecoder( + path, seek_mode="approximate", num_ffmpeg_threads=num_threads + ) + raw_frames = decoder.get_frames_played_at(pts_seconds) + transformed_frames = v2.Resize(size=dims)(raw_frames.data) + assert len(transformed_frames) == len(pts_seconds) + return transformed_frames def torchvision_crop( - path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int -) -> None: - decoder = create_from_file(str(path), seek_mode="approximate") - add_video_stream(decoder) - raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds) - return v2.functional.crop(raw_frames, top=y, left=x, height=dims[0], width=dims[1]) - - -def decoder_native_resize( - path: Path, pts_seconds: list[float], dims: tuple[int, int] -) -> None: - decoder = create_from_file(str(path), seek_mode="approximate") - add_video_stream(decoder, transform_specs=f"resize, {dims[0]}, {dims[1]}") - return get_frames_by_pts(decoder, timestamps=pts_seconds)[0] - - -def decoder_native_crop( - path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int -) -> None: - decoder = create_from_file(str(path), seek_mode="approximate") - add_video_stream(decoder, transform_specs=f"crop, {dims[0]}, {dims[1]}, {x}, {y}") - return get_frames_by_pts(decoder, timestamps=pts_seconds)[0] + path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int +) -> Tensor: + decoder = VideoDecoder( + path, seek_mode="approximate", num_ffmpeg_threads=num_threads + ) + raw_frames = decoder.get_frames_played_at(pts_seconds) + transformed_frames = v2.CenterCrop(size=dims)(raw_frames.data) + assert len(transformed_frames) == len(pts_seconds) + return transformed_frames + + +def decoder_resize( + path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int +) -> Tensor: + decoder = VideoDecoder( + path, + transforms=[v2.Resize(size=dims)], + seek_mode="approximate", + num_ffmpeg_threads=num_threads, + ) + transformed_frames = decoder.get_frames_played_at(pts_seconds).data + assert len(transformed_frames) == len(pts_seconds) + return transformed_frames.data + + +def decoder_crop( + path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int +) -> Tensor: + decoder = VideoDecoder( + path, + transforms=[v2.CenterCrop(size=dims)], + seek_mode="approximate", + num_ffmpeg_threads=num_threads, + ) + transformed_frames = decoder.get_frames_played_at(pts_seconds).data + assert len(transformed_frames) == len(pts_seconds) + return transformed_frames def main(): @@ -84,9 +99,27 @@ def main(): parser.add_argument( "--num-exp", type=int, - default=DEFAULT_NUM_EXP, + default=5, help="number of runs to average over", ) + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="number of threads to use; 0 means FFmpeg decides", + ) + parser.add_argument( + "--total-frame-fractions", + nargs="+", + type=float, + default=[0.005, 0.01, 0.05, 0.1], + ) + parser.add_argument( + "--input-dimension-fractions", + nargs="+", + type=float, + default=[0.5, 0.25, 0.125], + ) args = parser.parse_args() path = Path(args.path) @@ -100,10 +133,7 @@ def main(): input_height = metadata.height input_width = metadata.width - fraction_of_total_frames_to_sample = [0.005, 0.01, 0.05, 0.1] - fraction_of_input_dimensions = [0.5, 0.25, 0.125] - - for num_fraction in fraction_of_total_frames_to_sample: + for num_fraction in args.total_frame_fractions: num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction) print( f"Sampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames" @@ -112,51 +142,49 @@ def main(): i * duration / num_frames_to_sample for i in range(num_frames_to_sample) ] - for dims_fraction in fraction_of_input_dimensions: + for dims_fraction in args.input_dimension_fractions: dims = (int(input_height * dims_fraction), int(input_width * dims_fraction)) times = bench( - torchvision_resize, path, uniform_timestamps, dims, num_exp=args.num_exp + torchvision_resize, + path, + uniform_timestamps, + dims, + args.num_threads, + num_exp=args.num_exp, ) report_stats(times, prefix=f"torchvision_resize({dims})") times = bench( - decoder_native_resize, + decoder_resize, path, uniform_timestamps, dims, + args.num_threads, num_exp=args.num_exp, ) - report_stats(times, prefix=f"decoder_native_resize({dims})") - print() + report_stats(times, prefix=f"decoder_resize({dims})") - center_x = (input_height - dims[0]) // 2 - center_y = (input_width - dims[1]) // 2 times = bench( torchvision_crop, path, uniform_timestamps, dims, - center_x, - center_y, + args.num_threads, num_exp=args.num_exp, ) - report_stats( - times, prefix=f"torchvision_crop({dims}, {center_x}, {center_y})" - ) + report_stats(times, prefix=f"torchvision_crop({dims})") times = bench( - decoder_native_crop, + decoder_crop, path, uniform_timestamps, dims, - center_x, - center_y, + args.num_threads, num_exp=args.num_exp, ) - report_stats( - times, prefix=f"decoder_native_crop({dims}, {center_x}, {center_y})" - ) + report_stats(times, prefix=f"decoder_crop({dims})") + print()