diff --git a/tensorflow/compiler/mlir/lite/schema/schema_utils.cc b/tensorflow/compiler/mlir/lite/schema/schema_utils.cc index a173380940d..cb61ce6243f 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema_utils.cc +++ b/tensorflow/compiler/mlir/lite/schema/schema_utils.cc @@ -15,8 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/schema/schema_utils.h" #include +#include +#include +#include #include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h" +#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" namespace tflite { @@ -59,4 +63,51 @@ BuiltinOperator GetBuiltinCode(const OperatorCodeT* op_code) { op_code->deprecated_builtin_code)); } +size_t TensorTypeGetSize(::tflite::TensorType data_type) { + switch (data_type) { + case ::tflite::TensorType_FLOAT32: + static_assert(sizeof(float) == 4, ""); + return 4; + case ::tflite::TensorType_FLOAT16: + static_assert(sizeof(int16_t) == 2, ""); + return 2; + case ::tflite::TensorType_INT32: + static_assert(sizeof(int32_t) == 4, ""); + return 4; + case ::tflite::TensorType_UINT8: + static_assert(sizeof(uint8_t) == 1, ""); + return 1; + case ::tflite::TensorType_INT64: + static_assert(sizeof(int64_t) == 8, ""); + return 8; + case ::tflite::TensorType_BOOL: + return sizeof(bool); + case ::tflite::TensorType_INT16: + static_assert(sizeof(int16_t) == 2, ""); + return 2; + case ::tflite::TensorType_COMPLEX64: + static_assert(sizeof(std::complex) == 8, ""); + return 8; + case ::tflite::TensorType_INT8: + static_assert(sizeof(int8_t) == 1, ""); + return 1; + case ::tflite::TensorType_FLOAT64: + static_assert(sizeof(double) == 8, ""); + return 8; + case ::tflite::TensorType_COMPLEX128: + static_assert(sizeof(std::complex) == 16, ""); + return 16; + case ::tflite::TensorType_UINT64: + static_assert(sizeof(uint64_t) == 8, ""); + return 8; + case ::tflite::TensorType_UINT32: + static_assert(sizeof(uint32_t) == 4, ""); + return 4; + case ::tflite::TensorType_UINT16: + static_assert(sizeof(uint16_t) == 2, ""); + return 2; + default: + return 0; + } +} } // namespace tflite diff --git a/tensorflow/compiler/mlir/lite/schema/schema_utils.h b/tensorflow/compiler/mlir/lite/schema/schema_utils.h index 7498aa02ebe..9c32680b851 100644 --- a/tensorflow/compiler/mlir/lite/schema/schema_utils.h +++ b/tensorflow/compiler/mlir/lite/schema/schema_utils.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ +#include + #include "flatbuffers/flatbuffers.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" @@ -28,6 +30,11 @@ BuiltinOperator GetBuiltinCode(const OperatorCode *op_code); BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code); +// Returns the size of the given TensorType in bytes, or 0 if the TensorType is +// not supported, this function should be aligned with TfLiteTypeGetSize in +// lite/kernels/kernel_util.h. +size_t TensorTypeGetSize(::tflite::TensorType data_type); + } // namespace tflite #endif // TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_ diff --git a/tensorflow/lite/kernels/internal/reference/broadcast_to.h b/tensorflow/lite/kernels/internal/reference/broadcast_to.h index f106b2b52f6..0cd03db926d 100644 --- a/tensorflow/lite/kernels/internal/reference/broadcast_to.h +++ b/tensorflow/lite/kernels/internal/reference/broadcast_to.h @@ -15,6 +15,8 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_ +#include + #include "tensorflow/lite/kernels/internal/common.h" #include "tensorflow/lite/kernels/kernel_util.h" @@ -83,7 +85,8 @@ inline void BroadcastTo(const RuntimeShape& unextended_input_shape, // If non-broadcasting, just copy data from input to output tensor. if (last_broadcast_dim == -1) { memcpy(output_data, input_data, - unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type)); + static_cast(unextended_input_shape.FlatSize()) * + static_cast(TfLiteTypeGetSize(data_type))); return; } diff --git a/tensorflow/lite/kernels/internal/reference/slice.h b/tensorflow/lite/kernels/internal/reference/slice.h index cb73ea0d0c4..feddd639584 100644 --- a/tensorflow/lite/kernels/internal/reference/slice.h +++ b/tensorflow/lite/kernels/internal/reference/slice.h @@ -15,7 +15,14 @@ limitations under the License. #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_ #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_ +#include +#include + +#include "tensorflow/lite/core/c/common.h" #include "tensorflow/lite/kernels/internal/portable_tensor.h" +#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h" +#include "tensorflow/lite/kernels/internal/runtime_shape.h" +#include "tensorflow/lite/kernels/internal/tensor_ctypes.h" #include "tensorflow/lite/kernels/internal/types.h" namespace tflite { @@ -74,6 +81,27 @@ inline void Slice(const tflite::SliceParams& op_params, return Slice(op_params, input_shape, output_shape, &writer); } +inline void SliceInt4(const tflite::SliceParams& op_params, + const RuntimeShape& input_shape, + const TfLiteTensor* input, + const RuntimeShape& output_shape, TfLiteTensor* output) { + const int num_input_elements = input_shape.FlatSize(); + std::vector unpacked_input(num_input_elements); + tensor_utils::UnpackPackedIntToInt8(GetTensorData(input), + num_input_elements, 4, + unpacked_input.data()); + + const int num_output_elements = output_shape.FlatSize(); + std::vector unpacked_output(num_output_elements); + + reference_ops::Slice(op_params, input_shape, unpacked_input.data(), + output_shape, unpacked_output.data()); + + tensor_utils::PackInt8IntoDenseInt(unpacked_output.data(), + num_output_elements, 4, + GetTensorData(output)); +} + } // namespace reference_ops } // namespace tflite