Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow_text/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ tf_cc_library(
# tf:lib tensorflow dep,
],
deps = [
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/framework/op_requires.h"
#include "absl/status/status.h"
#include "tensorflow_text/core/kernels/sentencepiece/optimized_decoder.h"
#include "tensorflow_text/core/kernels/sentencepiece/sentencepiece_detokenizer.h"

Expand All @@ -51,6 +52,9 @@ class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel {
input_values_tensor.flat<tensorflow::int32>();
const auto& input_splits_tensor = ctx->input(kInputSplits);
const auto input_splits_flat = input_splits_tensor.flat<Tsplits>();
OP_REQUIRES(ctx, input_splits_flat.size() > 0,
absl::InvalidArgumentError(
"input_splits must have at least 1 element."));
const int num_of_sentences = input_splits_flat.size() - 1;
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(ctx,
Expand All @@ -65,8 +69,8 @@ class TFSentencepieceDetokenizerOp : public tensorflow::OpKernel {
ctx,
split_size >= 0 &&
(input_offset + split_size) <= input_values_flat.size(),
errors::InvalidArgument("input_splits must be monotonically "
"non-decreasing and within bounds."));
absl::InvalidArgumentError("input_splits must be monotonically "
"non-decreasing and within bounds."));
codes_for_split.clear();
codes_for_split.reserve(split_size);
for (int j = 0; j < split_size; ++j) {
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_text/core/kernels/sentencepiece_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,9 @@ class SentencepieceDetokenizeOp : public OpKernel {
const auto input_values_flat = input_values_tensor.flat<T>();
const Tensor& input_splits_tensor = ctx->input(2);
const auto input_splits_flat = input_splits_tensor.flat<Tsplits>();
OP_REQUIRES(ctx, input_splits_flat.size() > 0,
absl::InvalidArgumentError(
"input_splits must have at least 1 element."));
const int64 num_of_sentences = input_splits_flat.size() - 1;

OP_REQUIRES_OK(ctx, HandleExtraOptions(ctx, sp));
Expand Down
9 changes: 8 additions & 1 deletion tensorflow_text/core/ops/fast_sentencepiece_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "absl/status/status.h"
#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {
Expand Down Expand Up @@ -62,8 +63,14 @@ REGISTER_OP("TFText>FastSentencepieceDetokenize")
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));

shape_inference::DimensionHandle num_splits = c->NumElements(c->input(2));
shape_inference::DimensionHandle dim;
TF_RETURN_IF_ERROR(c->Subtract(c->NumElements(c->input(2)), 1, &dim));
if (c->ValueKnown(num_splits) && c->Value(num_splits) == 0) {
return absl::InvalidArgumentError(
"input_splits must have at least 1 element.");
} else {
TF_RETURN_IF_ERROR(c->Subtract(num_splits, 1, &dim));
}
c->set_output(0, c->Vector(dim));
return absl::OkStatus();
});
Expand Down
8 changes: 7 additions & 1 deletion tensorflow_text/core/ops/rouge_l_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "absl/status/status.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"

Expand Down Expand Up @@ -88,8 +89,13 @@ absl::Status RougeLShapeFn(InferenceContext* c) {
&output_nrows_plus_one));

// Output shape is a 1-D tensor with size equal to number of splits minus 1.
DimensionHandle num_splits = c->Dim(output_nrows_plus_one, 0);
DimensionHandle dim;
TF_RETURN_IF_ERROR(c->Subtract(c->Dim(output_nrows_plus_one, 0), 1, &dim));
if (c->ValueKnown(num_splits) && c->Value(num_splits) == 0) {
return absl::InvalidArgumentError("splits must have at least 1 element.");
} else {
TF_RETURN_IF_ERROR(c->Subtract(num_splits, 1, &dim));
}

// All outputs have the same shape.
c->set_output(0, c->Vector(dim));
Expand Down
9 changes: 8 additions & 1 deletion tensorflow_text/core/ops/sentencepiece_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "absl/status/status.h"
#include "tensorflow/core/lib/core/errors.h"

namespace tensorflow {
Expand Down Expand Up @@ -129,8 +130,14 @@ REGISTER_OP("SentencepieceDetokenizeOp")
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));

shape_inference::DimensionHandle num_splits = c->NumElements(c->input(2));
shape_inference::DimensionHandle dim;
TF_RETURN_IF_ERROR(c->Subtract(c->NumElements(c->input(2)), 1, &dim));
if (c->ValueKnown(num_splits) && c->Value(num_splits) == 0) {
return absl::InvalidArgumentError(
"input_splits must have at least 1 element.");
} else {
TF_RETURN_IF_ERROR(c->Subtract(num_splits, 1, &dim));
}
c->set_output(0, c->Vector(dim));
return absl::OkStatus();
});
Expand Down
14 changes: 14 additions & 0 deletions tensorflow_text/python/ops/sentencepiece_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
from tensorflow.python.platform import test
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader
gen_sentencepiece_tokenizer = load_library.load_op_library(resource_loader.get_path_to_datafile('_sentencepiece_tokenizer.so'))
from tensorflow_text.python.ops.sentencepiece_tokenizer import SentencepieceTokenizer


Expand Down Expand Up @@ -451,6 +454,17 @@ def testEmptyInputDetokenize(self):
detokenized = sp.detokenize(constant_op.constant([], dtypes.int32))
self.assertAllEqual('', detokenized)

def testEmptyInputSplitsDetokenizeOpZeroShape(self):
sp = SentencepieceTokenizer(self.model)
# Providing an empty tensor (length 0) as input_splits directly to the op
# should safely raise InvalidArgumentError instead of crashing.
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
_ = gen_sentencepiece_tokenizer.sentencepiece_detokenize_op(
sp._model_resource.resource_handle,
constant_op.constant([], dtypes.int32),
constant_op.constant([], dtypes.int64),
add_bos=False, add_eos=False, reverse=False)

def testReturnNbestAndDetokenize(self):
sp = SentencepieceTokenizer(
self.model, nbest_size=2, out_type=dtypes.int32, return_nbest=True)
Expand Down