diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD index 57527e275..e1dd0ec82 100644 --- a/tensorflow_text/core/kernels/BUILD +++ b/tensorflow_text/core/kernels/BUILD @@ -832,6 +832,7 @@ tf_cc_library( # tf:lib tensorflow dep, ], deps = [ + "@com_google_absl//absl/status", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc index 7e35041f6..dc100f7a5 100644 --- a/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc +++ b/tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer_kernel.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/framework/op_requires.h" +#include "absl/status/status.h" #include "tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h" #include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h" @@ -51,6 +52,9 @@ class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel { input_values_tensor.flat(); const auto& input_splits_tensor = ctx->input(kInputSplits); const auto input_splits_flat = input_splits_tensor.flat(); + OP_REQUIRES(ctx, input_splits_flat.size() > 0, + absl::InvalidArgumentError( + "input_splits must have at least 1 element.")); const int num_of_sentences = input_splits_flat.size() - 1; Tensor* output_tensor = nullptr; OP_REQUIRES_OK(ctx, @@ -65,8 +69,8 @@ class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel { ctx, split_size >= 0 && (input_offset + split_size) <= input_values_flat.size(), - errors::InvalidArgument("input_splits must be monotonically " - "non-decreasing and within bounds.")); + absl::InvalidArgumentError("input_splits must be monotonically " + "non-decreasing and within bounds.")); codes_for_split.clear(); codes_for_split.reserve(split_size); for (int j = 0; j < split_size; ++j) { diff --git a/tensorflow_text/core/kernels/sentencepiece_kernels.cc b/tensorflow_text/core/kernels/sentencepiece_kernels.cc index a1f57bc19..a002f8fa4 100644 --- a/tensorflow_text/core/kernels/sentencepiece_kernels.cc +++ b/tensorflow_text/core/kernels/sentencepiece_kernels.cc @@ -576,6 +576,9 @@ class SentencepieceDetokenizeOp : public OpKernel { const auto input_values_flat = input_values_tensor.flat(); const Tensor& input_splits_tensor = ctx->input(2); const auto input_splits_flat = input_splits_tensor.flat(); + OP_REQUIRES(ctx, input_splits_flat.size() > 0, + absl::InvalidArgumentError( + "input_splits must have at least 1 element.")); const int64 num_of_sentences = input_splits_flat.size() - 1; OP_REQUIRES_OK(ctx, HandleExtraOptions(ctx, sp)); diff --git a/tensorflow_text/core/ops/fast_sentencepiece_ops.cc b/tensorflow_text/core/ops/fast_sentencepiece_ops.cc index ac32b2e31..697233d32 100644 --- a/tensorflow_text/core/ops/fast_sentencepiece_ops.cc +++ b/tensorflow_text/core/ops/fast_sentencepiece_ops.cc @@ -15,6 +15,7 @@ #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "absl/status/status.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -62,8 +63,14 @@ REGISTER_OP("TFText>FastSentencepieceDetokenize") TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); + shape_inference::DimensionHandle num_splits = c->NumElements(c->input(2)); shape_inference::DimensionHandle dim; - TF_RETURN_IF_ERROR(c->Subtract(c->NumElements(c->input(2)), 1, &dim)); + if (c->ValueKnown(num_splits) && c->Value(num_splits) == 0) { + return absl::InvalidArgumentError( + "input_splits must have at least 1 element."); + } else { + TF_RETURN_IF_ERROR(c->Subtract(num_splits, 1, &dim)); + } c->set_output(0, c->Vector(dim)); return absl::OkStatus(); }); diff --git a/tensorflow_text/core/ops/rouge_l_op.cc b/tensorflow_text/core/ops/rouge_l_op.cc index ae63b4f5b..9510bddab 100644 --- a/tensorflow_text/core/ops/rouge_l_op.cc +++ b/tensorflow_text/core/ops/rouge_l_op.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "absl/status/status.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -88,8 +89,13 @@ absl::Status RougeLShapeFn(InferenceContext* c) { &output_nrows_plus_one)); // Output shape is a 1-D tensor with size equal to number of splits minus 1. + DimensionHandle num_splits = c->Dim(output_nrows_plus_one, 0); DimensionHandle dim; - TF_RETURN_IF_ERROR(c->Subtract(c->Dim(output_nrows_plus_one, 0), 1, &dim)); + if (c->ValueKnown(num_splits) && c->Value(num_splits) == 0) { + return absl::InvalidArgumentError("splits must have at least 1 element."); + } else { + TF_RETURN_IF_ERROR(c->Subtract(num_splits, 1, &dim)); + } // All outputs have the same shape. c->set_output(0, c->Vector(dim)); diff --git a/tensorflow_text/core/ops/sentencepiece_ops.cc b/tensorflow_text/core/ops/sentencepiece_ops.cc index 1f557862a..d3e72c014 100644 --- a/tensorflow_text/core/ops/sentencepiece_ops.cc +++ b/tensorflow_text/core/ops/sentencepiece_ops.cc @@ -15,6 +15,7 @@ #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "absl/status/status.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -129,8 +130,14 @@ REGISTER_OP("SentencepieceDetokenizeOp") TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + shape_inference::DimensionHandle num_splits = c->NumElements(c->input(2)); shape_inference::DimensionHandle dim; - TF_RETURN_IF_ERROR(c->Subtract(c->NumElements(c->input(2)), 1, &dim)); + if (c->ValueKnown(num_splits) && c->Value(num_splits) == 0) { + return absl::InvalidArgumentError( + "input_splits must have at least 1 element."); + } else { + TF_RETURN_IF_ERROR(c->Subtract(num_splits, 1, &dim)); + } c->set_output(0, c->Vector(dim)); return absl::OkStatus(); }); diff --git a/tensorflow_text/python/ops/sentencepiece_tokenizer_test.py b/tensorflow_text/python/ops/sentencepiece_tokenizer_test.py index 6f42ff712..2eabb2875 100644 --- a/tensorflow_text/python/ops/sentencepiece_tokenizer_test.py +++ b/tensorflow_text/python/ops/sentencepiece_tokenizer_test.py @@ -37,6 +37,9 @@ from tensorflow.python.platform import test from tensorflow.python.saved_model import load from tensorflow.python.saved_model import save +from tensorflow.python.framework import load_library +from tensorflow.python.platform import resource_loader +gen_sentencepiece_tokenizer = load_library.load_op_library(resource_loader.get_path_to_datafile('_sentencepiece_tokenizer.so')) from tensorflow_text.python.ops.sentencepiece_tokenizer import SentencepieceTokenizer @@ -451,6 +454,17 @@ def testEmptyInputDetokenize(self): detokenized = sp.detokenize(constant_op.constant([], dtypes.int32)) self.assertAllEqual('', detokenized) + def testEmptyInputSplitsDetokenizeOpZeroShape(self): + sp = SentencepieceTokenizer(self.model) + # Providing an empty tensor (length 0) as input_splits directly to the op + # should safely raise InvalidArgumentError instead of crashing. + with self.assertRaises((errors.InvalidArgumentError, ValueError)): + _ = gen_sentencepiece_tokenizer.sentencepiece_detokenize_op( + sp._model_resource.resource_handle, + constant_op.constant([], dtypes.int32), + constant_op.constant([], dtypes.int64), + add_bos=False, add_eos=False, reverse=False) + def testReturnNbestAndDetokenize(self): sp = SentencepieceTokenizer( self.model, nbest_size=2, out_type=dtypes.int32, return_nbest=True)