blob: 2726fdef4c4a29d2140b84132a2945cd10b5bb9a [file] [log] [blame]
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
7#include "TensorCopyUtils.hpp"
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01008#include <ResolveType.hpp>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +00009#include "WorkloadTestUtils.hpp"
10
11#include <armnn/Types.hpp>
12#include <backendsCommon/CpuTensorHandle.hpp>
13#include <backendsCommon/IBackendInternal.hpp>
14#include <backendsCommon/WorkloadFactory.hpp>
15#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
16#include <test/TensorHelpers.hpp>
17
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010018namespace
19{
20
21using FloatData = std::vector<float>;
22using QuantData = std::pair<float, int32_t>;
23
24struct TestData
25{
26 static const armnn::TensorShape s_BoxEncodingsShape;
27 static const armnn::TensorShape s_ScoresShape;
28 static const armnn::TensorShape s_AnchorsShape;
29
30 static const QuantData s_BoxEncodingsQuantData;
31 static const QuantData s_ScoresQuantData;
32 static const QuantData s_AnchorsQuantData;
33
34 static const FloatData s_BoxEncodings;
35 static const FloatData s_Scores;
36 static const FloatData s_Anchors;
37};
38
39struct RegularNmsExpectedResults
40{
41 static const FloatData s_DetectionBoxes;
42 static const FloatData s_DetectionScores;
43 static const FloatData s_DetectionClasses;
44 static const FloatData s_NumDetections;
45};
46
47struct FastNmsExpectedResults
48{
49 static const FloatData s_DetectionBoxes;
50 static const FloatData s_DetectionScores;
51 static const FloatData s_DetectionClasses;
52 static const FloatData s_NumDetections;
53};
54
55const armnn::TensorShape TestData::s_BoxEncodingsShape = { 1, 6, 4 };
56const armnn::TensorShape TestData::s_ScoresShape = { 1, 6, 3 };
57const armnn::TensorShape TestData::s_AnchorsShape = { 6, 4 };
58
59const QuantData TestData::s_BoxEncodingsQuantData = { 1.00f, 1 };
60const QuantData TestData::s_ScoresQuantData = { 0.01f, 0 };
61const QuantData TestData::s_AnchorsQuantData = { 0.50f, 0 };
62
63const FloatData TestData::s_BoxEncodings =
64{
65 0.0f, 0.0f, 0.0f, 0.0f,
66 0.0f, 1.0f, 0.0f, 0.0f,
67 0.0f, -1.0f, 0.0f, 0.0f,
68 0.0f, 0.0f, 0.0f, 0.0f,
69 0.0f, 1.0f, 0.0f, 0.0f,
70 0.0f, 0.0f, 0.0f, 0.0f
71};
72
73const FloatData TestData::s_Scores =
74{
75 0.0f, 0.90f, 0.80f,
76 0.0f, 0.75f, 0.72f,
77 0.0f, 0.60f, 0.50f,
78 0.0f, 0.93f, 0.95f,
79 0.0f, 0.50f, 0.40f,
80 0.0f, 0.30f, 0.20f
81};
82
83const FloatData TestData::s_Anchors =
84{
85 0.5f, 0.5f, 1.0f, 1.0f,
86 0.5f, 0.5f, 1.0f, 1.0f,
87 0.5f, 0.5f, 1.0f, 1.0f,
88 0.5f, 10.5f, 1.0f, 1.0f,
89 0.5f, 10.5f, 1.0f, 1.0f,
90 0.5f, 100.5f, 1.0f, 1.0f
91};
92
93const FloatData RegularNmsExpectedResults::s_DetectionBoxes =
94{
95 0.0f, 10.0f, 1.0f, 11.0f,
96 0.0f, 10.0f, 1.0f, 11.0f,
97 0.0f, 0.0f, 0.0f, 0.0f
98};
99
100const FloatData RegularNmsExpectedResults::s_DetectionScores =
101{
102 0.95f, 0.93f, 0.0f
103};
104
105const FloatData RegularNmsExpectedResults::s_DetectionClasses =
106{
107 1.0f, 0.0f, 0.0f
108};
109
110const FloatData RegularNmsExpectedResults::s_NumDetections = { 2.0f };
111
112const FloatData FastNmsExpectedResults::s_DetectionBoxes =
113{
114 0.0f, 10.0f, 1.0f, 11.0f,
115 0.0f, 0.0f, 1.0f, 1.0f,
116 0.0f, 100.0f, 1.0f, 101.0f
117};
118
119const FloatData FastNmsExpectedResults::s_DetectionScores =
120{
121 0.95f, 0.9f, 0.3f
122};
123
124const FloatData FastNmsExpectedResults::s_DetectionClasses =
125{
126 1.0f, 0.0f, 0.0f
127};
128
129const FloatData FastNmsExpectedResults::s_NumDetections = { 3.0f };
130
131} // anonymous namespace
132
133template<typename FactoryType,
134 armnn::DataType ArmnnType,
135 typename T = armnn::ResolveType<ArmnnType>>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000136void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
137 const armnn::TensorInfo& scoresInfo,
138 const armnn::TensorInfo& anchorsInfo,
139 const std::vector<T>& boxEncodingsData,
140 const std::vector<T>& scoresData,
141 const std::vector<T>& anchorsData,
142 const std::vector<float>& expectedDetectionBoxes,
143 const std::vector<float>& expectedDetectionClasses,
144 const std::vector<float>& expectedDetectionScores,
145 const std::vector<float>& expectedNumDetections,
146 bool useRegularNms)
147{
148 std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
149 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
150
151 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
152 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
153
154 auto boxEncodings = MakeTensor<T, 3>(boxEncodingsInfo, boxEncodingsData);
155 auto scores = MakeTensor<T, 3>(scoresInfo, scoresData);
156 auto anchors = MakeTensor<T, 2>(anchorsInfo, anchorsData);
157
158 armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
159 armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
160 armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
161 armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
162
163 LayerTestResult<float, 3> detectionBoxesResult(detectionBoxesInfo);
164 detectionBoxesResult.outputExpected = MakeTensor<float, 3>(detectionBoxesInfo, expectedDetectionBoxes);
165 LayerTestResult<float, 2> detectionClassesResult(detectionClassesInfo);
166 detectionClassesResult.outputExpected = MakeTensor<float, 2>(detectionClassesInfo, expectedDetectionClasses);
167 LayerTestResult<float, 2> detectionScoresResult(detectionScoresInfo);
168 detectionScoresResult.outputExpected = MakeTensor<float, 2>(detectionScoresInfo, expectedDetectionScores);
169 LayerTestResult<float, 1> numDetectionsResult(numDetectionInfo);
170 numDetectionsResult.outputExpected = MakeTensor<float, 1>(numDetectionInfo, expectedNumDetections);
171
172 std::unique_ptr<armnn::ITensorHandle> boxedHandle = workloadFactory.CreateTensorHandle(boxEncodingsInfo);
173 std::unique_ptr<armnn::ITensorHandle> scoreshandle = workloadFactory.CreateTensorHandle(scoresInfo);
174 std::unique_ptr<armnn::ITensorHandle> anchorsHandle = workloadFactory.CreateTensorHandle(anchorsInfo);
175 std::unique_ptr<armnn::ITensorHandle> outputBoxesHandle = workloadFactory.CreateTensorHandle(detectionBoxesInfo);
176 std::unique_ptr<armnn::ITensorHandle> classesHandle = workloadFactory.CreateTensorHandle(detectionClassesInfo);
177 std::unique_ptr<armnn::ITensorHandle> outputScoresHandle = workloadFactory.CreateTensorHandle(detectionScoresInfo);
178 std::unique_ptr<armnn::ITensorHandle> numDetectionHandle = workloadFactory.CreateTensorHandle(numDetectionInfo);
179
180 armnn::ScopedCpuTensorHandle anchorsTensor(anchorsInfo);
181 AllocateAndCopyDataToITensorHandle(&anchorsTensor, &anchors[0][0]);
182
183 armnn::DetectionPostProcessQueueDescriptor data;
184 data.m_Parameters.m_UseRegularNms = useRegularNms;
185 data.m_Parameters.m_MaxDetections = 3;
186 data.m_Parameters.m_MaxClassesPerDetection = 1;
187 data.m_Parameters.m_DetectionsPerClass =1;
188 data.m_Parameters.m_NmsScoreThreshold = 0.0;
189 data.m_Parameters.m_NmsIouThreshold = 0.5;
190 data.m_Parameters.m_NumClasses = 2;
191 data.m_Parameters.m_ScaleY = 10.0;
192 data.m_Parameters.m_ScaleX = 10.0;
193 data.m_Parameters.m_ScaleH = 5.0;
194 data.m_Parameters.m_ScaleW = 5.0;
195 data.m_Anchors = &anchorsTensor;
196
197 armnn::WorkloadInfo info;
198 AddInputToWorkload(data, info, boxEncodingsInfo, boxedHandle.get());
199 AddInputToWorkload(data, info, scoresInfo, scoreshandle.get());
200 AddOutputToWorkload(data, info, detectionBoxesInfo, outputBoxesHandle.get());
201 AddOutputToWorkload(data, info, detectionClassesInfo, classesHandle.get());
202 AddOutputToWorkload(data, info, detectionScoresInfo, outputScoresHandle.get());
203 AddOutputToWorkload(data, info, numDetectionInfo, numDetectionHandle.get());
204
205 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateDetectionPostProcess(data, info);
206
207 boxedHandle->Allocate();
208 scoreshandle->Allocate();
209 outputBoxesHandle->Allocate();
210 classesHandle->Allocate();
211 outputScoresHandle->Allocate();
212 numDetectionHandle->Allocate();
213
214 CopyDataToITensorHandle(boxedHandle.get(), boxEncodings.origin());
215 CopyDataToITensorHandle(scoreshandle.get(), scores.origin());
216
217 workload->Execute();
218
219 CopyDataFromITensorHandle(detectionBoxesResult.output.origin(), outputBoxesHandle.get());
220 CopyDataFromITensorHandle(detectionClassesResult.output.origin(), classesHandle.get());
221 CopyDataFromITensorHandle(detectionScoresResult.output.origin(), outputScoresHandle.get());
222 CopyDataFromITensorHandle(numDetectionsResult.output.origin(), numDetectionHandle.get());
223
224 BOOST_TEST(CompareTensors(detectionBoxesResult.output, detectionBoxesResult.outputExpected));
225 BOOST_TEST(CompareTensors(detectionClassesResult.output, detectionClassesResult.outputExpected));
226 BOOST_TEST(CompareTensors(detectionScoresResult.output, detectionScoresResult.outputExpected));
227 BOOST_TEST(CompareTensors(numDetectionsResult.output, numDetectionsResult.outputExpected));
228}
229
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100230template<armnn::DataType QuantizedType, typename RawType = armnn::ResolveType<QuantizedType>>
231void QuantizeData(RawType* quant, const float* dequant, const armnn::TensorInfo& info)
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000232{
233 for (size_t i = 0; i < info.GetNumElements(); i++)
234 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100235 quant[i] = armnn::Quantize<RawType>(
236 dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000237 }
238}
239
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100240template<typename FactoryType>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000241void DetectionPostProcessRegularNmsFloatTest()
242{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100243 return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
244 armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
245 armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
246 armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
247 TestData::s_BoxEncodings,
248 TestData::s_Scores,
249 TestData::s_Anchors,
250 RegularNmsExpectedResults::s_DetectionBoxes,
251 RegularNmsExpectedResults::s_DetectionClasses,
252 RegularNmsExpectedResults::s_DetectionScores,
253 RegularNmsExpectedResults::s_NumDetections,
254 true);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000255}
256
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100257template<typename FactoryType,
258 armnn::DataType QuantizedType,
259 typename RawType = armnn::ResolveType<QuantizedType>>
260void DetectionPostProcessRegularNmsQuantizedTest()
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000261{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100262 armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
263 armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
264 armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000265
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100266 boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
267 boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000268
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100269 scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
270 scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000271
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100272 anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
273 anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000274
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100275 std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
276 QuantizeData<QuantizedType>(boxEncodingsData.data(),
277 TestData::s_BoxEncodings.data(),
278 boxEncodingsInfo);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000279
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100280 std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
281 QuantizeData<QuantizedType>(scoresData.data(),
282 TestData::s_Scores.data(),
283 scoresInfo);
284
285 std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
286 QuantizeData<QuantizedType>(anchorsData.data(),
287 TestData::s_Anchors.data(),
288 anchorsInfo);
289
290 return DetectionPostProcessImpl<FactoryType, QuantizedType>(
291 boxEncodingsInfo,
292 scoresInfo,
293 anchorsInfo,
294 boxEncodingsData,
295 scoresData,
296 anchorsData,
297 RegularNmsExpectedResults::s_DetectionBoxes,
298 RegularNmsExpectedResults::s_DetectionClasses,
299 RegularNmsExpectedResults::s_DetectionScores,
300 RegularNmsExpectedResults::s_NumDetections,
301 true);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000302}
303
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100304template<typename FactoryType>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000305void DetectionPostProcessFastNmsFloatTest()
306{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100307 return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
308 armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
309 armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
310 armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
311 TestData::s_BoxEncodings,
312 TestData::s_Scores,
313 TestData::s_Anchors,
314 FastNmsExpectedResults::s_DetectionBoxes,
315 FastNmsExpectedResults::s_DetectionClasses,
316 FastNmsExpectedResults::s_DetectionScores,
317 FastNmsExpectedResults::s_NumDetections,
318 false);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000319}
320
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100321template<typename FactoryType,
322 armnn::DataType QuantizedType,
323 typename RawType = armnn::ResolveType<QuantizedType>>
324void DetectionPostProcessFastNmsQuantizedTest()
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000325{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100326 armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
327 armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
328 armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000329
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100330 boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
331 boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000332
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100333 scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
334 scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000335
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100336 anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
337 anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000338
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100339 std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
340 QuantizeData<QuantizedType>(boxEncodingsData.data(),
341 TestData::s_BoxEncodings.data(),
342 boxEncodingsInfo);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000343
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100344 std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
345 QuantizeData<QuantizedType>(scoresData.data(),
346 TestData::s_Scores.data(),
347 scoresInfo);
348
349 std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
350 QuantizeData<QuantizedType>(anchorsData.data(),
351 TestData::s_Anchors.data(),
352 anchorsInfo);
353
354 return DetectionPostProcessImpl<FactoryType, QuantizedType>(
355 boxEncodingsInfo,
356 scoresInfo,
357 anchorsInfo,
358 boxEncodingsData,
359 scoresData,
360 anchorsData,
361 FastNmsExpectedResults::s_DetectionBoxes,
362 FastNmsExpectedResults::s_DetectionClasses,
363 FastNmsExpectedResults::s_DetectionScores,
364 FastNmsExpectedResults::s_NumDetections,
365 false);
366}