Skip to content

Commit 0574243

Browse files
authored
Fix transform benchmarks (#1118)
1 parent 877eb0c commit 0574243

File tree

1 file changed

+83
-55
lines changed

1 file changed

+83
-55
lines changed

benchmarks/decoders/benchmark_transforms.py

Lines changed: 83 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,11 @@
55

66
import torch
77
from torch import Tensor
8-
from torchcodec._core import add_video_stream, create_from_file, get_frames_by_pts
98
from torchcodec.decoders import VideoDecoder
109
from torchvision.transforms import v2
1110

12-
DEFAULT_NUM_EXP = 20
1311

14-
15-
def bench(f, *args, num_exp=DEFAULT_NUM_EXP, warmup=1) -> Tensor:
12+
def bench(f, *args, num_exp, warmup=1) -> Tensor:
1613

1714
for _ in range(warmup):
1815
f(*args)
@@ -45,37 +42,55 @@ def report_stats(times: Tensor, unit: str = "ms", prefix: str = "") -> float:
4542

4643

4744
def torchvision_resize(
48-
path: Path, pts_seconds: list[float], dims: tuple[int, int]
49-
) -> None:
50-
decoder = create_from_file(str(path), seek_mode="approximate")
51-
add_video_stream(decoder)
52-
raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
53-
return v2.functional.resize(raw_frames, size=dims)
45+
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
46+
) -> Tensor:
47+
decoder = VideoDecoder(
48+
path, seek_mode="approximate", num_ffmpeg_threads=num_threads
49+
)
50+
raw_frames = decoder.get_frames_played_at(pts_seconds)
51+
transformed_frames = v2.Resize(size=dims)(raw_frames.data)
52+
assert len(transformed_frames) == len(pts_seconds)
53+
return transformed_frames
5454

5555

5656
def torchvision_crop(
57-
path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
58-
) -> None:
59-
decoder = create_from_file(str(path), seek_mode="approximate")
60-
add_video_stream(decoder)
61-
raw_frames, *_ = get_frames_by_pts(decoder, timestamps=pts_seconds)
62-
return v2.functional.crop(raw_frames, top=y, left=x, height=dims[0], width=dims[1])
63-
64-
65-
def decoder_native_resize(
66-
path: Path, pts_seconds: list[float], dims: tuple[int, int]
67-
) -> None:
68-
decoder = create_from_file(str(path), seek_mode="approximate")
69-
add_video_stream(decoder, transform_specs=f"resize, {dims[0]}, {dims[1]}")
70-
return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]
71-
72-
73-
def decoder_native_crop(
74-
path: Path, pts_seconds: list[float], dims: tuple[int, int], x: int, y: int
75-
) -> None:
76-
decoder = create_from_file(str(path), seek_mode="approximate")
77-
add_video_stream(decoder, transform_specs=f"crop, {dims[0]}, {dims[1]}, {x}, {y}")
78-
return get_frames_by_pts(decoder, timestamps=pts_seconds)[0]
57+
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
58+
) -> Tensor:
59+
decoder = VideoDecoder(
60+
path, seek_mode="approximate", num_ffmpeg_threads=num_threads
61+
)
62+
raw_frames = decoder.get_frames_played_at(pts_seconds)
63+
transformed_frames = v2.CenterCrop(size=dims)(raw_frames.data)
64+
assert len(transformed_frames) == len(pts_seconds)
65+
return transformed_frames
66+
67+
68+
def decoder_resize(
69+
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
70+
) -> Tensor:
71+
decoder = VideoDecoder(
72+
path,
73+
transforms=[v2.Resize(size=dims)],
74+
seek_mode="approximate",
75+
num_ffmpeg_threads=num_threads,
76+
)
77+
transformed_frames = decoder.get_frames_played_at(pts_seconds).data
78+
assert len(transformed_frames) == len(pts_seconds)
79+
return transformed_frames.data
80+
81+
82+
def decoder_crop(
83+
path: Path, pts_seconds: list[float], dims: tuple[int, int], num_threads: int
84+
) -> Tensor:
85+
decoder = VideoDecoder(
86+
path,
87+
transforms=[v2.CenterCrop(size=dims)],
88+
seek_mode="approximate",
89+
num_ffmpeg_threads=num_threads,
90+
)
91+
transformed_frames = decoder.get_frames_played_at(pts_seconds).data
92+
assert len(transformed_frames) == len(pts_seconds)
93+
return transformed_frames
7994

