diff --git a/tensorflow/lite/kernels/internal/reference/concatenation.h b/tensorflow/lite/kernels/internal/reference/concatenation.h index 4a82d7c502d..915492b1e92 100644 --- a/tensorflow/lite/kernels/internal/reference/concatenation.h +++ b/tensorflow/lite/kernels/internal/reference/concatenation.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONCATENATION_H_ #include +#include #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/internal/compatibility.h" @@ -109,7 +110,7 @@ inline void Concatenation(const ConcatenationParams& params, // not garbage. // Note: output_shape.FlatSize() gives number of elements (nibbles). // Bytes needed: (elements + 1) / 2. - memset(output_ptr, 0, (output_shape.FlatSize() + 1) / 2); + memset(output_ptr, 0, (static_cast(output_shape.FlatSize()) + 1) / 2); int64_t output_offset = 0; for (int k = 0; k < outer_size; k++) { diff --git a/tensorflow/lite/kernels/kernel_util.cc b/tensorflow/lite/kernels/kernel_util.cc index 62feffc1c0a..1554f2f2e5a 100644 --- a/tensorflow/lite/kernels/kernel_util.cc +++ b/tensorflow/lite/kernels/kernel_util.cc @@ -25,6 +25,7 @@ limitations under the License. #ifndef TF_LITE_STATIC_MEMORY #include +#include "absl/types/span.h" #include "tensorflow/lite/array.h" #endif // TF_LITE_STATIC_MEMORY @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/cppmath.h" #include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/util.h" #if defined(__APPLE__) #include "TargetConditionals.h" @@ -595,4 +597,25 @@ bool HasUnspecifiedDimension(const TfLiteTensor* tensor) { return false; } +TfLiteStatus CheckedShapeProduct(TfLiteContext* context, + absl::Span dims, + const char* error_message, size_t& product) { + // The CheckedNumElements function already checks for negative dimensions, so + // we don't do it here. + TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk, + "%s", error_message); + return kTfLiteOk; +} + +TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context, + absl::Span dims, + const char* error_message, int& product) { + for (const int dim : dims) { + TF_LITE_ENSURE_MSG(context, dim >= 0, "Encountered a negative dimension."); + } + TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk, + "%s", error_message); + return kTfLiteOk; +} + } // namespace tflite diff --git a/tensorflow/lite/kernels/kernel_util.h b/tensorflow/lite/kernels/kernel_util.h index 25e5386ccb6..6b649cc8e9b 100644 --- a/tensorflow/lite/kernels/kernel_util.h +++ b/tensorflow/lite/kernels/kernel_util.h @@ -17,11 +17,13 @@ limitations under the License. #include +#include #include #ifndef TF_LITE_STATIC_MEMORY #include #endif // TF_LITE_STATIC_MEMORY +#include "absl/types/span.h" #include "tensorflow/lite/core/c/builtin_op_data.h" #include "tensorflow/lite/core/c/common.h" #ifndef NDEBUG @@ -341,6 +343,30 @@ bool IsMobilePlatform(); // Returns whether there is unspecified dimension in the tensor's dim signature. bool HasUnspecifiedDimension(const TfLiteTensor* tensor); +/** + * Calculates the product of the given dimensions. Returns an error if any of + * the dimensions is negative or if the product overflows. + * @param context The context to use for error reporting. + * @param dims The dimensions to multiply. + * @param error_message The error message to use if an error is encountered. + * @param product The output parameter to store the product. + */ +TfLiteStatus CheckedShapeProduct(TfLiteContext* context, + absl::Span dims, + const char* error_message, size_t& product); + +/** + * Calculates the product of the given dimensions. Returns an error if any of + * the dimensions is negative or if the product overflows. + * @param context The context to use for error reporting. + * @param dims The dimensions to multiply. + * @param error_message The error message to use if an error is encountered. + * @param product The output parameter to store the product. + */ +TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context, + absl::Span dims, + const char* error_message, int& product); + } // namespace tflite #endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/lite/tools/flatbuffer_utils_test.py b/tensorflow/lite/tools/flatbuffer_utils_test.py index 13074aaca5e..e8a2e46b9be 100644 --- a/tensorflow/lite/tools/flatbuffer_utils_test.py +++ b/tensorflow/lite/tools/flatbuffer_utils_test.py @@ -18,9 +18,9 @@ import subprocess import sys -from tflite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import -from tflite_micro.tensorflow.lite.tools import flatbuffer_utils -from tflite_micro.tensorflow.lite.tools import test_utils +from tflite_micro.tensorflow.lite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import +from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import flatbuffer_utils +from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import test_utils from tensorflow.python.framework import test_util from tensorflow.python.platform import test diff --git a/tensorflow/lite/tools/test_utils.py b/tensorflow/lite/tools/test_utils.py index 44157143d5d..582fbd2879b 100644 --- a/tensorflow/lite/tools/test_utils.py +++ b/tensorflow/lite/tools/test_utils.py @@ -18,7 +18,7 @@ """ import flatbuffers -from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb +from tflite_micro.tensorflow.lite_micro.tensorflow.lite.python import schema_py_generated as schema_fb TFLITE_SCHEMA_VERSION = 3 diff --git a/tensorflow/lite/tools/visualize_test.py b/tensorflow/lite/tools/visualize_test.py index 68de38cc9d7..4cbb01f2b58 100644 --- a/tensorflow/lite/tools/visualize_test.py +++ b/tensorflow/lite/tools/visualize_test.py @@ -16,8 +16,8 @@ import os import re -from tflite_micro.tensorflow.lite.tools import test_utils -from tflite_micro.tensorflow.lite.tools import visualize +from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import test_utils +from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import visualize from tensorflow.python.framework import test_util from tensorflow.python.platform import test