Skip to content

Commit aa839b1

Browse files
cantoniostf-text-github-robot
authored andcommitted
Change input size from int16_t to int32_t to support large inputs.
PiperOrigin-RevId: 796484892
1 parent 1ef5f72 commit aa839b1

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer_kernel.cc

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,18 @@ See the License for the specific language governing permissions and
2727
limitations under the License.
2828
==============================================================================*/
2929

30+
#include <cstdint>
3031
#include <iterator>
32+
#include <limits>
3133
#include <vector>
3234

33-
#include "tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h"
34-
#include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h"
3535
#include "tensorflow/core/framework/op.h"
3636
#include "tensorflow/core/framework/op_kernel.h"
3737
#include "tensorflow/core/framework/shape_inference.h"
3838
#include "tensorflow/core/framework/tensor.h"
3939
#include "tensorflow/core/platform/errors.h"
40+
#include "tensorflow_text/core/kernels/sentencepiece/optimized_encoder.h"
41+
#include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_tokenizer.h"
4042

4143
namespace tensorflow {
4244
namespace text{
@@ -50,7 +52,7 @@ class TFSentencepieceOp : public tensorflow::OpKernel {
5052
const auto& input_values_tensor = ctx->input(kInputIndex);
5153
const auto input_values_flat =
5254
input_values_tensor.flat<tensorflow::tstring>();
53-
const int num_of_input_values = input_values_flat.size();
55+
const int64_t num_of_input_values = input_values_flat.size();
5456

5557
const auto& add_bos_tensor = ctx->input(kAddBOSInput);
5658
const bool add_bos = add_bos_tensor.scalar<bool>()();
@@ -74,20 +76,26 @@ class TFSentencepieceOp : public tensorflow::OpKernel {
7476
}
7577
tensorflow::Tensor* output_values_tensor = nullptr;
7678
tensorflow::Tensor* output_splits_tensor = nullptr;
77-
79+
OP_REQUIRES(ctx, encoded.size() < std::numeric_limits<int32_t>::max(),
80+
errors::InvalidArgument(
81+
"Encoded input must contain less than 2^31 characters."));
82+
OP_REQUIRES(
83+
ctx, splits.size() + 1 < std::numeric_limits<int32_t>::max(),
84+
errors::InvalidArgument("Splits tensor is limited to 2^31-1 values."));
7885
OP_REQUIRES_OK(
79-
ctx, ctx->allocate_output(0, {(int16_t)encoded.size()},
86+
ctx, ctx->allocate_output(0, {static_cast<int32_t>(encoded.size())},
8087
&output_values_tensor));
81-
OP_REQUIRES_OK(ctx, ctx->allocate_output(1, {(int16_t)splits.size() + 1},
82-
&output_splits_tensor));
88+
OP_REQUIRES_OK(
89+
ctx, ctx->allocate_output(1, {static_cast<int32_t>(splits.size()) + 1},
90+
&output_splits_tensor));
8391

8492
auto values_tensor_flat = output_values_tensor->vec<int32>();
8593
auto splits_tensor_flat = output_splits_tensor->vec<int32>();
86-
for (int i = 0; i < encoded.size(); ++i) {
94+
for (int32_t i = 0; i < encoded.size(); ++i) {
8795
values_tensor_flat(i) = encoded[i];
8896
}
8997
splits_tensor_flat(0) = 0;
90-
for (int i = 0; i < splits.size(); ++i) {
98+
for (int32_t i = 0; i < splits.size(); ++i) {
9199
splits_tensor_flat(i + 1) = splits[i];
92100
}
93101
}

0 commit comments

Comments
 (0)