blob: b41fcfa39b2fd1244d49baba928730668f375a44 [file] [log] [blame]
Sadik Armagana2747482021-02-09 10:28:54 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11
12#include <flatbuffers/flatbuffers.h>
13#include <tensorflow/lite/interpreter.h>
14#include <tensorflow/lite/kernels/register.h>
15#include <tensorflow/lite/model.h>
16#include <tensorflow/lite/schema/schema_generated.h>
17#include <tensorflow/lite/version.h>
18
19#include <doctest/doctest.h>
20
21#include <string>
22
23namespace
24{
25
26std::vector<char> CreateReduceTfLiteModel(tflite::BuiltinOperator reduceOperatorCode,
27 tflite::TensorType tensorType,
28 std::vector<int32_t>& input0TensorShape,
29 std::vector<int32_t>& input1TensorShape,
30 const std::vector <int32_t>& outputTensorShape,
31 std::vector<int32_t>& axisData,
32 const bool keepDims,
33 float quantScale = 1.0f,
34 int quantOffset = 0)
35{
36 using namespace tflite;
37 flatbuffers::FlatBufferBuilder flatBufferBuilder;
38
39 std::array<flatbuffers::Offset<tflite::Buffer>, 2> buffers;
40 buffers[0] = CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({}));
41 buffers[1] = CreateBuffer(flatBufferBuilder,
42 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(axisData.data()),
43 sizeof(int32_t) * axisData.size()));
44
45 auto quantizationParameters =
46 CreateQuantizationParameters(flatBufferBuilder,
47 0,
48 0,
49 flatBufferBuilder.CreateVector<float>({ quantScale }),
50 flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
51
52 std::array<flatbuffers::Offset<Tensor>, 3> tensors;
53 tensors[0] = CreateTensor(flatBufferBuilder,
54 flatBufferBuilder.CreateVector<int32_t>(input0TensorShape.data(),
55 input0TensorShape.size()),
56 tensorType,
57 0,
58 flatBufferBuilder.CreateString("input"),
59 quantizationParameters);
60
61 tensors[1] = CreateTensor(flatBufferBuilder,
62 flatBufferBuilder.CreateVector<int32_t>(input1TensorShape.data(),
63 input1TensorShape.size()),
64 ::tflite::TensorType_INT32,
65 1,
66 flatBufferBuilder.CreateString("axis"),
67 quantizationParameters);
68
69 // Create output tensor
70 tensors[2] = CreateTensor(flatBufferBuilder,
71 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
72 outputTensorShape.size()),
73 tensorType,
74 0,
75 flatBufferBuilder.CreateString("output"),
76 quantizationParameters);
77
78 // Create operator. Reduce operations MIN, MAX, SUM, MEAN uses ReducerOptions.
79 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_ReducerOptions;
80 flatbuffers::Offset<void> operatorBuiltinOptions = CreateReducerOptions(flatBufferBuilder, keepDims).Union();
81
82 const std::vector<int> operatorInputs{ {0, 1} };
83 const std::vector<int> operatorOutputs{ 2 };
84 flatbuffers::Offset <Operator> reduceOperator =
85 CreateOperator(flatBufferBuilder,
86 0,
87 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
88 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
89 operatorBuiltinOptionsType,
90 operatorBuiltinOptions);
91
92 const std::vector<int> subgraphInputs{ {0, 1} };
93 const std::vector<int> subgraphOutputs{ 2 };
94 flatbuffers::Offset <SubGraph> subgraph =
95 CreateSubGraph(flatBufferBuilder,
96 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
97 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
98 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
99 flatBufferBuilder.CreateVector(&reduceOperator, 1));
100
101 flatbuffers::Offset <flatbuffers::String> modelDescription =
102 flatBufferBuilder.CreateString("ArmnnDelegate: Reduce Operator Model");
103 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, reduceOperatorCode);
104
105 flatbuffers::Offset <Model> flatbufferModel =
106 CreateModel(flatBufferBuilder,
107 TFLITE_SCHEMA_VERSION,
108 flatBufferBuilder.CreateVector(&operatorCode, 1),
109 flatBufferBuilder.CreateVector(&subgraph, 1),
110 modelDescription,
111 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
112
113 flatBufferBuilder.Finish(flatbufferModel);
114
115 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
116 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
117}
118
119template <typename T>
120void ReduceTest(tflite::BuiltinOperator reduceOperatorCode,
121 tflite::TensorType tensorType,
122 std::vector<armnn::BackendId>& backends,
123 std::vector<int32_t>& input0Shape,
124 std::vector<int32_t>& input1Shape,
125 std::vector<int32_t>& expectedOutputShape,
126 std::vector<T>& input0Values,
127 std::vector<int32_t>& input1Values,
128 std::vector<T>& expectedOutputValues,
129 const bool keepDims,
130 float quantScale = 1.0f,
131 int quantOffset = 0)
132{
133 using namespace tflite;
134 std::vector<char> modelBuffer = CreateReduceTfLiteModel(reduceOperatorCode,
135 tensorType,
136 input0Shape,
137 input1Shape,
138 expectedOutputShape,
139 input1Values,
140 keepDims,
141 quantScale,
142 quantOffset);
143
144 const Model* tfLiteModel = GetModel(modelBuffer.data());
145
146 // Create TfLite Interpreters
147 std::unique_ptr<Interpreter> armnnDelegateInterpreter;
148 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
149 (&armnnDelegateInterpreter) == kTfLiteOk);
150 CHECK(armnnDelegateInterpreter != nullptr);
151 CHECK(armnnDelegateInterpreter->AllocateTensors() == kTfLiteOk);
152
153 std::unique_ptr<Interpreter> tfLiteInterpreter;
154 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
155 (&tfLiteInterpreter) == kTfLiteOk);
156 CHECK(tfLiteInterpreter != nullptr);
157 CHECK(tfLiteInterpreter->AllocateTensors() == kTfLiteOk);
158
159 // Create the ArmNN Delegate
160 armnnDelegate::DelegateOptions delegateOptions(backends);
161 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
162 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
163 armnnDelegate::TfLiteArmnnDelegateDelete);
164 CHECK(theArmnnDelegate != nullptr);
165
166 // Modify armnnDelegateInterpreter to use armnnDelegate
167 CHECK(armnnDelegateInterpreter->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
168
169 // Set input data
170 armnnDelegate::FillInput<T>(tfLiteInterpreter, 0, input0Values);
171 armnnDelegate::FillInput<T>(armnnDelegateInterpreter, 0, input0Values);
172
173 // Run EnqueWorkload
174 CHECK(tfLiteInterpreter->Invoke() == kTfLiteOk);
175 CHECK(armnnDelegateInterpreter->Invoke() == kTfLiteOk);
176
177 // Compare output data
178 armnnDelegate::CompareOutputData<T>(tfLiteInterpreter,
179 armnnDelegateInterpreter,
180 expectedOutputShape,
181 expectedOutputValues);
182
183 armnnDelegateInterpreter.reset(nullptr);
184}
185
186} // anonymous namespace