Skip to content

Commit ab026c8

Browse files
authored
Refactor pybind_ops to only deal with file like context holders (#889)
1 parent adc6299 commit ab026c8

File tree

6 files changed

+138
-112
lines changed

6 files changed

+138
-112
lines changed

src/torchcodec/_core/CMakeLists.txt

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,23 @@ function(make_torchcodec_libraries
175175
# stray initialization of py::objects. The rest of the object code must
176176
# match. See:
177177
# https://pybind11.readthedocs.io/en/stable/faq.html#someclass-declared-with-greater-visibility-than-the-type-of-its-field-someclass-member-wattributes
178-
if(NOT WIN32)
178+
# We have to do this for both pybind_ops and custom_ops because they include
179+
# some of the same headers.
180+
#
181+
# Note that this is Linux only. It's not necessary on Windows, and on Mac
182+
# hiding visibility can actually break dyanmic casts across share libraries
183+
# because the type infos don't get exported.
184+
if(LINUX)
179185
target_compile_options(
180186
${pybind_ops_library_name}
181187
PUBLIC
182188
"-fvisibility=hidden"
183189
)
190+
target_compile_options(
191+
${custom_ops_library_name}
192+
PUBLIC
193+
"-fvisibility=hidden"
194+
)
184195
endif()
185196

186197
# The value we use here must match the value we return from

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,16 +1704,4 @@ FrameDims getHeightAndWidthFromOptionsOrAVFrame(
17041704
videoStreamOptions.width.value_or(avFrame->width));
17051705
}
17061706

1707-
SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
1708-
if (seekMode == "exact") {
1709-
return SingleStreamDecoder::SeekMode::exact;
1710-
} else if (seekMode == "approximate") {
1711-
return SingleStreamDecoder::SeekMode::approximate;
1712-
} else if (seekMode == "custom_frame_mappings") {
1713-
return SingleStreamDecoder::SeekMode::custom_frame_mappings;
1714-
} else {
1715-
TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
1716-
}
1717-
}
1718-
17191707
} // namespace facebook::torchcodec

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,4 @@ std::ostream& operator<<(
375375
std::ostream& os,
376376
const SingleStreamDecoder::DecodeStats& stats);
377377

378-
SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode);
379-
380378
} // namespace facebook::torchcodec

src/torchcodec/_core/custom_ops.cpp

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <string>
1111
#include "c10/core/SymIntArrayRef.h"
1212
#include "c10/util/Exception.h"
13+
#include "src/torchcodec/_core/AVIOFileLikeContext.h"
1314
#include "src/torchcodec/_core/AVIOTensorContext.h"
1415
#include "src/torchcodec/_core/Encoder.h"
1516
#include "src/torchcodec/_core/SingleStreamDecoder.h"
@@ -35,9 +36,12 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3536
"encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()");
3637
m.def(
3738
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
39+
m.def(
40+
"_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3841
m.def(
3942
"create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor");
40-
m.def("_convert_to_tensor(int decoder_ptr) -> Tensor");
43+
m.def(
44+
"_create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor");
4145
m.def(
4246
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()");
4347
m.def(
@@ -167,6 +171,18 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
167171
return ss.str();
168172
}
169173

174+
SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
175+
if (seekMode == "exact") {
176+
return SingleStreamDecoder::SeekMode::exact;
177+
} else if (seekMode == "approximate") {
178+
return SingleStreamDecoder::SeekMode::approximate;
179+
} else if (seekMode == "custom_frame_mappings") {
180+
return SingleStreamDecoder::SeekMode::custom_frame_mappings;
181+
} else {
182+
TORCH_CHECK(false, "Invalid seek mode: " + std::string(seekMode));
183+
}
184+
}
185+
170186
} // namespace
171187

172188
// ==============================
@@ -205,16 +221,32 @@ at::Tensor create_from_tensor(
205221
realSeek = seekModeFromString(seek_mode.value());
206222
}
207223

208-
auto contextHolder = std::make_unique<AVIOFromTensorContext>(video_tensor);
224+
auto avioContextHolder =
225+
std::make_unique<AVIOFromTensorContext>(video_tensor);
209226

210227
std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
211-
std::make_unique<SingleStreamDecoder>(std::move(contextHolder), realSeek);
228+
std::make_unique<SingleStreamDecoder>(
229+
std::move(avioContextHolder), realSeek);
212230
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
213231
}
214232

