blob: fedf7ee150df5cdf1d260780d3546c9058084efc [file] [log] [blame]
Sadik Armagana2747482021-02-09 10:28:54 +00001//
Ryan OShea238ecd92023-03-07 11:44:23 +00002// Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
Sadik Armagana2747482021-02-09 10:28:54 +00003// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
Teresa Charlinad1b3d72023-03-14 12:10:28 +000016#include <schema_generated.h>
Sadik Armagana2747482021-02-09 10:28:54 +000017#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21#include <string>
22
23namespace
24{
25
26std::vector<char> CreateReduceTfLiteModel(tflite::BuiltinOperator reduceOperatorCode,
Teresa Charlin4d85adf2022-10-27 11:37:29 +010027 tflite::TensorType tensorType,
28 std::vector<int32_t>& input0TensorShape,
29 std::vector<int32_t>& input1TensorShape,
30 const std::vector <int32_t>& outputTensorShape,
31 std::vector<int32_t>& axisData,
32 const bool keepDims,
33 float quantScale = 1.0f,
34 int quantOffset = 0,
35 bool kTfLiteNoQuantizationForQuantized = false)
Sadik Armagana2747482021-02-09 10:28:54 +000036{
37 using namespace tflite;
38 flatbuffers::FlatBufferBuilder flatBufferBuilder;
39
Ryan OShea238ecd92023-03-07 11:44:23 +000040 flatbuffers::Offset<tflite::Buffer> buffers[4] = {
41 CreateBuffer(flatBufferBuilder),
42 CreateBuffer(flatBufferBuilder),
43 CreateBuffer(flatBufferBuilder,
44 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(axisData.data()),
45 sizeof(int32_t) * axisData.size())),
46 CreateBuffer(flatBufferBuilder)
47 };
Sadik Armagana2747482021-02-09 10:28:54 +000048
Teresa Charlin4d85adf2022-10-27 11:37:29 +010049 flatbuffers::Offset<tflite::QuantizationParameters> quantizationParametersAxis
Ryan OShea238ecd92023-03-07 11:44:23 +000050 = CreateQuantizationParameters(flatBufferBuilder);
Teresa Charlin4d85adf2022-10-27 11:37:29 +010051
52 flatbuffers::Offset<tflite::QuantizationParameters> quantizationParameters;
53
54 if (kTfLiteNoQuantizationForQuantized)
55 {
56 if ((quantScale == 1 || quantScale == 0) && quantOffset == 0)
57 {
58 // Creates quantization parameter with quantization.type = kTfLiteNoQuantization
59 quantizationParameters = CreateQuantizationParameters(flatBufferBuilder);
60 }
61 else
62 {
63 // Creates quantization parameter with quantization.type != kTfLiteNoQuantization
64 quantizationParameters = CreateQuantizationParameters(
65 flatBufferBuilder,
66 0,
67 0,
68 flatBufferBuilder.CreateVector<float>({quantScale}),
69 flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
70 }
71 }
72 else
73 {
74 quantizationParameters = CreateQuantizationParameters(
75 flatBufferBuilder,
76 0,
77 0,
78 flatBufferBuilder.CreateVector<float>({quantScale}),
79 flatBufferBuilder.CreateVector<int64_t>({quantOffset}));
80 }
Sadik Armagana2747482021-02-09 10:28:54 +000081
82 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
83 tensors[0] = CreateTensor(flatBufferBuilder,
84 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
85 input0TensorShape.size()),
86 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +000087 1,
Sadik Armagana2747482021-02-09 10:28:54 +000088 flatBufferBuilder.CreateString("input"),
89 quantizationParameters);
90
91 tensors[1] = CreateTensor(flatBufferBuilder,
92 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
93 input1TensorShape.size()),
94 ::tflite::TensorType_INT32,
Ryan OShea238ecd92023-03-07 11:44:23 +000095 2,
Sadik Armagana2747482021-02-09 10:28:54 +000096 flatBufferBuilder.CreateString("axis"),
Teresa Charlin4d85adf2022-10-27 11:37:29 +010097 quantizationParametersAxis);
Sadik Armagana2747482021-02-09 10:28:54 +000098
99 // Create output tensor
100 tensors[2] = CreateTensor(flatBufferBuilder,
101 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
102 outputTensorShape.size()),
103 tensorType,
Ryan OShea238ecd92023-03-07 11:44:23 +0000104 3,
Sadik Armagana2747482021-02-09 10:28:54 +0000105 flatBufferBuilder.CreateString("output"),
106 quantizationParameters);
107
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100108 // Create operator. Reduce operations MIN, MAX, SUM, MEAN, PROD uses ReducerOptions.
Sadik Armagana2747482021-02-09 10:28:54 +0000109 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_ReducerOptions;
110 flatbuffers::Offset<void> operatorBuiltinOptions = CreateReducerOptions(flatBufferBuilder, keepDims).Union();
111
112 const std::vector<int> operatorInputs{ {0, 1} };
113 const std::vector<int> operatorOutputs{ 2 };
114 flatbuffers::Offset <Operator> reduceOperator =
115 CreateOperator(flatBufferBuilder,
116 0,
117 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
118 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
119 operatorBuiltinOptionsType,
120 operatorBuiltinOptions);
121
122 const std::vector<int> subgraphInputs{ {0, 1} };
123 const std::vector<int> subgraphOutputs{ 2 };
124 flatbuffers::Offset <SubGraph> subgraph =
125 CreateSubGraph(flatBufferBuilder,
126 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
127 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
128 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
129 flatBufferBuilder.CreateVector(&reduceOperator, 1));
130
131 flatbuffers::Offset <flatbuffers::String> modelDescription =
132 flatBufferBuilder.CreateString("ArmnnDelegate: Reduce Operator Model");
133 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, reduceOperatorCode);
134
135 flatbuffers::Offset <Model> flatbufferModel =
136 CreateModel(flatBufferBuilder,
137 TFLITE_SCHEMA_VERSION,
138 flatBufferBuilder.CreateVector(&operatorCode, 1),
139 flatBufferBuilder.CreateVector(&subgraph, 1),
140 modelDescription,
Ryan OShea238ecd92023-03-07 11:44:23 +0000141 flatBufferBuilder.CreateVector(buffers, 4));
Sadik Armagana2747482021-02-09 10:28:54 +0000142
143 flatBufferBuilder.Finish(flatbufferModel);
144
145 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
146 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
147}
148
149template <typename T>
150void ReduceTest(tflite::BuiltinOperator reduceOperatorCode,
151 tflite::TensorType tensorType,
152 std::vector<armnn::BackendId>& backends,
153 std::vector<int32_t>& input0Shape,
154 std::vector<int32_t>& input1Shape,
155 std::vector<int32_t>& expectedOutputShape,
156 std::vector<T>& input0Values,
157 std::vector<int32_t>& input1Values,
158 std::vector<T>& expectedOutputValues,
159 const bool keepDims,
160 float quantScale = 1.0f,
161 int quantOffset = 0)
162{
163 using namespace tflite;
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100164 std::vector<char> modelBufferArmNN = CreateReduceTfLiteModel(reduceOperatorCode,
165 tensorType,
166 input0Shape,
167 input1Shape,
168 expectedOutputShape,
169 input1Values,
170 keepDims,
171 quantScale,
172 quantOffset,
173 false);
174 std::vector<char> modelBufferTFLite = CreateReduceTfLiteModel(reduceOperatorCode,
175 tensorType,
176 input0Shape,
177 input1Shape,
178 expectedOutputShape,
179 input1Values,
180 keepDims,
181 quantScale,
182 quantOffset,
183 true);
Sadik Armagana2747482021-02-09 10:28:54 +0000184
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100185 const Model* tfLiteModelArmNN = GetModel(modelBufferArmNN.data());
186 const Model* tfLiteModelTFLite = GetModel(modelBufferTFLite.data());
Sadik Armagana2747482021-02-09 10:28:54 +0000187
188 // Create TfLite Interpreters
189 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100190 CHECK(InterpreterBuilder(tfLiteModelArmNN, ::tflite::ops::builtin::BuiltinOpResolver())
Sadik Armagana2747482021-02-09 10:28:54 +0000191 (&armnnDelegateInterpreter) == kTfLiteOk);
192 CHECK(armnnDelegateInterpreter != nullptr);
193 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
194
195 std::unique_ptr<Interpreter> tfLiteInterpreter;
Teresa Charlin4d85adf2022-10-27 11:37:29 +0100196 CHECK(InterpreterBuilder(tfLiteModelTFLite, ::tflite::ops::builtin::BuiltinOpResolver())
Sadik Armagana2747482021-02-09 10:28:54 +0000197 (&tfLiteInterpreter) == kTfLiteOk);
198 CHECK(tfLiteInterpreter != nullptr);
199 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
200
201 // Create the ArmNN Delegate
202 armnnDelegate::DelegateOptions delegateOptions(backends);
203 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
204 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
205 armnnDelegate::TfLiteArmnnDelegateDelete);
206 CHECK(theArmnnDelegate != nullptr);
207
208 // Modify armnnDelegateInterpreter to use armnnDelegate
209 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
210
211 // Set input data
212 armnnDelegate::FillInput<T>(tfLiteInterpreter, 0, input0Values);
213 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 0, input0Values);
214
215 // Run EnqueWorkload
216 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
217 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
218
219 // Compare output data
220 armnnDelegate::CompareOutputData<T>(tfLiteInterpreter,
221 armnnDelegateInterpreter,
222 expectedOutputShape,
223 expectedOutputValues);
224
225 armnnDelegateInterpreter.reset(nullptr);
226}
227
228} // anonymous namespace