Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,23 @@ function(make_torchcodec_libraries
# stray initialization of py::objects. The rest of the object code must
# match. See:
# https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes
if(NOT WIN32)
# We have to do this for both pybind_ops and custom_ops because they include
# some of the same headers.
#
# Note that this is Linux only. It's not necessary on Windows, and on Mac
# hiding visibility can actually break dyanmic casts across share libraries
# because the type infos don't get exported.
if(LINUX)
target_compile_options(
${pybind_ops_library_name}
PUBLIC
"-fvisibility=hidden"
)
target_compile_options(
${custom_ops_library_name}
PUBLIC
"-fvisibility=hidden"
)
endif()

# The value we use here must match the value we return from
Expand Down
12 changes: 0 additions & 12 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,16 +1704,4 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame(
videoStreamOptions.width.value_or(avFrame->width));
}

SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
if (seekMode == "exact") {
return SingleStreamDecoder::SeekMode::exact;
} else if (seekMode == "approximate") {
return SingleStreamDecoder::SeekMode::approximate;
} else if (seekMode == "custom_frame_mappings") {
return SingleStreamDecoder::SeekMode::custom_frame_mappings;
} else {
TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
}
}

} // namespace facebook::torchcodec
2 changes: 0 additions & 2 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,4 @@ std::ostream& operator<<(
std::ostream& os,
const SingleStreamDecoder::DecodeStats& stats);

SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode);

} // namespace facebook::torchcodec
77 changes: 70 additions & 7 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <string>
#include "c10/core/SymIntArrayRef.h"
#include "c10/util/Exception.h"
#include "src/torchcodec/_core/AVIOFileLikeContext.h"
#include "src/torchcodec/_core/AVIOTensorContext.h"
#include "src/torchcodec/_core/Encoder.h"
#include "src/torchcodec/_core/SingleStreamDecoder.h"
Expand All @@ -35,9 +36,12 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()");
m.def(
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
m.def(
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
m.def(
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
m.def(
"_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor");
m.def(
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
m.def(
Expand Down Expand Up @@ -167,6 +171,18 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
return ss.str();
}

SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
if (seekMode == "exact") {
return SingleStreamDecoder::SeekMode::exact;
} else if (seekMode == "approximate") {
return SingleStreamDecoder::SeekMode::approximate;
} else if (seekMode == "custom_frame_mappings") {
return SingleStreamDecoder::SeekMode::custom_frame_mappings;
} else {
TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
}
}

} // namespace

// ==============================
Expand Down Expand Up @@ -205,16 +221,32 @@ at::Tensor create_from_tensor(
realSeek = seekModeFromString(seek_mode.value());
}

auto contextHolder = std::make_unique<AVIOFromTensorContext>(video_tensor);
auto avioContextHolder =
std::make_unique<AVIOFromTensorContext>(video_tensor);

std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
std::make_unique<SingleStreamDecoder>(std::move(contextHolder), realSeek);
std::make_unique<SingleStreamDecoder>(
std::move(avioContextHolder), realSeek);
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
}

at::Tensor _convert_to_tensor(int64_t decoder_ptr) {
auto decoder = reinterpret_cast<SingleStreamDecoder*>(decoder_ptr);
std::unique_ptr<SingleStreamDecoder> uniqueDecoder(decoder);
at::Tensor _create_from_file_like(
int64_t file_like_context,
std::optional<std::string_view> seek_mode) {
auto fileLikeContext =
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
TORCH_CHECK(
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);

SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
if (seek_mode.has_value()) {
realSeek = seekModeFromString(seek_mode.value());
}

std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
std::make_unique<SingleStreamDecoder>(
std::move(avioContextHolder), realSeek);
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
}

Expand Down Expand Up @@ -456,6 +488,36 @@ at::Tensor encode_audio_to_tensor(
.encodeToTensor();
}

void _encode_audio_to_file_like(
const at::Tensor& samples,
int64_t sample_rate,
std::string_view format,
int64_t file_like_context,
std::optional<int64_t> bit_rate = std::nullopt,
std::optional<int64_t> num_channels = std::nullopt,
std::optional<int64_t> desired_sample_rate = std::nullopt) {
auto fileLikeContext =
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
TORCH_CHECK(
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);

AudioStreamOptions audioStreamOptions;
audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
audioStreamOptions.numChannels =
validateOptionalInt64ToInt(num_channels, "num_channels");
audioStreamOptions.sampleRate =
validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");

AudioEncoder encoder(
samples,
validateInt64ToInt(sample_rate, "sample_rate"),
format,
std::move(avioContextHolder),
audioStreamOptions);
encoder.encode();
}

// For testing only. We need to implement this operation as a core library
// function because what we're testing is round-tripping pts values as
// double-precision floating point numbers from C++ to Python and back to C++.
Expand Down Expand Up @@ -709,7 +771,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
m.impl("create_from_file", &create_from_file);
m.impl("create_from_tensor", &create_from_tensor);
m.impl("_convert_to_tensor", &_convert_to_tensor);
m.impl("_create_from_file_like", &_create_from_file_like);
m.impl(
"_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
}
Expand All @@ -718,6 +780,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("encode_audio_to_file", &encode_audio_to_file);
m.impl("encode_video_to_file", &encode_video_to_file);
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
m.impl("seek_to_pts", &seek_to_pts);
m.impl("add_video_stream", &add_video_stream);
m.impl("_add_video_stream", &_add_video_stream);
Expand Down
65 changes: 34 additions & 31 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,14 @@ def load_torchcodec_shared_libraries():
encode_audio_to_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.encode_audio_to_tensor.default
)
_encode_audio_to_file_like = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns._encode_audio_to_file_like.default
)
create_from_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns.create_from_tensor.default
)
_convert_to_tensor = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns._convert_to_tensor.default
_create_from_file_like = torch._dynamo.disallow_in_graph(
torch.ops.torchcodec_ns._create_from_file_like.default
)
add_video_stream = torch.ops.torchcodec_ns.add_video_stream.default
_add_video_stream = torch.ops.torchcodec_ns._add_video_stream.default
Expand Down Expand Up @@ -151,7 +154,12 @@ def create_from_file_like(
file_like: Union[io.RawIOBase, io.BufferedReader], seek_mode: Optional[str] = None
) -> torch.Tensor:
assert _pybind_ops is not None
return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode))
return _create_from_file_like(
_pybind_ops.create_file_like_context(
file_like, False # False means not for writing
),
seek_mode,
)