8095

8196
def main():
@@ -84,9 +99,27 @@ def main():
8499
parser.add_argument(
85100
"--num-exp",
86101
type=int,
87-
default=DEFAULT_NUM_EXP,
102+
default=5,
88103
help="number of runs to average over",
89104
)
105+
parser.add_argument(
106+
"--num-threads",
107+
type=int,
108+
default=1,
109+
help="number of threads to use; 0 means FFmpeg decides",
110+
)
111+
parser.add_argument(
112+
"--total-frame-fractions",
113+
nargs="+",
114+
type=float,
115+
default=[0.005, 0.01, 0.05, 0.1],
116+
)
117+
parser.add_argument(
118+
"--input-dimension-fractions",
119+
nargs="+",
120+
type=float,
121+
default=[0.5, 0.25, 0.125],
122+
)
90123

91124
args = parser.parse_args()
92125
path = Path(args.path)
@@ -100,10 +133,7 @@ def main():
100133

101134
input_height = metadata.height
102135
input_width = metadata.width
103-
fraction_of_total_frames_to_sample = [0.005, 0.01, 0.05, 0.1]
104-
fraction_of_input_dimensions = [0.5, 0.25, 0.125]
105-
106-
for num_fraction in fraction_of_total_frames_to_sample:
136+
for num_fraction in args.total_frame_fractions:
107137
num_frames_to_sample = math.ceil(metadata.num_frames * num_fraction)
108138
print(
109139
f"Sampling {num_fraction * 100}%, {num_frames_to_sample}, of {metadata.num_frames} frames"
@@ -112,51 +142,49 @@ def main():
112142
i * duration / num_frames_to_sample for i in range(num_frames_to_sample)
113143
]
114144

115-
for dims_fraction in fraction_of_input_dimensions:
145+
for dims_fraction in args.input_dimension_fractions:
116146
dims = (int(input_height * dims_fraction), int(input_width * dims_fraction))
117147

118148
times = bench(
119-
torchvision_resize, path, uniform_timestamps, dims, num_exp=args.num_exp
149+
torchvision_resize,
150+
path,
151+
uniform_timestamps,
152+
dims,
153+
args.num_threads,
154+
num_exp=args.num_exp,
120155
)
121156
report_stats(times, prefix=f"torchvision_resize({dims})")
122157

123158
times = bench(
124-
decoder_native_resize,
159+
decoder_resize,
125160
path,
126161
uniform_timestamps,
127162
dims,
163+
args.num_threads,
128164
num_exp=args.num_exp,
129165
)
130-
report_stats(times, prefix=f"decoder_native_resize({dims})")
131-
print()
166+
report_stats(times, prefix=f"decoder_resize({dims})")
132167

133-
center_x = (input_height - dims[0]) // 2
134-
center_y = (input_width - dims[1]) // 2
135168
times = bench(
136169
torchvision_crop,
137170
path,
138171
uniform_timestamps,
139172
dims,
140-
center_x,
141-
center_y,
173+
args.num_threads,
142174
num_exp=args.num_exp,
143175
)
144-
report_stats(
145-
times, prefix=f"torchvision_crop({dims}, {center_x}, {center_y})"
146-
)
176+
report_stats(times, prefix=f"torchvision_crop({dims})")
147177

148178
times = bench(
149-
decoder_native_crop,
179+
decoder_crop,
150180
path,
151181
uniform_timestamps,
152182
dims,
153-
center_x,
154-
center_y,
183+
args.num_threads,
155184
num_exp=args.num_exp,
156185
)
157-
report_stats(
158-
times, prefix=f"decoder_native_crop({dims}, {center_x}, {center_y})"
159-
)
186+
report_stats(times, prefix=f"decoder_crop({dims})")
187+
160188
print()
161189

162190

0 commit comments

Comments
 (0)