blob: 5457adbd0f60d03e50270a9b81de4d93e5e8d517 [file] [log] [blame]
Sadik Armagana2747482021-02-09 10:28:54 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
16#include <tensorflow/lite/schema/schema_generated.h>
17#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21#include <string>
22
23namespace
24{
25
26std::vector<char> CreateReduceTfLiteModel(tflite::BuiltinOperator reduceOperatorCode,
Teresa Charlin4d85adf2022-10-27 11:37:29 +010027 tflite::TensorType tensorType,
28 std::vector<int32_t>& input0TensorShape,
29 std::vector<int32_t>& input1TensorShape,
30 const std::vector <int32_t>& outputTensorShape,
31 std::vector<int32_t>& axisData,
32 const bool keepDims,
33 float quantScale = 1.0f,
34 int quantOffset = 0,
35 bool kTfLiteNoQuantizationForQuantized = false)
Sadik Armagana2747482021-02-09 10:28:54 +000036{
37 using namespace tflite;
38 flatbuffers::FlatBufferBuilder flatBufferBuilder;
39
40 std::array<flatbuffers::Offset<tflite::Buffer>, 2> buffers;
41 buffers[0] = CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({}));
42 buffers[1] = CreateBuffer(flatBufferBuilder,
43 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(axisData.data()),
44 sizeof(int32_t) * axisData.size()));
45
Teresa Charlin4d85adf2022-10-27 11:37:29 +010046 flatbuffers::Offset<tflite::QuantizationParameters> quantizationParametersAxis
47 = CreateQuantizationParameters(flatBufferBuilder);
48
49 flatbuffers::Offset<tflite::QuantizationParameters> quantizationParameters;
50
51 if (kTfLiteNoQuantizationForQuantized)
52 {
53 if ((quantScale == 1 || quantScale == 0) && quantOffset == 0)
54 {
55 // Creates quantization parameter with quantization.type = kTfLiteNoQuantization
56 quantizationParameters = CreateQuantizationParameters(flatBufferBuilder);
57 }
58 else
59 {
60 // Creates quantization parameter with quantization.type != kTfLiteNoQuantization
61 quantizationParameters = CreateQuantizationParameters(
62 flatBufferBuilder,
63 0,
64 0,
65 flatBufferBuilder.CreateVector<float>({quantScale}),
66 flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
67 }
68 }
69 else
70 {
71 quantizationParameters = CreateQuantizationParameters(
72 flatBufferBuilder,
73 0,
74 0,
75 flatBufferBuilder.CreateVector<float>({quantScale}),
76 flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
77 }
Sadik Armagana2747482021-02-09 10:28:54 +000078
79 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
80 tensors[0] = CreateTensor(flatBufferBuilder,
81 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
82 input0TensorShape.size()),
83 tensorType,
84 0,
85 flatBufferBuilder.CreateString("input"),
86 quantizationParameters);
87
88 tensors[1] = CreateTensor(flatBufferBuilder,
89 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
90 input1TensorShape.size()),
91 ::tflite::TensorType_INT32,
92 1,
93 flatBufferBuilder.CreateString("axis"),
Teresa Charlin4d85adf2022-10-27 11:37:29 +010094 quantizationParametersAxis);
Sadik Armagana2747482021-02-09 10:28:54 +000095
96 // Create output tensor
97 tensors[2] = CreateTensor(flatBufferBuilder,
98 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
99 outputTensorShape.size()),
100 tensorType,
101 0,
102 flatBufferBuilder.CreateString("output"),
103 quantizationParameters);
104
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100105 // Create operator. Reduce operations MIN, MAX, SUM, MEAN, PROD uses ReducerOptions.
Sadik Armagana2747482021-02-09 10:28:54 +0000106 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_ReducerOptions;
107 flatbuffers::Offset<void> operatorBuiltinOptions = CreateReducerOptions(flatBufferBuilder, keepDims).Union();
108
109 const std::vector<int> operatorInputs{ {0, 1} };
110 const std::vector<int> operatorOutputs{ 2 };
111 flatbuffers::Offset <Operator> reduceOperator =
112 CreateOperator(flatBufferBuilder,
113 0,
114 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
115 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
116 operatorBuiltinOptionsType,
117 operatorBuiltinOptions);
118
119 const std::vector<int> subgraphInputs{ {0, 1} };
120 const std::vector<int> subgraphOutputs{ 2 };
121 flatbuffers::Offset <SubGraph> subgraph =
122 CreateSubGraph(flatBufferBuilder,
123 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
124 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
125 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
126 flatBufferBuilder.CreateVector(&reduceOperator, 1));
127
128 flatbuffers::Offset <flatbuffers::String> modelDescription =
129 flatBufferBuilder.CreateString("ArmnnDelegate: Reduce Operator Model");
130 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, reduceOperatorCode);
131
132 flatbuffers::Offset <Model> flatbufferModel =
133 CreateModel(flatBufferBuilder,
134 TFLITE_SCHEMA_VERSION,
135 flatBufferBuilder.CreateVector(&operatorCode, 1),
136 flatBufferBuilder.CreateVector(&subgraph, 1),
137 modelDescription,
138 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
139
140 flatBufferBuilder.Finish(flatbufferModel);
141
142 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
143 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
144}
145
146template <typename T>
147void ReduceTest(tflite::BuiltinOperator reduceOperatorCode,
148 tflite::TensorType tensorType,
149 std::vector<armnn::BackendId>& backends,
150 std::vector<int32_t>& input0Shape,
151 std::vector<int32_t>& input1Shape,
152 std::vector<int32_t>& expectedOutputShape,
153 std::vector<T>& input0Values,
154 std::vector<int32_t>& input1Values,
155 std::vector<T>& expectedOutputValues,
156 const bool keepDims,
157 float quantScale = 1.0f,
158 int quantOffset = 0)
159{
160 using namespace tflite;
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100161 std::vector<char> modelBufferArmNN = CreateReduceTfLiteModel(reduceOperatorCode,
162 tensorType,
163 input0Shape,
164 input1Shape,
165 expectedOutputShape,
166 input1Values,
167 keepDims,
168 quantScale,
169 quantOffset,
170 false);
171 std::vector<char> modelBufferTFLite = CreateReduceTfLiteModel(reduceOperatorCode,
172 tensorType,
173 input0Shape,
174 input1Shape,
175 expectedOutputShape,
176 input1Values,
177 keepDims,
178 quantScale,
179 quantOffset,
180 true);
Sadik Armagana2747482021-02-09 10:28:54 +0000181
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100182 const Model* tfLiteModelArmNN = GetModel(modelBufferArmNN.data());
183 const Model* tfLiteModelTFLite = GetModel(modelBufferTFLite.data());
Sadik Armagana2747482021-02-09 10:28:54 +0000184
185 // Create TfLite Interpreters
186 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100187 CHECK(InterpreterBuilder(tfLiteModelArmNN, ::tflite::ops::builtin::BuiltinOpResolver())
Sadik Armagana2747482021-02-09 10:28:54 +0000188 (&armnnDelegateInterpreter) == kTfLiteOk);
189 CHECK(armnnDelegateInterpreter != nullptr);
190 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
191
192 std::unique_ptr<Interpreter> tfLiteInterpreter;
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100193 CHECK(InterpreterBuilder(tfLiteModelTFLite, ::tflite::ops::builtin::BuiltinOpResolver())
Sadik Armagana2747482021-02-09 10:28:54 +0000194 (&tfLiteInterpreter) == kTfLiteOk);
195 CHECK(tfLiteInterpreter != nullptr);
196 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
197
198 // Create the ArmNN Delegate
199 armnnDelegate::DelegateOptions delegateOptions(backends);
200 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
201 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
202 armnnDelegate::TfLiteArmnnDelegateDelete);
203 CHECK(theArmnnDelegate != nullptr);
204
205 // Modify armnnDelegateInterpreter to use armnnDelegate
206 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
207
208 // Set input data
209 armnnDelegate::FillInput<T>(tfLiteInterpreter, 0, input0Values);
210 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 0, input0Values);
211
212 // Run EnqueWorkload
213 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
214 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
215
216 // Compare output data
217 armnnDelegate::CompareOutputData<T>(tfLiteInterpreter,
218 armnnDelegateInterpreter,
219 expectedOutputShape,
220 expectedOutputValues);
221
222 armnnDelegateInterpreter.reset(nullptr);
223}
224
225} // anonymous namespace