10
10
#include < string>
11
11
#include " c10/core/SymIntArrayRef.h"
12
12
#include " c10/util/Exception.h"
13
+ #include " src/torchcodec/_core/AVIOFileLikeContext.h"
13
14
#include " src/torchcodec/_core/AVIOTensorContext.h"
14
15
#include " src/torchcodec/_core/Encoder.h"
15
16
#include " src/torchcodec/_core/SingleStreamDecoder.h"
@@ -35,9 +36,12 @@ TORCH_LIBRARY(torchcodec_ns, m) {
35
36
" encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()" );
36
37
m.def (
37
38
" encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor" );
39
+ m.def (
40
+ " _encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()" );
38
41
m.def (
39
42
" create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor" );
40
- m.def (" _convert_to_tensor(int decoder_ptr) -> Tensor" );
43
+ m.def (
44
+ " _create_from_file_like(int file_like_context, str? seek_mode=None) -> Tensor" );
41
45
m.def (
42
46
" _add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, (Tensor, Tensor, Tensor)? custom_frame_mappings=None, str? color_conversion_library=None) -> ()" );
43
47
m.def (
@@ -167,6 +171,18 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {
167
171
return ss.str ();
168
172
}
169
173
174
+ SingleStreamDecoder::SeekMode seekModeFromString (std::string_view seekMode) {
175
+ if (seekMode == " exact" ) {
176
+ return SingleStreamDecoder::SeekMode::exact;
177
+ } else if (seekMode == " approximate" ) {
178
+ return SingleStreamDecoder::SeekMode::approximate;
179
+ } else if (seekMode == " custom_frame_mappings" ) {
180
+ return SingleStreamDecoder::SeekMode::custom_frame_mappings;
181
+ } else {
182
+ TORCH_CHECK (false , " Invalid seek mode: " + std::string (seekMode));
183
+ }
184
+ }
185
+
170
186
} // namespace
171
187
172
188
// ==============================
@@ -205,16 +221,32 @@ at::Tensor create_from_tensor(
205
221
realSeek = seekModeFromString (seek_mode.value ());
206
222
}
207
223
208
- auto contextHolder = std::make_unique<AVIOFromTensorContext>(video_tensor);
224
+ auto avioContextHolder =
225
+ std::make_unique<AVIOFromTensorContext>(video_tensor);
209
226
210
227
std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
211
- std::make_unique<SingleStreamDecoder>(std::move (contextHolder), realSeek);
228
+ std::make_unique<SingleStreamDecoder>(
229
+ std::move (avioContextHolder), realSeek);
212
230
return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
213
231
}
214
232
215
- at::Tensor _convert_to_tensor (int64_t decoder_ptr) {
216
- auto decoder = reinterpret_cast <SingleStreamDecoder*>(decoder_ptr);
217
- std::unique_ptr<SingleStreamDecoder> uniqueDecoder (decoder);
233
+ at::Tensor _create_from_file_like (
234
+ int64_t file_like_context,
235
+ std::optional<std::string_view> seek_mode) {
236
+ auto fileLikeContext =
237
+ reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
238
+ TORCH_CHECK (
239
+ fileLikeContext != nullptr , " file_like_context must be a valid pointer" );
240
+ std::unique_ptr<AVIOFileLikeContext> avioContextHolder (fileLikeContext);
241
+
242
+ SingleStreamDecoder::SeekMode realSeek = SingleStreamDecoder::SeekMode::exact;
243
+ if (seek_mode.has_value ()) {
244
+ realSeek = seekModeFromString (seek_mode.value ());
245
+ }
246
+
247
+ std::unique_ptr<SingleStreamDecoder> uniqueDecoder =
248
+ std::make_unique<SingleStreamDecoder>(
249
+ std::move (avioContextHolder), realSeek);
218
250
return wrapDecoderPointerToTensor (std::move (uniqueDecoder));
219
251
}
220
252
@@ -456,6 +488,36 @@ at::Tensor encode_audio_to_tensor(
456
488
.encodeToTensor ();
457
489
}
458
490
491
+ void _encode_audio_to_file_like (
492
+ const at::Tensor& samples,
493
+ int64_t sample_rate,
494
+ std::string_view format,
495
+ int64_t file_like_context,
496
+ std::optional<int64_t > bit_rate = std::nullopt ,
497
+ std::optional<int64_t > num_channels = std::nullopt ,
498
+ std::optional<int64_t > desired_sample_rate = std::nullopt ) {
499
+ auto fileLikeContext =
500
+ reinterpret_cast <AVIOFileLikeContext*>(file_like_context);
501
+ TORCH_CHECK (
502
+ fileLikeContext != nullptr , " file_like_context must be a valid pointer" );
503
+ std::unique_ptr<AVIOFileLikeContext> avioContextHolder (fileLikeContext);
504
+
505
+ AudioStreamOptions audioStreamOptions;
506
+ audioStreamOptions.bitRate = validateOptionalInt64ToInt (bit_rate, " bit_rate" );
507
+ audioStreamOptions.numChannels =
508
+ validateOptionalInt64ToInt (num_channels, " num_channels" );
509
+ audioStreamOptions.sampleRate =
510
+ validateOptionalInt64ToInt (desired_sample_rate, " desired_sample_rate" );
511
+
512
+ AudioEncoder encoder (
513
+ samples,
514
+ validateInt64ToInt (sample_rate, " sample_rate" ),
515
+ format,
516
+ std::move (avioContextHolder),
517
+ audioStreamOptions);
518
+ encoder.encode ();
519
+ }
520
+
459
521
// For testing only. We need to implement this operation as a core library
460
522
// function because what we're testing is round-tripping pts values as
461
523
// double-precision floating point numbers from C++ to Python and back to C++.
@@ -709,7 +771,7 @@ void scan_all_streams_to_update_metadata(at::Tensor& decoder) {
709
771
TORCH_LIBRARY_IMPL (torchcodec_ns, BackendSelect, m) {
710
772
m.impl (" create_from_file" , &create_from_file);
711
773
m.impl (" create_from_tensor" , &create_from_tensor);
712
- m.impl (" _convert_to_tensor " , &_convert_to_tensor );
774
+ m.impl (" _create_from_file_like " , &_create_from_file_like );
713
775
m.impl (
714
776
" _get_json_ffmpeg_library_versions" , &_get_json_ffmpeg_library_versions);
715
777
}
@@ -718,6 +780,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
718
780
m.impl (" encode_audio_to_file" , &encode_audio_to_file);
719
781
m.impl (" encode_video_to_file" , &encode_video_to_file);
720
782
m.impl (" encode_audio_to_tensor" , &encode_audio_to_tensor);
783
+ m.impl (" _encode_audio_to_file_like" , &_encode_audio_to_file_like);
721
784
m.impl (" seek_to_pts" , &seek_to_pts);
722
785
m.impl (" add_video_stream" , &add_video_stream);
723
786
m.impl (" _add_video_stream" , &_add_video_stream);
0 commit comments