@@ -18,34 +18,34 @@ constexpr int64_t MAX_TENSOR_SIZE = 320'000'000; // 320 MB
1818int read (void * opaque, uint8_t * buf, int buf_size) {
1919 auto tensorContext = static_cast <detail::TensorContext*>(opaque);
2020 TORCH_CHECK (
21- tensorContext->current <= tensorContext->data .numel (),
22- " Tried to read outside of the buffer: current =" ,
23- tensorContext->current ,
21+ tensorContext->current_pos <= tensorContext->data .numel (),
22+ " Tried to read outside of the buffer: current_pos =" ,
23+ tensorContext->current_pos ,
2424 " , size=" ,
2525 tensorContext->data .numel ());
2626
2727 int64_t numBytesRead = std::min (
2828 static_cast <int64_t >(buf_size),
29- tensorContext->data .numel () - tensorContext->current );
29+ tensorContext->data .numel () - tensorContext->current_pos );
3030
3131 TORCH_CHECK (
3232 numBytesRead >= 0 ,
3333 " Tried to read negative bytes: numBytesRead=" ,
3434 numBytesRead,
3535 " , size=" ,
3636 tensorContext->data .numel (),
37- " , current =" ,
38- tensorContext->current );
37+ " , current_pos =" ,
38+ tensorContext->current_pos );
3939
4040 if (numBytesRead == 0 ) {
4141 return AVERROR_EOF;
4242 }
4343
4444 std::memcpy (
4545 buf,
46- tensorContext->data .data_ptr <uint8_t >() + tensorContext->current ,
46+ tensorContext->data .data_ptr <uint8_t >() + tensorContext->current_pos ,
4747 numBytesRead);
48- tensorContext->current += numBytesRead;
48+ tensorContext->current_pos += numBytesRead;
4949 return numBytesRead;
5050}
5151
@@ -54,7 +54,7 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
5454 auto tensorContext = static_cast <detail::TensorContext*>(opaque);
5555
5656 int64_t bufSize = static_cast <int64_t >(buf_size);
57- if (tensorContext->current + bufSize > tensorContext->data .numel ()) {
57+ if (tensorContext->current_pos + bufSize > tensorContext->data .numel ()) {
5858 TORCH_CHECK (
5959 tensorContext->data .numel () * 2 <= MAX_TENSOR_SIZE,
6060 " We tried to allocate an output encoded tensor larger than " ,
@@ -68,13 +68,17 @@ int write(void* opaque, const uint8_t* buf, int buf_size) {
6868 }
6969
7070 TORCH_CHECK (
71- tensorContext->current + bufSize <= tensorContext->data .numel (),
71+ tensorContext->current_pos + bufSize <= tensorContext->data .numel (),
7272 " Re-allocation of the output tensor didn't work. " ,
7373 " This should not happen, please report on TorchCodec bug tracker" );
7474
7575 uint8_t * outputTensorData = tensorContext->data .data_ptr <uint8_t >();
76- std::memcpy (outputTensorData + tensorContext->current , buf, bufSize);
77- tensorContext->current += bufSize;
76+ std::memcpy (outputTensorData + tensorContext->current_pos , buf, bufSize);
77+ tensorContext->current_pos += bufSize;
78+ // Track the maximum position written so getOutputTensor's narrow() does not
79+ // truncate the file if final seek was backwards
80+ tensorContext->max_pos =
81+ std::max (tensorContext->current_pos , tensorContext->max_pos );
7882 return buf_size;
7983}
8084
@@ -88,7 +92,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
8892 ret = tensorContext->data .numel ();
8993 break ;
9094 case SEEK_SET:
91- tensorContext->current = offset;
95+ tensorContext->current_pos = offset;
9296 ret = offset;
9397 break ;
9498 default :
@@ -101,7 +105,7 @@ int64_t seek(void* opaque, int64_t offset, int whence) {
101105} // namespace
102106
103107AVIOFromTensorContext::AVIOFromTensorContext (torch::Tensor data)
104- : tensorContext_{data, 0 } {
108+ : tensorContext_{data, 0 , 0 } {
105109 TORCH_CHECK (data.numel () > 0 , " data must not be empty" );
106110 TORCH_CHECK (data.is_contiguous (), " data must be contiguous" );
107111 TORCH_CHECK (data.scalar_type () == torch::kUInt8 , " data must be kUInt8" );
@@ -110,14 +114,17 @@ AVIOFromTensorContext::AVIOFromTensorContext(torch::Tensor data)
110114}
111115
112116AVIOToTensorContext::AVIOToTensorContext ()
113- : tensorContext_{torch::empty ({INITIAL_TENSOR_SIZE}, {torch::kUInt8 }), 0 } {
117+ : tensorContext_{
118+ torch::empty ({INITIAL_TENSOR_SIZE}, {torch::kUInt8 }),
119+ 0 ,
120+ 0 } {
114121 createAVIOContext (
115122 nullptr , &write, &seek, &tensorContext_, /* isForWriting=*/ true );
116123}
117124
118125torch::Tensor AVIOToTensorContext::getOutputTensor () {
119126 return tensorContext_.data .narrow (
120- /* dim=*/ 0 , /* start=*/ 0 , /* length=*/ tensorContext_.current );
127+ /* dim=*/ 0 , /* start=*/ 0 , /* length=*/ tensorContext_.max_pos );
121128}
122129
123130} // namespace facebook::torchcodec
0 commit comments