@@ -78,7 +78,7 @@ def _validate_params(*, decoder, num_frames_per_clip, policy):
7878
7979def _validate_params_index_based (* , num_clips , num_indices_between_frames ):
8080 if num_clips <= 0 :
81- raise ValueError (f"num_clips ({ num_clips } ) must be strictly positive " )
81+ raise ValueError (f"num_clips ({ num_clips } ) must be > 0 " )
8282
8383 if num_indices_between_frames <= 0 :
8484 raise ValueError (
@@ -339,14 +339,24 @@ def clips_at_regular_indices(
339339def _validate_params_time_based (
340340 * ,
341341 decoder ,
342+ num_clips ,
342343 seconds_between_clip_starts ,
343344 seconds_between_frames ,
344345):
345- if seconds_between_clip_starts <= 0 :
346+
347+ if (num_clips is None and seconds_between_clip_starts is None ) or (
348+ num_clips is not None and seconds_between_clip_starts is not None
349+ ):
350+ raise ValueError ("This is internal only and should never happen." )
351+
352+ if seconds_between_clip_starts is not None and seconds_between_clip_starts <= 0 :
346353 raise ValueError (
347354 f"seconds_between_clip_starts ({ seconds_between_clip_starts } ) must be > 0"
348355 )
349356
357+ if num_clips is not None and num_clips <= 0 :
358+ raise ValueError (f"num_clips ({ num_clips } ) must be > 0" )
359+
350360 if decoder .metadata .average_fps is None :
351361 raise ValueError (
352362 "Could not infer average fps from video metadata. "
@@ -480,6 +490,13 @@ def _decode_all_clips_timestamps(
480490 and frame_pts_seconds == all_clips_timestamps_sorted [i - 1 ]
481491 ):
482492 # Avoid decoding the same frame twice.
493+ # Unfortunatly this is unlikely to lead to speed-up as-is: it's
494+ # pretty unlikely that 2 pts will be the same since pts are float
495+ # contiguous values. Theoretically the dedup can still happen, but
496+ # it would be much more efficient to implement it at the frame index
497+ # level. We should do that once we implement that in C++.
498+ # See also https://github.com/pytorch/torchcodec/issues/256.
499+ #
483500 # IMPORTANT: this is only correct because a copy of the frame will
484501 # happen within `_to_framebatch` when we call torch.stack.
485502 # If a copy isn't made, the same underlying memory will be used for
@@ -498,15 +515,17 @@ def _decode_all_clips_timestamps(
498515 return [_to_framebatch (clip ) for clip in all_clips ]
499516
500517
501- def clips_at_regular_timestamps (
518+ def _generic_time_based_sampler (
519+ kind : Literal ["random" , "regular" ],
502520 decoder ,
503521 * ,
504- seconds_between_clip_starts : float ,
505- num_frames_per_clip : int = 1 ,
506- seconds_between_frames : Optional [float ] = None ,
522+ num_clips : Optional [int ], # mutually exclusive with seconds_between_clip_starts
523+ seconds_between_clip_starts : Optional [float ],
524+ num_frames_per_clip : int ,
525+ seconds_between_frames : Optional [float ],
507526 # None means "begining", which may not always be 0
508- sampling_range_start : Optional [float ] = None ,
509- sampling_range_end : Optional [float ] = None , # interval is [start, end).
527+ sampling_range_start : Optional [float ],
528+ sampling_range_end : Optional [float ], # interval is [start, end).
510529 policy : str = "repeat_last" ,
511530) -> List [FrameBatch ]:
512531 # Note: *everywhere*, sampling_range_end denotes the upper bound of where a
@@ -521,6 +540,7 @@ def clips_at_regular_timestamps(
521540
522541 seconds_between_frames = _validate_params_time_based (
523542 decoder = decoder ,
543+ num_clips = num_clips ,
524544 seconds_between_clip_starts = seconds_between_clip_starts ,
525545 seconds_between_frames = seconds_between_frames ,
526546 )
@@ -534,11 +554,21 @@ def clips_at_regular_timestamps(
534554 end_stream_seconds = decoder .metadata .end_stream_seconds ,
535555 )
536556
537- clip_start_seconds = torch .arange (
538- sampling_range_start ,
539- sampling_range_end , # excluded
540- seconds_between_clip_starts ,
541- )
557+ if kind == "random" :
558+ assert num_clips is not None # appease type-checker
559+ sampling_range_width = sampling_range_end - sampling_range_start
560+ # torch.rand() returns in [0, 1)
561+ # which ensures all clip starts are < sampling_range_end
562+ clip_start_seconds = (
563+ torch .rand (num_clips ) * sampling_range_width + sampling_range_start
564+ )
565+ else :
566+ assert seconds_between_clip_starts is not None # appease type-checker
567+ clip_start_seconds = torch .arange (
568+ sampling_range_start ,
569+ sampling_range_end , # excluded
570+ seconds_between_clip_starts ,
571+ )
542572
543573 all_clips_timestamps = _build_all_clips_timestamps (
544574 clip_start_seconds = clip_start_seconds ,
@@ -553,3 +583,51 @@ def clips_at_regular_timestamps(
553583 all_clips_timestamps = all_clips_timestamps ,
554584 num_frames_per_clip = num_frames_per_clip ,
555585 )
586+
587+
588+ def clips_at_random_timestamps (
589+ decoder ,
590+ * ,
591+ num_clips : int = 1 ,
592+ num_frames_per_clip : int = 1 ,
593+ seconds_between_frames : Optional [float ] = None ,
594+ # None means "begining", which may not always be 0
595+ sampling_range_start : Optional [float ] = None ,
596+ sampling_range_end : Optional [float ] = None , # interval is [start, end).
597+ policy : str = "repeat_last" ,
598+ ) -> List [FrameBatch ]:
599+ return _generic_time_based_sampler (
600+ kind = "random" ,
601+ decoder = decoder ,
602+ num_clips = num_clips ,
603+ seconds_between_clip_starts = None ,
604+ num_frames_per_clip = num_frames_per_clip ,
605+ seconds_between_frames = seconds_between_frames ,
606+ sampling_range_start = sampling_range_start ,
607+ sampling_range_end = sampling_range_end ,
608+ policy = policy ,
609+ )
610+
611+
612+ def clips_at_regular_timestamps (
613+ decoder ,
614+ * ,
615+ seconds_between_clip_starts : float ,
616+ num_frames_per_clip : int = 1 ,
617+ seconds_between_frames : Optional [float ] = None ,
618+ # None means "begining", which may not always be 0
619+ sampling_range_start : Optional [float ] = None ,
620+ sampling_range_end : Optional [float ] = None , # interval is [start, end).
621+ policy : str = "repeat_last" ,
622+ ) -> List [FrameBatch ]:
623+ return _generic_time_based_sampler (
624+ kind = "regular" ,
625+ decoder = decoder ,
626+ num_clips = None ,
627+ seconds_between_clip_starts = seconds_between_clip_starts ,
628+ num_frames_per_clip = num_frames_per_clip ,
629+ seconds_between_frames = seconds_between_frames ,
630+ sampling_range_start = sampling_range_start ,
631+ sampling_range_end = sampling_range_end ,
632+ policy = policy ,
633+ )
0 commit comments