215-
at::Tensor _convert_to_tensor(int64_t decoder_ptr) {
216-
auto decoder = reinterpret_cast<SingleStreamDecoder*>(decoder_ptr);
217-
std::unique_ptr<SingleStreamDecoder> uniqueDecoder(decoder);
233+
at::Tensor _create_from_file_like(
234+
int64_t file_like_context,
235+
std::optional<std::string_view> seek_mode) {
236+
auto fileLikeContext =
237+
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
238+
TORCH_CHECK(
239+
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
240+
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
241+
242+
SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
243+
if (seek_mode.has_value()) {
244+
realSeek = seekModeFromString(seek_mode.value());
245+
}
246+
247+
std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
248+
std::make_unique<SingleStreamDecoder>(
249+
std::move(avioContextHolder), realSeek);
218250
return wrapDecoderPointerToTensor(std::move(uniqueDecoder));
219251
}
220252

@@ -456,6 +488,36 @@ at::Tensor encode_audio_to_tensor(
456488
.encodeToTensor();
457489
}
458490

491+
void _encode_audio_to_file_like(
492+
const at::Tensor& samples,
493+
int64_t sample_rate,
494+
std::string_view format,
495+
int64_t file_like_context,
496+
std::optional<int64_t> bit_rate = std::nullopt,
497+
std::optional<int64_t> num_channels = std::nullopt,
498+
std::optional<int64_t> desired_sample_rate = std::nullopt) {
499+
auto fileLikeContext =
500+
reinterpret_cast<AVIOFileLikeContext*>(file_like_context);
501+
TORCH_CHECK(
502+
fileLikeContext != nullptr, "file_like_context must be a valid pointer");
503+
std::unique_ptr<AVIOFileLikeContext> avioContextHolder(fileLikeContext);
504+
505+
AudioStreamOptions audioStreamOptions;
506+
audioStreamOptions.bitRate = validateOptionalInt64ToInt(bit_rate, "bit_rate");
507+
audioStreamOptions.numChannels =
508+
validateOptionalInt64ToInt(num_channels, "num_channels");
509+
audioStreamOptions.sampleRate =
510+
validateOptionalInt64ToInt(desired_sample_rate, "desired_sample_rate");
511+
512+
AudioEncoder encoder(
513+
samples,
514+
validateInt64ToInt(sample_rate, "sample_rate"),
515+
format,
516+
std::move(avioContextHolder),
517+
audioStreamOptions);
518+
encoder.encode();
519+
}
520+
459521
// For testing only. We need to implement this operation as a core library
460522
// function because what we're testing is round-tripping pts values as
461523
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -709,7 +771,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
709771
TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
710772
m.impl("create_from_file", &create_from_file);
711773
m.impl("create_from_tensor", &create_from_tensor);
712-
m.impl("_convert_to_tensor", &_convert_to_tensor);
774+
m.impl("_create_from_file_like", &_create_from_file_like);
713775
m.impl(
714776
"_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
715777
}
@@ -718,6 +780,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
718780
m.impl("encode_audio_to_file", &encode_audio_to_file);
719781
m.impl("encode_video_to_file", &encode_video_to_file);
720782
m.impl("encode_audio_to_tensor", &encode_audio_to_tensor);
783+
m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like);
721784
m.impl("seek_to_pts", &seek_to_pts);
722785
m.impl("add_video_stream", &add_video_stream);
723786
m.impl("_add_video_stream", &_add_video_stream);

src/torchcodec/_core/ops.py

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -98,11 +98,14 @@ def load_torchcodec_shared_libraries():
9898
encode_audio_to_tensor = torch._dynamo.disallow_in_graph(
9999
torch.ops.torchcodec_ns.encode_audio_to_tensor.default
100100
)
101+
_encode_audio_to_file_like = torch._dynamo.disallow_in_graph(
102+
torch.ops.torchcodec_ns._encode_audio_to_file_like.default
103+
)
101104
create_from_tensor = torch._dynamo.disallow_in_graph(
102105
torch.ops.torchcodec_ns.create_from_tensor.default
103106
)
104-
_convert_to_tensor = torch._dynamo.disallow_in_graph(
105-
torch.ops.torchcodec_ns._convert_to_tensor.default
107+
_create_from_file_like = torch._dynamo.disallow_in_graph(
108+
torch.ops.torchcodec_ns._create_from_file_like.default
106109
)
107110
add_video_stream = torch.ops.torchcodec_ns.add_video_stream.default
108111
_add_video_stream = torch.ops.torchcodec_ns._add_video_stream.default
@@ -151,7 +154,12 @@ def create_from_file_like(
151154
file_like: Union[io.RawIOBase, io.BufferedReader], seek_mode: Optional[str] = None
152155
) -> torch.Tensor:
153156
assert _pybind_ops is not None
154-
return _convert_to_tensor(_pybind_ops.create_from_file_like(file_like, seek_mode))
157+
return _create_from_file_like(
158+
_pybind_ops.create_file_like_context(
159+
file_like, False # False means not for writing
160+
),
161+
seek_mode,
162+
)
155163

156164

157165
def encode_audio_to_file_like(
@@ -179,36 +187,16 @@ def encode_audio_to_file_like(
179187
if samples.dtype != torch.float32:
180188
raise ValueError(f"samples must have dtype torch.float32, got {samples.dtype}")
181189

182-
# We're having the same problem as with the decoder's create_from_file_like:
183-
# We should be able to pass a tensor directly, but this leads to a pybind
184-
# error. In order to work around this, we pass the pointer to the tensor's
185-
# data, and its shape, in order to re-construct it in C++. For this to work:
186-
# - the tensor must be float32
187-
# - the tensor must be contiguous, which is why we call contiguous().
188-
# In theory we could avoid this restriction by also passing the strides?
189-
# - IMPORTANT: the input samples tensor and its underlying data must be
190-
# alive during the call.
191-
#
192-
# A more elegant solution would be to cast the tensor into a py::object, but
193-
# casting the py::object backk to a tensor in C++ seems to lead to the same
194-
# pybing error.
195-
196-
samples = samples.contiguous()
197-
_pybind_ops.encode_audio_to_file_like(
198-
samples.data_ptr(),
199-
list(samples.shape),
190+
_encode_audio_to_file_like(
191+
samples,
200192
sample_rate,
201193
format,
202-
file_like,
194+
_pybind_ops.create_file_like_context(file_like, True), # True means for writing
203195
bit_rate,
204196
num_channels,
205197
desired_sample_rate,
206198
)
207199

208-
# This check is useless but it's critical to keep it to ensures that samples
209-
# is still alive during the call to encode_audio_to_file_like.
210-
assert samples.is_contiguous()
211-
212200

213201
# ==============================
214202
# Abstract impl for the operators. Needed by torch.compile.
@@ -218,6 +206,13 @@ def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.
218206
return torch.empty([], dtype=torch.long)
219207

220208

209+
@register_fake("torchcodec_ns::_create_from_file_like")
210+
def _create_from_file_like_abstract(
211+
file_like: int, seek_mode: Optional[str]
212+
) -> torch.Tensor:
213+
return torch.empty([], dtype=torch.long)
214+
215+
221216
@register_fake("torchcodec_ns::encode_audio_to_file")
222217
def encode_audio_to_file_abstract(
223218
samples: torch.Tensor,
@@ -251,18 +246,26 @@ def encode_audio_to_tensor_abstract(
251246
return torch.empty([], dtype=torch.long)
252247

253248

249+
@register_fake("torchcodec_ns::_encode_audio_to_file_like")
250+
def _encode_audio_to_file_like_abstract(
251+
samples: torch.Tensor,
252+
sample_rate: int,
253+
format: str,
254+
file_like_context: int,
255+
bit_rate: Optional[int] = None,
256+
num_channels: Optional[int] = None,
257+
desired_sample_rate: Optional[int] = None,
258+
) -> None:
259+
return
260+
261+
254262
@register_fake("torchcodec_ns::create_from_tensor")
255263
def create_from_tensor_abstract(
256264
video_tensor: torch.Tensor, seek_mode: Optional[str]
257265
) -> torch.Tensor:
258266
return torch.empty([], dtype=torch.long)
259267

260268

261-
@register_fake("torchcodec_ns::_convert_to_tensor")
262-
def _convert_to_tensor_abstract(decoder_ptr: int) -> torch.Tensor:
263-
return torch.empty([], dtype=torch.long)
264-
265-
266269
@register_fake("torchcodec_ns::_add_video_stream")
267270
def _add_video_stream_abstract(
268271
decoder: torch.Tensor,

0 commit comments

Comments
 (0)