@@ -27,16 +27,18 @@ See the License for the specific language governing permissions and
27
27
limitations under the License.
28
28
==============================================================================*/
29
29
30
+ #include < cstdint>
30
31
#include < iterator>
32
+ #include < limits>
31
33
#include < vector>
32
34
33
- #include " tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h"
34
- #include " tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h"
35
35
#include " tensorflow/core/framework/op.h"
36
36
#include " tensorflow/core/framework/op_kernel.h"
37
37
#include " tensorflow/core/framework/shape_inference.h"
38
38
#include " tensorflow/core/framework/tensor.h"
39
39
#include " tensorflow/core/platform/errors.h"
40
+ #include " tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h"
41
+ #include " tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h"
40
42
41
43
namespace tensorflow {
42
44
namespace text {
@@ -50,7 +52,7 @@ class TFSentencepieceOp : public tensorflow::OpKernel {
50
52
const auto & input_values_tensor = ctx->input (kInputIndex );
51
53
const auto input_values_flat =
52
54
input_values_tensor.flat <tensorflow::tstring>();
53
- const int num_of_input_values = input_values_flat.size ();
55
+ const int64_t num_of_input_values = input_values_flat.size ();
54
56
55
57
const auto & add_bos_tensor = ctx->input (kAddBOSInput );
56
58
const bool add_bos = add_bos_tensor.scalar <bool >()();
@@ -74,20 +76,26 @@ class TFSentencepieceOp : public tensorflow::OpKernel {
74
76
}
75
77
tensorflow::Tensor* output_values_tensor = nullptr ;
76
78
tensorflow::Tensor* output_splits_tensor = nullptr ;
77
-
79
+ OP_REQUIRES (ctx, encoded.size () < std::numeric_limits<int32_t >::max (),
80
+ errors::InvalidArgument (
81
+ " Encoded input must contain less than 2^31 characters." ));
82
+ OP_REQUIRES (
83
+ ctx, splits.size () + 1 < std::numeric_limits<int32_t >::max (),
84
+ errors::InvalidArgument (" Splits tensor is limited to 2^31-1 values." ));
78
85
OP_REQUIRES_OK (
79
- ctx, ctx->allocate_output (0 , {( int16_t ) encoded.size ()},
86
+ ctx, ctx->allocate_output (0 , {static_cast < int32_t >( encoded.size () )},
80
87
&output_values_tensor));
81
- OP_REQUIRES_OK (ctx, ctx->allocate_output (1 , {(int16_t )splits.size () + 1 },
82
- &output_splits_tensor));
88
+ OP_REQUIRES_OK (
89
+ ctx, ctx->allocate_output (1 , {static_cast <int32_t >(splits.size ()) + 1 },
90
+ &output_splits_tensor));
83
91
84
92
auto values_tensor_flat = output_values_tensor->vec <int32>();
85
93
auto splits_tensor_flat = output_splits_tensor->vec <int32>();
86
- for (int i = 0 ; i < encoded.size (); ++i) {
94
+ for (int32_t i = 0 ; i < encoded.size (); ++i) {
87
95
values_tensor_flat (i) = encoded[i];
88
96
}
89
97
splits_tensor_flat (0 ) = 0 ;
90
- for (int i = 0 ; i < splits.size (); ++i) {
98
+ for (int32_t i = 0 ; i < splits.size (); ++i) {
91
99
splits_tensor_flat (i + 1 ) = splits[i];
92
100
}
93
101
}
0 commit comments