def encode_audio_to_file_like(
Expand Down Expand Up @@ -179,36 +187,16 @@ def encode_audio_to_file_like(
if samples.dtype != torch.float32:
raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}")

# We're having the same problem as with the decoder's create_from_file_like:
# We should be able to pass a tensor directly, but this leads to a pybind
# error. In order to work around this, we pass the pointer to the tensor's
# data, and its shape, in order to re-construct it in C++. For this to work:
# - the tensor must be float32
# - the tensor must be contiguous, which is why we call contiguous().
# In theory we could avoid this restriction by also passing the strides?
# - IMPORTANT: the input samples tensor and its underlying data must be
# alive during the call.
#
# A more elegant solution would be to cast the tensor into a py::object, but
# casting the py::object backk to a tensor in C++ seems to lead to the same
# pybing error.

samples = samples.contiguous()
_pybind_ops.encode_audio_to_file_like(
samples.data_ptr(),
list(samples.shape),
_encode_audio_to_file_like(
samples,
sample_rate,
format,
file_like,
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
bit_rate,
num_channels,
desired_sample_rate,
)

# This check is useless but it's critical to keep it to ensures that samples
# is still alive during the call to encode_audio_to_file_like.
assert samples.is_contiguous()


# ==============================
# Abstract impl for the operators. Needed by torch.compile.
Expand All @@ -218,6 +206,13 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
return torch.empty([], dtype=torch.long)


@register_fake("torchcodec_ns::_create_from_file_like")
def _create_from_file_like_abstract(
file_like: int, seek_mode: Optional[str]
) -> torch.Tensor:
return torch.empty([], dtype=torch.long)


@register_fake("torchcodec_ns::encode_audio_to_file")
def encode_audio_to_file_abstract(
samples: torch.Tensor,
Expand Down Expand Up @@ -251,18 +246,26 @@ def encode_audio_to_tensor_abstract(
return torch.empty([], dtype=torch.long)


@register_fake("torchcodec_ns::_encode_audio_to_file_like")
def _encode_audio_to_file_like_abstract(
samples: torch.Tensor,
sample_rate: int,
format: str,
file_like_context: int,
bit_rate: Optional[int] = None,
num_channels: Optional[int] = None,
desired_sample_rate: Optional[int] = None,
) -> None:
return


@register_fake("torchcodec_ns::create_from_tensor")
def create_from_tensor_abstract(
video_tensor: torch.Tensor, seek_mode: Optional[str]
) -> torch.Tensor:
return torch.empty([], dtype=torch.long)


@register_fake("torchcodec_ns::_convert_to_tensor")
def _convert_to_tensor_abstract(decoder_ptr: int) -> torch.Tensor:
return torch.empty([], dtype=torch.long)


@register_fake("torchcodec_ns::_add_video_stream")
def _add_video_stream_abstract(
decoder: torch.Tensor,
Expand Down
Loading
Loading