blob: abaa807aed1b71905a597c6ecbabeb1380bff460 [file] [log] [blame]
Jan Eilers2ffddda2021-02-03 09:14:30 +00001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "TestUtils.hpp"
9
10#include <armnn_delegate.hpp>
11#include <armnn/DescriptorsFwd.hpp>
12
13#include <flatbuffers/flatbuffers.h>
14#include <tensorflow/lite/interpreter.h>
15#include <tensorflow/lite/kernels/register.h>
16#include <tensorflow/lite/model.h>
17#include <tensorflow/lite/schema/schema_generated.h>
18#include <tensorflow/lite/version.h>
19
20#include <doctest/doctest.h>
21
22#include <string>
23
24namespace
25{
26
27struct StridedSliceParams
28{
29 StridedSliceParams(std::vector<int32_t>& inputTensorShape,
30 std::vector<int32_t>& beginTensorData,
31 std::vector<int32_t>& endTensorData,
32 std::vector<int32_t>& strideTensorData,
33 std::vector<int32_t>& outputTensorShape,
34 armnn::StridedSliceDescriptor& descriptor)
35 : m_InputTensorShape(inputTensorShape),
36 m_BeginTensorData(beginTensorData),
37 m_EndTensorData(endTensorData),
38 m_StrideTensorData(strideTensorData),
39 m_OutputTensorShape(outputTensorShape),
40 m_Descriptor (descriptor) {}
41
42 std::vector<int32_t> m_InputTensorShape;
43 std::vector<int32_t> m_BeginTensorData;
44 std::vector<int32_t> m_EndTensorData;
45 std::vector<int32_t> m_StrideTensorData;
46 std::vector<int32_t> m_OutputTensorShape;
47 armnn::StridedSliceDescriptor m_Descriptor;
48};
49
50std::vector<char> CreateSliceTfLiteModel(tflite::TensorType tensorType,
51 const std::vector<int32_t>& inputTensorShape,
52 const std::vector<int32_t>& beginTensorData,
53 const std::vector<int32_t>& endTensorData,
54 const std::vector<int32_t>& strideTensorData,
55 const std::vector<int32_t>& beginTensorShape,
56 const std::vector<int32_t>& endTensorShape,
57 const std::vector<int32_t>& strideTensorShape,
58 const std::vector<int32_t>& outputTensorShape,
59 const int32_t beginMask,
60 const int32_t endMask,
61 const int32_t ellipsisMask,
62 const int32_t newAxisMask,
63 const int32_t ShrinkAxisMask,
64 const armnn::DataLayout& dataLayout)
65{
66 using namespace tflite;
67 flatbuffers::FlatBufferBuilder flatBufferBuilder;
68
69 std::array<flatbuffers::Offset<tflite::Buffer>, 4> buffers;
70 buffers[0] = CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({}));
71 buffers[1] = CreateBuffer(flatBufferBuilder,
72 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(beginTensorData.data()),
73 sizeof(int32_t) * beginTensorData.size()));
74 buffers[2] = CreateBuffer(flatBufferBuilder,
75 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(endTensorData.data()),
76 sizeof(int32_t) * endTensorData.size()));
77 buffers[3] = CreateBuffer(flatBufferBuilder,
78 flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(strideTensorData.data()),
79 sizeof(int32_t) * strideTensorData.size()));
80
81 std::array<flatbuffers::Offset<Tensor>, 5> tensors;
82 tensors[0] = CreateTensor(flatBufferBuilder,
83 flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
84 inputTensorShape.size()),
85 tensorType,
86 0,
87 flatBufferBuilder.CreateString("input"));
88 tensors[1] = CreateTensor(flatBufferBuilder,
89 flatBufferBuilder.CreateVector<int32_t>(beginTensorShape.data(),
90 beginTensorShape.size()),
91 ::tflite::TensorType_INT32,
92 1,
93 flatBufferBuilder.CreateString("begin_tensor"));
94 tensors[2] = CreateTensor(flatBufferBuilder,
95 flatBufferBuilder.CreateVector<int32_t>(endTensorShape.data(),
96 endTensorShape.size()),
97 ::tflite::TensorType_INT32,
98 2,
99 flatBufferBuilder.CreateString("end_tensor"));
100 tensors[3] = CreateTensor(flatBufferBuilder,
101 flatBufferBuilder.CreateVector<int32_t>(strideTensorShape.data(),
102 strideTensorShape.size()),
103 ::tflite::TensorType_INT32,
104 3,
105 flatBufferBuilder.CreateString("stride_tensor"));
106 tensors[4] = CreateTensor(flatBufferBuilder,
107 flatBufferBuilder.CreateVector<int32_t>(outputTensorShape.data(),
108 outputTensorShape.size()),
109 tensorType,
110 0,
111 flatBufferBuilder.CreateString("output"));
112
113
114 // create operator
115 tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_StridedSliceOptions;
116 flatbuffers::Offset<void> operatorBuiltinOptions = CreateStridedSliceOptions(flatBufferBuilder,
117 beginMask,
118 endMask,
119 ellipsisMask,
120 newAxisMask,
121 ShrinkAxisMask).Union();
122
123 const std::vector<int> operatorInputs{ 0, 1, 2, 3 };
124 const std::vector<int> operatorOutputs{ 4 };
125 flatbuffers::Offset <Operator> sliceOperator =
126 CreateOperator(flatBufferBuilder,
127 0,
128 flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
129 flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
130 operatorBuiltinOptionsType,
131 operatorBuiltinOptions);
132
133 const std::vector<int> subgraphInputs{ 0, 1, 2, 3 };
134 const std::vector<int> subgraphOutputs{ 4 };
135 flatbuffers::Offset <SubGraph> subgraph =
136 CreateSubGraph(flatBufferBuilder,
137 flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
138 flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
139 flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
140 flatBufferBuilder.CreateVector(&sliceOperator, 1));
141
142 flatbuffers::Offset <flatbuffers::String> modelDescription =
143 flatBufferBuilder.CreateString("ArmnnDelegate: StridedSlice Operator Model");
144 flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder,
145 BuiltinOperator_STRIDED_SLICE);
146
147 flatbuffers::Offset <Model> flatbufferModel =
148 CreateModel(flatBufferBuilder,
149 TFLITE_SCHEMA_VERSION,
150 flatBufferBuilder.CreateVector(&operatorCode, 1),
151 flatBufferBuilder.CreateVector(&subgraph, 1),
152 modelDescription,
153 flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
154
155 flatBufferBuilder.Finish(flatbufferModel);
156
157 return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
158 flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
159}
160
161template <typename T>
162void StridedSliceTestImpl(std::vector<armnn::BackendId>& backends,
163 std::vector<T>& inputValues,
164 std::vector<T>& expectedOutputValues,
165 std::vector<int32_t>& beginTensorData,
166 std::vector<int32_t>& endTensorData,
167 std::vector<int32_t>& strideTensorData,
168 std::vector<int32_t>& inputTensorShape,
169 std::vector<int32_t>& beginTensorShape,
170 std::vector<int32_t>& endTensorShape,
171 std::vector<int32_t>& strideTensorShape,
172 std::vector<int32_t>& outputTensorShape,
173 const int32_t beginMask = 0,
174 const int32_t endMask = 0,
175 const int32_t ellipsisMask = 0,
176 const int32_t newAxisMask = 0,
177 const int32_t ShrinkAxisMask = 0,
178 const armnn::DataLayout& dataLayout = armnn::DataLayout::NHWC)
179{
180 using namespace tflite;
181 std::vector<char> modelBuffer = CreateSliceTfLiteModel(
182 ::tflite::TensorType_FLOAT32,
183 inputTensorShape,
184 beginTensorData,
185 endTensorData,
186 strideTensorData,
187 beginTensorShape,
188 endTensorShape,
189 strideTensorShape,
190 outputTensorShape,
191 beginMask,
192 endMask,
193 ellipsisMask,
194 newAxisMask,
195 ShrinkAxisMask,
196 dataLayout);
197
198 auto tfLiteModel = GetModel(modelBuffer.data());
199
200 // Create TfLite Interpreters
201 std::unique_ptr<Interpreter> armnnDelegate;
202 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
203 (&armnnDelegate) == kTfLiteOk);
204 CHECK(armnnDelegate != nullptr);
205 CHECK(armnnDelegate->AllocateTensors() == kTfLiteOk);
206
207 std::unique_ptr<Interpreter> tfLiteDelegate;
208 CHECK(InterpreterBuilder(tfLiteModel, ::tflite::ops::builtin::BuiltinOpResolver())
209 (&tfLiteDelegate) == kTfLiteOk);
210 CHECK(tfLiteDelegate != nullptr);
211 CHECK(tfLiteDelegate->AllocateTensors() == kTfLiteOk);
212
213 // Create the ArmNN Delegate
214 armnnDelegate::DelegateOptions delegateOptions(backends);
215 std::unique_ptr<TfLiteDelegate, decltype(&armnnDelegate::TfLiteArmnnDelegateDelete)>
216 theArmnnDelegate(armnnDelegate::TfLiteArmnnDelegateCreate(delegateOptions),
217 armnnDelegate::TfLiteArmnnDelegateDelete);
218 CHECK(theArmnnDelegate != nullptr);
219
220 // Modify armnnDelegateInterpreter to use armnnDelegate
221 CHECK(armnnDelegate->ModifyGraphWithDelegate(theArmnnDelegate.get()) == kTfLiteOk);
222
223 // Set input data
224 armnnDelegate::FillInput<T>(tfLiteDelegate, 0, inputValues);
225 armnnDelegate::FillInput<T>(armnnDelegate, 0, inputValues);
226
227 // Run EnqueWorkload
228 CHECK(tfLiteDelegate->Invoke() == kTfLiteOk);
229 CHECK(armnnDelegate->Invoke() == kTfLiteOk);
230
231 // Compare output data
232 armnnDelegate::CompareOutputData<T>(tfLiteDelegate,
233 armnnDelegate,
234 outputTensorShape,
235 expectedOutputValues);
236
237 tfLiteDelegate.reset(nullptr);
238 armnnDelegate.reset(nullptr);
239} // End of StridedSlice Test
240
241} // anonymous namespace