Skip to content

Commit 93fff37

Browse files
authored
More robust key frame index setting (#489)
1 parent eb88cd6 commit 93fff37

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
608608
// we have scanned all packets and sorted by pts.
609609
FrameInfo frameInfo = {packet->pts};
610610
if (packet->flags & AV_PKT_FLAG_KEY) {
611+
frameInfo.isKeyFrame = true;
611612
streamInfos_[streamIndex].keyFrames.push_back(frameInfo);
612613
}
613614
streamInfos_[streamIndex].allFrames.push_back(frameInfo);
@@ -658,25 +659,23 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
658659
return frameInfo1.pts < frameInfo2.pts;
659660
});
660661

661-
size_t keyIndex = 0;
662+
size_t keyFrameIndex = 0;
662663
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
663664
streamInfo.allFrames[i].frameIndex = i;
664-
665-
// For correctly encoded files, we shouldn't need to ensure that keyIndex
666-
// is less than the number of key frames. That is, the relationship
667-
// between the frames in allFrames and keyFrames should be such that
668-
// keyIndex is always a valid index into keyFrames. But we're being
669-
// defensive in case we encounter incorrectly encoded files.
670-
if (keyIndex < streamInfo.keyFrames.size() &&
671-
streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) {
672-
streamInfo.keyFrames[keyIndex].frameIndex = i;
673-
++keyIndex;
665+
if (streamInfo.allFrames[i].isKeyFrame) {
666+
TORCH_CHECK(
667+
keyFrameIndex < streamInfo.keyFrames.size(),
668+
"The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec.");
669+
streamInfo.keyFrames[keyFrameIndex].frameIndex = i;
670+
++keyFrameIndex;
674671
}
675-
676672
if (i + 1 < streamInfo.allFrames.size()) {
677673
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
678674
}
679675
}
676+
TORCH_CHECK(
677+
keyFrameIndex == streamInfo.keyFrames.size(),
678+
"The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec.");
680679
}
681680

682681
scannedAllStreams_ = true;

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,20 @@ class VideoDecoder {
294294
// FrameInfo structs with *increasing* nextPts values. That's a necessary
295295
// condition for the binary searches on those values to work properly (as
296296
// typically done during pts -> index conversions).
297+
// TODO: This field is unset (left to the default) for entries in the
298+
// keyFrames vec!
297299
int64_t nextPts = INT64_MAX;
298300

299301
// Note that frameIndex is ALWAYS the index into all of the frames in that
300302
// stream, even when the FrameInfo is part of the key frame index. Given a
301303
// FrameInfo for a key frame, the frameIndex allows us to know which frame
302304
// that is in the stream.
303305
int64_t frameIndex = 0;
306+
307+
// Indicates whether a frame is a key frame. It may appear redundant as it's
308+
// only true for FrameInfos in the keyFrames index, but it is needed to
309+
// correctly map frames between allFrames and keyFrames during the scan.
310+
bool isKeyFrame = false;
304311
};
305312

306313
struct FilterGraphContext {

0 commit comments

Comments
 (0)