Skip to content

Commit a713765

Browse files
authored
Fix audio seeks bugs (#599)
1 parent 57899ee commit a713765

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -890,16 +890,15 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
890890
std::optional<double> stopSecondsOptional) {
891891
validateActiveStream(AVMEDIA_TYPE_AUDIO);
892892

893-
double stopSeconds =
894-
stopSecondsOptional.value_or(std::numeric_limits<double>::max());
895-
896-
TORCH_CHECK(
897-
startSeconds <= stopSeconds,
898-
"Start seconds (" + std::to_string(startSeconds) +
899-
") must be less than or equal to stop seconds (" +
900-
std::to_string(stopSeconds) + ").");
893+
if (stopSecondsOptional.has_value()) {
894+
TORCH_CHECK(
895+
startSeconds <= *stopSecondsOptional,
896+
"Start seconds (" + std::to_string(startSeconds) +
897+
") must be less than or equal to stop seconds (" +
898+
std::to_string(*stopSecondsOptional) + ").");
899+
}
901900

902-
if (startSeconds == stopSeconds) {
901+
if (stopSecondsOptional.has_value() && startSeconds == *stopSecondsOptional) {
903902
// For consistency with video
904903
return AudioFramesOutput{torch::empty({0, 0}), 0.0};
905904
}
@@ -912,7 +911,7 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
912911
// If we need to seek backwards, then we have to seek back to the beginning
913912
// of the stream.
914913
// See [Audio Decoding Design].
915-
setCursorPtsInSecondsInternal(INT64_MIN);
914+
setCursor(INT64_MIN);
916915
}
917916

918917
// TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec +
@@ -921,7 +920,9 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
921920
std::vector<torch::Tensor> frames;
922921

923922
std::optional<double> firstFramePtsSeconds = std::nullopt;
924-
auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase);
923+
auto stopPts = stopSecondsOptional.has_value()
924+
? secondsToClosestPts(*stopSecondsOptional, streamInfo.timeBase)
925+
: INT64_MAX;
925926
auto finished = false;
926927
while (!finished) {
927928
try {
@@ -971,13 +972,13 @@ void VideoDecoder::setCursorPtsInSeconds(double seconds) {
971972
// We don't allow public audio decoding APIs to seek, see [Audio Decoding
972973
// Design]
973974
validateActiveStream(AVMEDIA_TYPE_VIDEO);
974-
setCursorPtsInSecondsInternal(seconds);
975+
setCursor(
976+
secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase));
975977
}
976978

977-
void VideoDecoder::setCursorPtsInSecondsInternal(double seconds) {
979+
void VideoDecoder::setCursor(int64_t pts) {
978980
cursorWasJustSet_ = true;
979-
cursor_ =
980-
secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase);
981+
cursor_ = pts;
981982
}
982983

983984
/*

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ class VideoDecoder {
367367
// DECODING APIS AND RELATED UTILS
368368
// --------------------------------------------------------------------------
369369

370-
void setCursorPtsInSecondsInternal(double seconds);
370+
void setCursor(int64_t pts);
371371
bool canWeAvoidSeeking() const;
372372

373373
void maybeSeekToBeforeDesiredPts();

0 commit comments

Comments
 (0)