Skip to content

Commit cff9492

Browse files
authored
Add Random time-based sampler (#255)
1 parent 6cec00c commit cff9492

File tree

4 files changed

+218
-59
lines changed

4 files changed

+218
-59
lines changed

benchmarks/samplers/benchmark_samplers.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33

44
import torch
55
from torchcodec.decoders import VideoDecoder
6-
from torchcodec.samplers import clips_at_random_indices
6+
from torchcodec.samplers import (
7+
clips_at_random_indices,
8+
clips_at_random_timestamps,
9+
clips_at_regular_indices,
10+
clips_at_regular_timestamps,
11+
)
712

813

914
def bench(f, *args, num_exp=100, warmup=0, **kwargs):
@@ -34,19 +39,51 @@ def report_stats(times, unit="ms"):
3439
return med
3540

3641

37-
def sample(num_clips):
42+
def sample(sampler, **kwargs):
3843
decoder = VideoDecoder(VIDEO_PATH)
39-
clips_at_random_indices(
44+
sampler(
4045
decoder,
41-
num_clips=num_clips,
4246
num_frames_per_clip=10,
43-
num_indices_between_frames=2,
47+
**kwargs,
4448
)
4549

4650

4751
VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
52+
NUM_EXP = 30
53+
54+
for num_clips in (1, 50):
55+
print("-" * 10)
56+
print(f"{num_clips = }")
57+
58+
print("clips_at_random_indices ", end="")
59+
times = bench(
60+
sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
61+
)
62+
report_stats(times, unit="ms")
63+
64+
print("clips_at_regular_indices ", end="")
65+
times = bench(
66+
sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2
67+
)
68+
report_stats(times, unit="ms")
4869

49-
times = bench(sample, num_clips=1, num_exp=30, warmup=2)
50-
report_stats(times, unit="ms")
51-
times = bench(sample, num_clips=50, num_exp=30, warmup=2)
52-
report_stats(times, unit="ms")
70+
print("clips_at_random_timestamps ", end="")
71+
times = bench(
72+
sample,
73+
clips_at_random_timestamps,
74+
num_clips=num_clips,
75+
num_exp=NUM_EXP,
76+
warmup=2,
77+
)
78+
report_stats(times, unit="ms")
79+
80+
print("clips_at_regular_timestamps ", end="")
81+
seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long
82+
times = bench(
83+
sample,
84+
clips_at_regular_timestamps,
85+
seconds_between_clip_starts=seconds_between_clip_starts,
86+
num_exp=NUM_EXP,
87+
warmup=2,
88+
)
89+
report_stats(times, unit="ms")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from ._implem import (
22
clips_at_random_indices,
3+
clips_at_random_timestamps,
34
clips_at_regular_indices,
45
clips_at_regular_timestamps,
56
)

src/torchcodec/samplers/_implem.py

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _validate_params(*, decoder, num_frames_per_clip, policy):
7878

7979
def _validate_params_index_based(*, num_clips, num_indices_between_frames):
8080
if num_clips <= 0:
81-
raise ValueError(f"num_clips ({num_clips}) must be strictly positive")
81+
raise ValueError(f"num_clips ({num_clips}) must be > 0")
8282

8383
if num_indices_between_frames <= 0:
8484
raise ValueError(
@@ -339,14 +339,24 @@ def clips_at_regular_indices(
339339
def _validate_params_time_based(
340340
*,
341341
decoder,
342+
num_clips,
342343
seconds_between_clip_starts,
343344
seconds_between_frames,
344345
):
345-
if seconds_between_clip_starts <= 0:
346+
347+
if (num_clips is None and seconds_between_clip_starts is None) or (
348+
num_clips is not None and seconds_between_clip_starts is not None
349+
):
350+
raise ValueError("This is internal only and should never happen.")
351+
352+
if seconds_between_clip_starts is not None and seconds_between_clip_starts <= 0:
346353
raise ValueError(
347354
f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0"
348355
)
349356

357+
if num_clips is not None and num_clips <= 0:
358+
raise ValueError(f"num_clips ({num_clips}) must be > 0")
359+
350360
if decoder.metadata.average_fps is None:
351361
raise ValueError(
352362
"Could not infer average fps from video metadata. "
@@ -480,6 +490,13 @@ def _decode_all_clips_timestamps(
480490
and frame_pts_seconds == all_clips_timestamps_sorted[i - 1]
481491
):
482492
# Avoid decoding the same frame twice.
493+
# Unfortunatly this is unlikely to lead to speed-up as-is: it's
494+
# pretty unlikely that 2 pts will be the same since pts are float
495+
# contiguous values. Theoretically the dedup can still happen, but
496+
# it would be much more efficient to implement it at the frame index
497+
# level. We should do that once we implement that in C++.
498+
# See also https://github.com/pytorch/torchcodec/issues/256.
499+
#
483500
# IMPORTANT: this is only correct because a copy of the frame will
484501
# happen within `_to_framebatch` when we call torch.stack.
485502
# If a copy isn't made, the same underlying memory will be used for
@@ -498,15 +515,17 @@ def _decode_all_clips_timestamps(
498515
return [_to_framebatch(clip) for clip in all_clips]
499516

500517

501-
def clips_at_regular_timestamps(
518+
def _generic_time_based_sampler(
519+
kind: Literal["random", "regular"],
502520
decoder,
503521
*,
504-
seconds_between_clip_starts: float,
505-
num_frames_per_clip: int = 1,
506-
seconds_between_frames: Optional[float] = None,
522+
num_clips: Optional[int], # mutually exclusive with seconds_between_clip_starts
523+
seconds_between_clip_starts: Optional[float],
524+
num_frames_per_clip: int,
525+
seconds_between_frames: Optional[float],
507526
# None means "begining", which may not always be 0
508-
sampling_range_start: Optional[float] = None,
509-
sampling_range_end: Optional[float] = None, # interval is [start, end).
527+
sampling_range_start: Optional[float],
528+
sampling_range_end: Optional[float], # interval is [start, end).
510529
policy: str = "repeat_last",
511530
) -> List[FrameBatch]:
512531
# Note: *everywhere*, sampling_range_end denotes the upper bound of where a
@@ -521,6 +540,7 @@ def clips_at_regular_timestamps(
521540

522541
seconds_between_frames = _validate_params_time_based(
523542
decoder=decoder,
543+
num_clips=num_clips,
524544
seconds_between_clip_starts=seconds_between_clip_starts,
525545
seconds_between_frames=seconds_between_frames,
526546
)
@@ -534,11 +554,21 @@ def clips_at_regular_timestamps(
534554
end_stream_seconds=decoder.metadata.end_stream_seconds,
535555
)
536556

537-
clip_start_seconds = torch.arange(
538-
sampling_range_start,
539-
sampling_range_end, # excluded
540-
seconds_between_clip_starts,
541-
)
557+
if kind == "random":
558+
assert num_clips is not None # appease type-checker
559+
sampling_range_width = sampling_range_end - sampling_range_start
560+
# torch.rand() returns in [0, 1)
561+
# which ensures all clip starts are < sampling_range_end
562+
clip_start_seconds = (
563+
torch.rand(num_clips) * sampling_range_width + sampling_range_start
564+
)
565+
else:
566+
assert seconds_between_clip_starts is not None # appease type-checker
567+
clip_start_seconds = torch.arange(
568+
sampling_range_start,
569+
sampling_range_end, # excluded
570+
seconds_between_clip_starts,
571+
)
542572

543573
all_clips_timestamps = _build_all_clips_timestamps(
544574
clip_start_seconds=clip_start_seconds,
@@ -553,3 +583,51 @@ def clips_at_regular_timestamps(
553583
all_clips_timestamps=all_clips_timestamps,
554584
num_frames_per_clip=num_frames_per_clip,
555585
)
586+
587+
588+
def clips_at_random_timestamps(
589+
decoder,
590+
*,
591+
num_clips: int = 1,
592+
num_frames_per_clip: int = 1,
593+
seconds_between_frames: Optional[float] = None,
594+
# None means "begining", which may not always be 0
595+
sampling_range_start: Optional[float] = None,
596+
sampling_range_end: Optional[float] = None, # interval is [start, end).
597+
policy: str = "repeat_last",
598+
) -> List[FrameBatch]:
599+
return _generic_time_based_sampler(
600+
kind="random",
601+
decoder=decoder,
602+
num_clips=num_clips,
603+
seconds_between_clip_starts=None,
604+
num_frames_per_clip=num_frames_per_clip,
605+
seconds_between_frames=seconds_between_frames,
606+
sampling_range_start=sampling_range_start,
607+
sampling_range_end=sampling_range_end,
608+
policy=policy,
609+
)
610+
611+
612+
def clips_at_regular_timestamps(
613+
decoder,
614+
*,
615+
seconds_between_clip_starts: float,
616+
num_frames_per_clip: int = 1,
617+
seconds_between_frames: Optional[float] = None,
618+
# None means "begining", which may not always be 0
619+
sampling_range_start: Optional[float] = None,
620+
sampling_range_end: Optional[float] = None, # interval is [start, end).
621+
policy: str = "repeat_last",
622+
) -> List[FrameBatch]:
623+
return _generic_time_based_sampler(
624+
kind="regular",
625+
decoder=decoder,
626+
num_clips=None,
627+
seconds_between_clip_starts=seconds_between_clip_starts,
628+
num_frames_per_clip=num_frames_per_clip,
629+
seconds_between_frames=seconds_between_frames,
630+
sampling_range_start=sampling_range_start,
631+
sampling_range_end=sampling_range_end,
632+
policy=policy,
633+
)

0 commit comments

Comments
 (0)