blob: 13b336e91ea4b4ee2afc82fb6451dadae26092b3 [file] [log] [blame]
Sadik Armagan67e95f22020-10-29 16:14:54 +00001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
Sadik Armaganf7ac72c2021-05-05 15:03:50 +01008#include "TestUtils.hpp"
9
Sadik Armagan67e95f22020-10-29 16:14:54 +000010#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
16#include <tensorflow/lite/schema/schema_generated.h>
17#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21namespace
22{
23
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010024template <typename T>
Sadik Armagan67e95f22020-10-29 16:14:54 +000025std::vector<char> CreateElementwiseBinaryTfLiteModel(tflite::BuiltinOperator binaryOperatorCode,
26 tflite::ActivationFunctionType activationType,
27 tflite::TensorType tensorType,
28 const std::vector <int32_t>& input0TensorShape,
29 const std::vector <int32_t>& input1TensorShape,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000030 const std::vector <int32_t>& outputTensorShape,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010031 std::vector<T>& input1Values,
32 bool constantInput = false,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000033 float quantScale = 1.0f,
34 int quantOffset = 0)
Sadik Armagan67e95f22020-10-29 16:14:54 +000035{
36 using namespace tflite;
37 flatbuffers::FlatBufferBuilder flatBufferBuilder;
38
39 std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
40 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010041 if (constantInput)
42 {
43 buffers.push_back(
44 CreateBuffer(flatBufferBuilder,
45 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(input1Values.data()),
46 sizeof(T) * input1Values.size())));
47 }
48 else
49 {
50 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
51 }
52 buffers.push_back(CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({})));
Sadik Armagan67e95f22020-10-29 16:14:54 +000053
Sadik Armagan21a94ff2020-11-09 08:38:30 +000054 auto quantizationParameters =
55 CreateQuantizationParameters(flatBufferBuilder,
56 0,
57 0,
58 flatBufferBuilder.CreateVector<float>({ quantScale }),
59 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
60
61
Sadik Armagan67e95f22020-10-29 16:14:54 +000062 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
63 tensors[0] = CreateTensor(flatBufferBuilder,
64 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
65 input0TensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000066 tensorType,
67 0,
68 flatBufferBuilder.CreateString("input_0"),
69 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000070 tensors[1] = CreateTensor(flatBufferBuilder,
71 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
72 input1TensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000073 tensorType,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010074 1,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000075 flatBufferBuilder.CreateString("input_1"),
76 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000077 tensors[2] = CreateTensor(flatBufferBuilder,
78 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
79 outputTensorShape.size()),
Sadik Armagan21a94ff2020-11-09 08:38:30 +000080 tensorType,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +010081 2,
Sadik Armagan21a94ff2020-11-09 08:38:30 +000082 flatBufferBuilder.CreateString("output"),
83 quantizationParameters);
Sadik Armagan67e95f22020-10-29 16:14:54 +000084
85 // create operator
86 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_NONE;
87 flatbuffers::Offset<void> operatorBuiltinOptions = 0;
88 switch (binaryOperatorCode)
89 {
90 case BuiltinOperator_ADD:
91 {
92 operatorBuiltinOptionsType = BuiltinOptions_AddOptions;
93 operatorBuiltinOptions = CreateAddOptions(flatBufferBuilder, activationType).Union();
94 break;
95 }
96 case BuiltinOperator_DIV:
97 {
98 operatorBuiltinOptionsType = BuiltinOptions_DivOptions;
99 operatorBuiltinOptions = CreateDivOptions(flatBufferBuilder, activationType).Union();
100 break;
101 }
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000102 case BuiltinOperator_MAXIMUM:
103 {
104 operatorBuiltinOptionsType = BuiltinOptions_MaximumMinimumOptions;
105 operatorBuiltinOptions = CreateMaximumMinimumOptions(flatBufferBuilder).Union();
106 break;
107 }
108 case BuiltinOperator_MINIMUM:
109 {
110 operatorBuiltinOptionsType = BuiltinOptions_MaximumMinimumOptions;
111 operatorBuiltinOptions = CreateMaximumMinimumOptions(flatBufferBuilder).Union();
112 break;
113 }
Sadik Armagan67e95f22020-10-29 16:14:54 +0000114 case BuiltinOperator_MUL:
115 {
116 operatorBuiltinOptionsType = BuiltinOptions_MulOptions;
117 operatorBuiltinOptions = CreateMulOptions(flatBufferBuilder, activationType).Union();
118 break;
119 }
120 case BuiltinOperator_SUB:
121 {
122 operatorBuiltinOptionsType = BuiltinOptions_SubOptions;
123 operatorBuiltinOptions = CreateSubOptions(flatBufferBuilder, activationType).Union();
124 break;
125 }
126 default:
127 break;
128 }
Keith Davis892fafe2020-11-26 17:40:35 +0000129 const std::vector<int32_t> operatorInputs{0, 1};
130 const std::vector<int32_t> operatorOutputs{2};
Sadik Armagan67e95f22020-10-29 16:14:54 +0000131 flatbuffers::Offset <Operator> elementwiseBinaryOperator =
132 CreateOperator(flatBufferBuilder,
133 0,
134 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
135 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
136 operatorBuiltinOptionsType,
137 operatorBuiltinOptions);
138
Keith Davis892fafe2020-11-26 17:40:35 +0000139 const std::vector<int> subgraphInputs{0, 1};
140 const std::vector<int> subgraphOutputs{2};
Sadik Armagan67e95f22020-10-29 16:14:54 +0000141 flatbuffers::Offset <SubGraph> subgraph =
142 CreateSubGraph(flatBufferBuilder,
143 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
144 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
145 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
146 flatBufferBuilder.CreateVector(&elementwiseBinaryOperator, 1));
147
148 flatbuffers::Offset <flatbuffers::String> modelDescription =
149 flatBufferBuilder.CreateString("ArmnnDelegate: Elementwise Binary Operator Model");
150 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, binaryOperatorCode);
151
152 flatbuffers::Offset <Model> flatbufferModel =
153 CreateModel(flatBufferBuilder,
154 TFLITE_SCHEMA_VERSION,
155 flatBufferBuilder.CreateVector(&operatorCode, 1),
156 flatBufferBuilder.CreateVector(&subgraph, 1),
157 modelDescription,
158 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
159
160 flatBufferBuilder.Finish(flatbufferModel);
161
162 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
163 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
164}
165
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000166template <typename T>
167void ElementwiseBinaryTest(tflite::BuiltinOperator binaryOperatorCode,
168 tflite::ActivationFunctionType activationType,
169 tflite::TensorType tensorType,
170 std::vector<armnn::BackendId>& backends,
171 std::vector<int32_t>& input0Shape,
172 std::vector<int32_t>& input1Shape,
173 std::vector<int32_t>& outputShape,
174 std::vector<T>& input0Values,
175 std::vector<T>& input1Values,
176 std::vector<T>& expectedOutputValues,
177 float quantScale = 1.0f,
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100178 int quantOffset = 0,
179 bool constantInput = false)
Sadik Armagan67e95f22020-10-29 16:14:54 +0000180{
181 using namespace tflite;
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100182 std::vector<char> modelBuffer = CreateElementwiseBinaryTfLiteModel<T>(binaryOperatorCode,
183 activationType,
184 tensorType,
185 input0Shape,
186 input1Shape,
187 outputShape,
188 input1Values,
189 constantInput,
190 quantScale,
191 quantOffset);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000192
193 const Model* tfLiteModel = GetModel(modelBuffer.data());
194 // Create TfLite Interpreters
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100195 std::unique_ptr <Interpreter> armnnDelegateInterpreter;
Sadik Armagan67e95f22020-10-29 16:14:54 +0000196 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
197 (&armnnDelegateInterpreter) == kTfLiteOk);
198 CHECK(armnnDelegateInterpreter != nullptr);
199 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
200
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100201 std::unique_ptr <Interpreter> tfLiteInterpreter;
Sadik Armagan67e95f22020-10-29 16:14:54 +0000202 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
203 (&tfLiteInterpreter) == kTfLiteOk);
204 CHECK(tfLiteInterpreter != nullptr);
205 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
206
207 // Create the ArmNN Delegate
208 armnnDelegate::DelegateOptions delegateOptions(backends);
209 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100210 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
211 armnnDelegate::TfLiteArmnnDelegateDelete);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000212 CHECK(theArmnnDelegate != nullptr);
213 // Modify armnnDelegateInterpreter to use armnnDelegate
214 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
215
216 // Set input data
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100217 armnnDelegate::FillInput<T>(tfLiteInterpreter, 0, input0Values);
218 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 0, input0Values);
219 if (!constantInput)
Sadik Armagan67e95f22020-10-29 16:14:54 +0000220 {
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100221 armnnDelegate::FillInput<T>(tfLiteInterpreter, 1, input1Values);
222 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 1, input1Values);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000223 }
Sadik Armagan67e95f22020-10-29 16:14:54 +0000224 // Run EnqueWorkload
225 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
226 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
227
228 // Compare output data
Sadik Armaganf7ac72c2021-05-05 15:03:50 +0100229 armnnDelegate::CompareOutputData<T>(tfLiteInterpreter,
230 armnnDelegateInterpreter,
231 outputShape,
232 expectedOutputValues);
Sadik Armagan67e95f22020-10-29 16:14:54 +0000233 armnnDelegateInterpreter.reset(nullptr);
234}
235
Sadik Armagan21a94ff2020-11-09 08:38:30 +0000236} // anonymous namespace