blob: 69b0c88dc875e25cc48400d3ee8293721c2fb99d [file] [log] [blame]
Sadik Armagan67e95f22020-10-29 16:14:54 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armaganf7ac72c2021-05-05 15:03:50 +01008#include "TestUtils.hpp"
9
Sadik Armagan67e95f22020-10-29 16:14:54 +000010#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
16#include <tensorflow/lite/schema/schema_generated.h>
17#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21namespace
22{
23
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010024template <typename T>
Sadik Armagan67e95f22020-10-29 16:14:54 +000025std::vector<char> CreateElementwiseBinaryTfLiteModel(tflite::BuiltinOperator binaryOperatorCode,
26 tflite::ActivationFunctionType activationType,
27 tflite::TensorType tensorType,
28 const std::vector <int32_t>& input0TensorShape,
29 const std::vector <int32_t>& input1TensorShape,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000030 const std::vector <int32_t>& outputTensorShape,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010031 std::vector<T>& input1Values,
32 bool constantInput = false,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000033 float quantScale = 1.0f,
34 int quantOffset = 0)
Sadik Armagan67e95f22020-10-29 16:14:54 +000035{
36 using namespace tflite;
37 flatbuffers::FlatBufferBuilder flatBufferBuilder;
38
39 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
40 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010041 if (constantInput)
42 {
43 buffers.push_back(
44 CreateBuffer(flatBufferBuilder,
45 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(input1Values.data()),
46 sizeof(T) * input1Values.size())));
47 }
48 else
49 {
50 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
51 }
52 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
Sadik Armagan67e95f22020-10-29 16:14:54 +000053
Sadik Armagan21a94ff2020-11-09 08:38:30 +000054 auto quantizationParameters =
55 CreateQuantizationParameters(flatBufferBuilder,
56 0,
57 0,
58 flatBufferBuilder.CreateVector<float>({ quantScale }),
59 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
60
61
Sadik Armagan67e95f22020-10-29 16:14:54 +000062 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
63 tensors[0] = CreateTensor(flatBufferBuilder,
64 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
65 input0TensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000066 tensorType,
67 0,
68 flatBufferBuilder.CreateString("input_0"),
69 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000070 tensors[1] = CreateTensor(flatBufferBuilder,
71 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
72 input1TensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000073 tensorType,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010074 1,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000075 flatBufferBuilder.CreateString("input_1"),
76 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000077 tensors[2] = CreateTensor(flatBufferBuilder,
78 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
79 outputTensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000080 tensorType,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010081 2,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000082 flatBufferBuilder.CreateString("output"),
83 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000084
85 // create operator
86 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_NONE;
87 flatbuffers::Offset<void> operatorBuiltinOptions = 0;
88 switch (binaryOperatorCode)
89 {
90 case BuiltinOperator_ADD:
91 {
92 operatorBuiltinOptionsType = BuiltinOptions_AddOptions;
93 operatorBuiltinOptions = CreateAddOptions(flatBufferBuilder, activationType).Union();
94 break;
95 }
96 case BuiltinOperator_DIV:
97 {
98 operatorBuiltinOptionsType = BuiltinOptions_DivOptions;
99 operatorBuiltinOptions = CreateDivOptions(flatBufferBuilder, activationType).Union();
100 break;
101 }
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000102 case BuiltinOperator_MAXIMUM:
103 {
104 operatorBuiltinOptionsType = BuiltinOptions_MaximumMinimumOptions;
105 operatorBuiltinOptions = CreateMaximumMinimumOptions(flatBufferBuilder).Union();
106 break;
107 }
108 case BuiltinOperator_MINIMUM:
109 {
110 operatorBuiltinOptionsType = BuiltinOptions_MaximumMinimumOptions;
111 operatorBuiltinOptions = CreateMaximumMinimumOptions(flatBufferBuilder).Union();
112 break;
113 }
Sadik Armagan67e95f22020-10-29 16:14:54 +0000114 case BuiltinOperator_MUL:
115 {
116 operatorBuiltinOptionsType = BuiltinOptions_MulOptions;
117 operatorBuiltinOptions = CreateMulOptions(flatBufferBuilder, activationType).Union();
118 break;
119 }
120 case BuiltinOperator_SUB:
121 {
122 operatorBuiltinOptionsType = BuiltinOptions_SubOptions;
123 operatorBuiltinOptions = CreateSubOptions(flatBufferBuilder, activationType).Union();
124 break;
125 }
Jim Flynn4b2f3472021-10-13 21:20:07 +0100126 case BuiltinOperator_FLOOR_DIV:
127 {
128 operatorBuiltinOptionsType = tflite::BuiltinOptions_FloorDivOptions;
129 operatorBuiltinOptions = CreateSubOptions(flatBufferBuilder, activationType).Union();
130 break;
131 }
Sadik Armagan67e95f22020-10-29 16:14:54 +0000132 default:
133 break;
134 }
Keith Davis892fafe2020-11-26 17:40:35 +0000135 const std::vector<int32_t> operatorInputs{0, 1};
136 const std::vector<int32_t> operatorOutputs{2};
Sadik Armagan67e95f22020-10-29 16:14:54 +0000137 flatbuffers::Offset <Operator> elementwiseBinaryOperator =
138 CreateOperator(flatBufferBuilder,
139 0,
140 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
141 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
142 operatorBuiltinOptionsType,
143 operatorBuiltinOptions);
144
Keith Davis892fafe2020-11-26 17:40:35 +0000145 const std::vector<int> subgraphInputs{0, 1};
146 const std::vector<int> subgraphOutputs{2};
Sadik Armagan67e95f22020-10-29 16:14:54 +0000147 flatbuffers::Offset <SubGraph> subgraph =
148 CreateSubGraph(flatBufferBuilder,
149 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
150 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
151 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
152 flatBufferBuilder.CreateVector(&elementwiseBinaryOperator, 1));
153
154 flatbuffers::Offset <flatbuffers::String> modelDescription =
155 flatBufferBuilder.CreateString("ArmnnDelegate: Elementwise Binary Operator Model");
156 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, binaryOperatorCode);
157
158 flatbuffers::Offset <Model> flatbufferModel =
159 CreateModel(flatBufferBuilder,
160 TFLITE_SCHEMA_VERSION,
161 flatBufferBuilder.CreateVector(&operatorCode, 1),
162 flatBufferBuilder.CreateVector(&subgraph, 1),
163 modelDescription,
164 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
165
166 flatBufferBuilder.Finish(flatbufferModel);
167
168 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
169 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
170}
171
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000172template <typename T>
173void ElementwiseBinaryTest(tflite::BuiltinOperator binaryOperatorCode,
174 tflite::ActivationFunctionType activationType,
175 tflite::TensorType tensorType,
176 std::vector<armnn::BackendId>& backends,
177 std::vector<int32_t>& input0Shape,
178 std::vector<int32_t>& input1Shape,
179 std::vector<int32_t>& outputShape,
180 std::vector<T>& input0Values,
181 std::vector<T>& input1Values,
182 std::vector<T>& expectedOutputValues,
183 float quantScale = 1.0f,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100184 int quantOffset = 0,
185 bool constantInput = false)
Sadik Armagan67e95f22020-10-29 16:14:54 +0000186{
187 using namespace tflite;
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100188 std::vector<char> modelBuffer = CreateElementwiseBinaryTfLiteModel<T>(binaryOperatorCode,
189 activationType,
190 tensorType,
191 input0Shape,
192 input1Shape,
193 outputShape,
194 input1Values,
195 constantInput,
196 quantScale,
197 quantOffset);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000198
199 const Model* tfLiteModel = GetModel(modelBuffer.data());
200 // Create TfLite Interpreters
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100201 std::unique_ptr <Interpreter> armnnDelegateInterpreter;
Sadik Armagan67e95f22020-10-29 16:14:54 +0000202 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
203 (&armnnDelegateInterpreter) == kTfLiteOk);
204 CHECK(armnnDelegateInterpreter != nullptr);
205 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
206
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100207 std::unique_ptr <Interpreter> tfLiteInterpreter;
Sadik Armagan67e95f22020-10-29 16:14:54 +0000208 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
209 (&tfLiteInterpreter) == kTfLiteOk);
210 CHECK(tfLiteInterpreter != nullptr);
211 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
212
213 // Create the ArmNN Delegate
214 armnnDelegate::DelegateOptions delegateOptions(backends);
215 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100216 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
217 armnnDelegate::TfLiteArmnnDelegateDelete);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000218 CHECK(theArmnnDelegate != nullptr);
219 // Modify armnnDelegateInterpreter to use armnnDelegate
220 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
221
222 // Set input data
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100223 armnnDelegate::FillInput<T>(tfLiteInterpreter, 0, input0Values);
224 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 0, input0Values);
225 if (!constantInput)
Sadik Armagan67e95f22020-10-29 16:14:54 +0000226 {
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100227 armnnDelegate::FillInput<T>(tfLiteInterpreter, 1, input1Values);
228 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 1, input1Values);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000229 }
Sadik Armagan67e95f22020-10-29 16:14:54 +0000230 // Run EnqueWorkload
231 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
232 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
233
234 // Compare output data
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100235 armnnDelegate::CompareOutputData<T>(tfLiteInterpreter,
236 armnnDelegateInterpreter,
237 outputShape,
238 expectedOutputValues);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000239 armnnDelegateInterpreter.reset(nullptr);
240}
241
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000242} // anonymous namespace