Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 83 additions & 55 deletions benchmarks/decoders/benchmark_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@

import torch
from torch import Tensor
from torchcodec._core import add_video_stream, create_from_file, get_frames_by_pts
from torchcodec.decoders import VideoDecoder
from torchvision.transforms import v2

DEFAULT_NUM_EXP = 20


def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1) -> Tensor:
def bench(f, *args, num_exp, warmup=1) -> Tensor:

for _ in range(warmup):
f(*args)
Expand Down Expand Up @@ -45,37 +42,55 @@ def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float:


def torchvision_resize(
path: Path, pts_seconds: list[float], dims: tuple[int, int]
) -> None:
decoder = create_from_file(str(path), seek_mode="approximate")
add_video_stream(decoder)
raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
return v2.functional.resize(raw_frames, size=dims)
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
) -> Tensor:
decoder = VideoDecoder(
path, seek_mode="approximate", num_ffmpeg_threads=num_threads
)
raw_frames = decoder.get_frames_played_at(pts_seconds)
transformed_frames = v2.Resize(size=dims)(raw_frames.data)
assert len(transformed_frames) == len(pts_seconds)
return transformed_frames


def torchvision_crop(
path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
) -> None:
decoder = create_from_file(str(path), seek_mode="approximate")
add_video_stream(decoder)
raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
return v2.functional.crop(raw_frames, top=y, left=x, height=dims[0], width=dims[1])


def decoder_native_resize(
path: Path, pts_seconds: list[float], dims: tuple[int, int]
) -> None:
decoder = create_from_file(str(path), seek_mode="approximate")
add_video_stream(decoder, transform_specs=f"resize, {dims[0]}, {dims[1]}")
return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]


def decoder_native_crop(
path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
) -> None:
decoder = create_from_file(str(path), seek_mode="approximate")
add_video_stream(decoder, transform_specs=f"crop, {dims[0]}, {dims[1]}, {x}, {y}")
return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
) -> Tensor:
decoder = VideoDecoder(
path, seek_mode="approximate", num_ffmpeg_threads=num_threads
)
raw_frames = decoder.get_frames_played_at(pts_seconds)
transformed_frames = v2.CenterCrop(size=dims)(raw_frames.data)
assert len(transformed_frames) == len(pts_seconds)
return transformed_frames


def decoder_resize(
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
) -> Tensor:
decoder = VideoDecoder(
path,
transforms=[v2.Resize(size=dims)],
seek_mode="approximate",
num_ffmpeg_threads=num_threads,
)
transformed_frames = decoder.get_frames_played_at(pts_seconds).data
assert len(transformed_frames) == len(pts_seconds)
return transformed_frames.data


def decoder_crop(
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
) -> Tensor:
decoder = VideoDecoder(
path,
transforms=[v2.CenterCrop(size=dims)],
seek_mode="approximate",
num_ffmpeg_threads=num_threads,
)
transformed_frames = decoder.get_frames_played_at(pts_seconds).data
assert len(transformed_frames) == len(pts_seconds)
return transformed_frames


def main():
Expand All @@ -84,9 +99,27 @@ def main():
parser.add_argument(
"--num-exp",
type=int,
default=DEFAULT_NUM_EXP,
default=5,
help="number of runs to average over",
)
parser.add_argument(
"--num-threads",
type=int,
default=1,
help="number of threads to use; 0 means FFmpeg decides",
)
Comment on lines +105 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to also call torch.set_num_threads(args.num_threads) when num_threads != 0? In the current conditions of the benchmark, I think torch's resize isn't multi-threaded, so this should have no effect. But there are code paths where it is multithreaded over the batch dimension, depending on the input dtype and the interpolation mode (example: https://github.com/pytorch/pytorch/blame/afb173d9b9440d804b5f77d0c291e53c720d1fcf/aten/src/ATen/native/cpu/UpSampleKernel.cpp#L2024C18-L2024C18).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it was as clean as just providing this value to the Torch APIs, I think we should just do it now. But because of the weirdness with 0 (I don't think torch.set_num_threads() has an equivalent for automatic deciding), I think we may want to control it with another flag or do some logic (n_cpus // 2). For those reasons, I'd rather punt on that until we want to get numbers for those scenarios.

parser.add_argument(
"--total-frame-fractions",
nargs="+",
type=float,
default=[0.005, 0.01, 0.05, 0.1],
)
parser.add_argument(
"--input-dimension-fractions",
nargs="+",
type=float,
default=[0.5, 0.25, 0.125],
)

args = parser.parse_args()
path = Path(args.path)
Expand All @@ -100,10 +133,7 @@ def main():

input_height = metadata.height
input_width = metadata.width
fraction_of_total_frames_to_sample = [0.005, 0.01, 0.05, 0.1]
fraction_of_input_dimensions = [0.5, 0.25, 0.125]

for num_fraction in fraction_of_total_frames_to_sample:
for num_fraction in args.total_frame_fractions:
num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction)
print(
f"Sampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames"
Expand All @@ -112,51 +142,49 @@ def main():
i * duration / num_frames_to_sample for i in range(num_frames_to_sample)
]

for dims_fraction in fraction_of_input_dimensions:
for dims_fraction in args.input_dimension_fractions:
dims = (int(input_height * dims_fraction), int(input_width * dims_fraction))

times = bench(
torchvision_resize, path, uniform_timestamps, dims, num_exp=args.num_exp
torchvision_resize,
path,
uniform_timestamps,
dims,
args.num_threads,
num_exp=args.num_exp,
)
report_stats(times, prefix=f"torchvision_resize({dims})")

times = bench(
decoder_native_resize,
decoder_resize,
path,
uniform_timestamps,
dims,
args.num_threads,
num_exp=args.num_exp,
)
report_stats(times, prefix=f"decoder_native_resize({dims})")
print()
report_stats(times, prefix=f"decoder_resize({dims})")

center_x = (input_height - dims[0]) // 2
center_y = (input_width - dims[1]) // 2
times = bench(
torchvision_crop,
path,
uniform_timestamps,
dims,
center_x,
center_y,
args.num_threads,
num_exp=args.num_exp,
)
report_stats(
times, prefix=f"torchvision_crop({dims}, {center_x}, {center_y})"
)
report_stats(times, prefix=f"torchvision_crop({dims})")

times = bench(
decoder_native_crop,
decoder_crop,
path,
uniform_timestamps,
dims,
center_x,
center_y,
args.num_threads,
num_exp=args.num_exp,
)
report_stats(
times, prefix=f"decoder_native_crop({dims}, {center_x}, {center_y})"
)
report_stats(times, prefix=f"decoder_crop({dims})")

print()


Expand Down
Loading