blob: 396c75c22e59d9253f2c28cd0fe0e8b4089128d6 [file] [log] [blame]
Matthew Sloyanebe392d2023-03-30 10:12:08 +01001//
2// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <armnn/Exceptions.hpp>
9
10#include <tensorflow/lite/core/c/c_api.h>
11#include <tensorflow/lite/kernels/custom_ops_register.h>
12#include <tensorflow/lite/kernels/register.h>
13
14#include <type_traits>
15
16namespace delegateTestInterpreter
17{
18
19inline TfLiteTensor* GetInputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index)
20{
21 TfLiteTensor* inputTensor = TfLiteInterpreterGetInputTensor(interpreter, index);
22 if(inputTensor == nullptr)
23 {
24 throw armnn::Exception("Input tensor was not found at the given index: " + std::to_string(index));
25 }
26 return inputTensor;
27}
28
29inline const TfLiteTensor* GetOutputTensorFromInterpreter(TfLiteInterpreter* interpreter, int index)
30{
31 const TfLiteTensor* outputTensor = TfLiteInterpreterGetOutputTensor(interpreter, index);
32 if(outputTensor == nullptr)
33 {
34 throw armnn::Exception("Output tensor was not found at the given index: " + std::to_string(index));
35 }
36 return outputTensor;
37}
38
39inline TfLiteModel* CreateTfLiteModel(std::vector<char>& data)
40{
41 TfLiteModel* tfLiteModel = TfLiteModelCreate(data.data(), data.size());
42 if(tfLiteModel == nullptr)
43 {
44 throw armnn::Exception("An error has occurred when creating the TfLiteModel.");
45 }
46 return tfLiteModel;
47}
48
49inline TfLiteInterpreterOptions* CreateTfLiteInterpreterOptions()
50{
51 TfLiteInterpreterOptions* options = TfLiteInterpreterOptionsCreate();
52 if(options == nullptr)
53 {
54 throw armnn::Exception("An error has occurred when creating the TfLiteInterpreterOptions.");
55 }
56 return options;
57}
58
59inline tflite::ops::builtin::BuiltinOpResolver GenerateCustomOpResolver(const std::string& opName)
60{
61 tflite::ops::builtin::BuiltinOpResolver opResolver;
62 if (opName == "MaxPool3D")
63 {
64 opResolver.AddCustom("MaxPool3D", tflite::ops::custom::Register_MAX_POOL_3D());
65 }
66 else if (opName == "AveragePool3D")
67 {
68 opResolver.AddCustom("AveragePool3D", tflite::ops::custom::Register_AVG_POOL_3D());
69 }
70 else
71 {
72 throw armnn::Exception("The custom op isn't supported by the DelegateTestInterpreter.");
73 }
74 return opResolver;
75}
76
77template<typename T>
78inline TfLiteStatus CopyFromBufferToTensor(TfLiteTensor* tensor, std::vector<T>& values)
79{
80 // Make sure there is enough bytes allocated to copy into for uint8_t and int16_t case.
81 if(tensor->bytes < values.size() * sizeof(T))
82 {
83 throw armnn::Exception("Tensor has not been allocated to match number of values.");
84 }
85
86 // Requires uint8_t and int16_t specific case as the number of bytes is larger than values passed when creating
87 // TFLite tensors of these types. Otherwise, use generic TfLiteTensorCopyFromBuffer function.
88 TfLiteStatus status = kTfLiteOk;
89 if (std::is_same<T, uint8_t>::value)
90 {
91 for (unsigned int i = 0; i < values.size(); ++i)
92 {
93 tensor->data.uint8[i] = values[i];
94 }
95 }
96 else if (std::is_same<T, int16_t>::value)
97 {
98 for (unsigned int i = 0; i < values.size(); ++i)
99 {
100 tensor->data.i16[i] = values[i];
101 }
102 }
103 else
104 {
105 status = TfLiteTensorCopyFromBuffer(tensor, values.data(), values.size() * sizeof(T));
106 }
107 return status;
108}
109
110} // anonymous namespace