Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_CONCATENATION_H_

#include <algorithm>
#include <cstddef>

#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/compatibility.h"
Expand Down Expand Up @@ -109,7 +110,7 @@ inline void Concatenation<Int4>(const ConcatenationParams& params,
// not garbage.
// Note: output_shape.FlatSize() gives number of elements (nibbles).
// Bytes needed: (elements + 1) / 2.
memset(output_ptr, 0, (output_shape.FlatSize() + 1) / 2);
memset(output_ptr, 0, (static_cast<size_t>(output_shape.FlatSize()) + 1) / 2);

int64_t output_offset = 0;
for (int k = 0; k < outer_size; k++) {
Expand Down
23 changes: 23 additions & 0 deletions tensorflow/lite/kernels/kernel_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#ifndef TF_LITE_STATIC_MEMORY
#include <string>

#include "absl/types/span.h"
#include "tensorflow/lite/array.h"
#endif // TF_LITE_STATIC_MEMORY

Expand All @@ -33,6 +34,7 @@ limitations under the License.
#include "tensorflow/lite/core/c/common.h"
#include "tensorflow/lite/kernels/internal/cppmath.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/util.h"

#if defined(__APPLE__)
#include "TargetConditionals.h"
Expand Down Expand Up @@ -595,4 +597,25 @@ bool HasUnspecifiedDimension(const TfLiteTensor* tensor) {
return false;
}

TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
absl::Span<const int> dims,
const char* error_message, size_t& product) {
// The CheckedNumElements function already checks for negative dimensions, so
// we don't do it here.
TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk,
"%s", error_message);
return kTfLiteOk;
}

TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context,
absl::Span<const int> dims,
const char* error_message, int& product) {
for (const int dim : dims) {
TF_LITE_ENSURE_MSG(context, dim >= 0, "Encountered a negative dimension.");
}
TF_LITE_ENSURE_MSG(context, CheckedNumElements(dims, product) == kTfLiteOk,
"%s", error_message);
return kTfLiteOk;
}

} // namespace tflite
26 changes: 26 additions & 0 deletions tensorflow/lite/kernels/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ limitations under the License.

#include <stdint.h>

#include <cstddef>
#include <limits>
#ifndef TF_LITE_STATIC_MEMORY
#include <string>
#endif // TF_LITE_STATIC_MEMORY

#include "absl/types/span.h"
#include "tensorflow/lite/core/c/builtin_op_data.h"
#include "tensorflow/lite/core/c/common.h"
#ifndef NDEBUG
Expand Down Expand Up @@ -341,6 +343,30 @@ bool IsMobilePlatform();
// Returns whether there is unspecified dimension in the tensor's dim signature.
bool HasUnspecifiedDimension(const TfLiteTensor* tensor);

/**
* Calculates the product of the given dimensions. Returns an error if any of
* the dimensions is negative or if the product overflows.
* @param context The context to use for error reporting.
* @param dims The dimensions to multiply.
* @param error_message The error message to use if an error is encountered.
* @param product The output parameter to store the product.
*/
TfLiteStatus CheckedShapeProduct(TfLiteContext* context,
absl::Span<const int> dims,
const char* error_message, size_t& product);

/**
* Calculates the product of the given dimensions. Returns an error if any of
* the dimensions is negative or if the product overflows.
* @param context The context to use for error reporting.
* @param dims The dimensions to multiply.
* @param error_message The error message to use if an error is encountered.
* @param product The output parameter to store the product.
*/
TfLiteStatus CheckedShapeProductToInt(TfLiteContext* context,
absl::Span<const int> dims,
const char* error_message, int& product);

} // namespace tflite

#endif // TENSORFLOW_LITE_KERNELS_KERNEL_UTIL_H_
6 changes: 3 additions & 3 deletions tensorflow/lite/tools/flatbuffer_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import subprocess
import sys

from tflite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
from tflite_micro.tensorflow.lite.tools import flatbuffer_utils
from tflite_micro.tensorflow.lite.tools import test_utils
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.python import schema_py_generated as schema # pylint:disable=g-direct-tensorflow-import
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import flatbuffer_utils
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import test_utils
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/lite/tools/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

import flatbuffers
from tflite_micro.tensorflow.lite.python import schema_py_generated as schema_fb
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.python import schema_py_generated as schema_fb

TFLITE_SCHEMA_VERSION = 3

Expand Down
4 changes: 2 additions & 2 deletions tensorflow/lite/tools/visualize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import os
import re

from tflite_micro.tensorflow.lite.tools import test_utils
from tflite_micro.tensorflow.lite.tools import visualize
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import test_utils
from tflite_micro.tensorflow.lite_micro.tensorflow.lite.tools import visualize
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test

Expand Down
Loading