blob: bcb5abfe726aca3a7d161f2a53d830b1f5573ba8 [file] [log] [blame]
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01007#include <ResolveType.hpp>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +00008
9#include <armnn/Types.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010010
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +000011#include <backendsCommon/CpuTensorHandle.hpp>
12#include <backendsCommon/IBackendInternal.hpp>
13#include <backendsCommon/WorkloadFactory.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010014
15#include <backendsCommon/test/TensorCopyUtils.hpp>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +000016#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010017#include <backendsCommon/test/WorkloadTestUtils.hpp>
18
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +000019#include <test/TensorHelpers.hpp>
20
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010021namespace
22{
23
24using FloatData = std::vector<float>;
25using QuantData = std::pair<float, int32_t>;
26
27struct TestData
28{
29 static const armnn::TensorShape s_BoxEncodingsShape;
30 static const armnn::TensorShape s_ScoresShape;
31 static const armnn::TensorShape s_AnchorsShape;
32
33 static const QuantData s_BoxEncodingsQuantData;
34 static const QuantData s_ScoresQuantData;
35 static const QuantData s_AnchorsQuantData;
36
37 static const FloatData s_BoxEncodings;
38 static const FloatData s_Scores;
39 static const FloatData s_Anchors;
40};
41
42struct RegularNmsExpectedResults
43{
44 static const FloatData s_DetectionBoxes;
45 static const FloatData s_DetectionScores;
46 static const FloatData s_DetectionClasses;
47 static const FloatData s_NumDetections;
48};
49
50struct FastNmsExpectedResults
51{
52 static const FloatData s_DetectionBoxes;
53 static const FloatData s_DetectionScores;
54 static const FloatData s_DetectionClasses;
55 static const FloatData s_NumDetections;
56};
57
58const armnn::TensorShape TestData::s_BoxEncodingsShape = { 1, 6, 4 };
59const armnn::TensorShape TestData::s_ScoresShape = { 1, 6, 3 };
60const armnn::TensorShape TestData::s_AnchorsShape = { 6, 4 };
61
62const QuantData TestData::s_BoxEncodingsQuantData = { 1.00f, 1 };
63const QuantData TestData::s_ScoresQuantData = { 0.01f, 0 };
64const QuantData TestData::s_AnchorsQuantData = { 0.50f, 0 };
65
66const FloatData TestData::s_BoxEncodings =
67{
68 0.0f, 0.0f, 0.0f, 0.0f,
69 0.0f, 1.0f, 0.0f, 0.0f,
70 0.0f, -1.0f, 0.0f, 0.0f,
71 0.0f, 0.0f, 0.0f, 0.0f,
72 0.0f, 1.0f, 0.0f, 0.0f,
73 0.0f, 0.0f, 0.0f, 0.0f
74};
75
76const FloatData TestData::s_Scores =
77{
78 0.0f, 0.90f, 0.80f,
79 0.0f, 0.75f, 0.72f,
80 0.0f, 0.60f, 0.50f,
81 0.0f, 0.93f, 0.95f,
82 0.0f, 0.50f, 0.40f,
83 0.0f, 0.30f, 0.20f
84};
85
86const FloatData TestData::s_Anchors =
87{
88 0.5f, 0.5f, 1.0f, 1.0f,
89 0.5f, 0.5f, 1.0f, 1.0f,
90 0.5f, 0.5f, 1.0f, 1.0f,
91 0.5f, 10.5f, 1.0f, 1.0f,
92 0.5f, 10.5f, 1.0f, 1.0f,
93 0.5f, 100.5f, 1.0f, 1.0f
94};
95
96const FloatData RegularNmsExpectedResults::s_DetectionBoxes =
97{
98 0.0f, 10.0f, 1.0f, 11.0f,
99 0.0f, 10.0f, 1.0f, 11.0f,
100 0.0f, 0.0f, 0.0f, 0.0f
101};
102
103const FloatData RegularNmsExpectedResults::s_DetectionScores =
104{
105 0.95f, 0.93f, 0.0f
106};
107
108const FloatData RegularNmsExpectedResults::s_DetectionClasses =
109{
110 1.0f, 0.0f, 0.0f
111};
112
113const FloatData RegularNmsExpectedResults::s_NumDetections = { 2.0f };
114
115const FloatData FastNmsExpectedResults::s_DetectionBoxes =
116{
117 0.0f, 10.0f, 1.0f, 11.0f,
118 0.0f, 0.0f, 1.0f, 1.0f,
119 0.0f, 100.0f, 1.0f, 101.0f
120};
121
122const FloatData FastNmsExpectedResults::s_DetectionScores =
123{
124 0.95f, 0.9f, 0.3f
125};
126
127const FloatData FastNmsExpectedResults::s_DetectionClasses =
128{
129 1.0f, 0.0f, 0.0f
130};
131
132const FloatData FastNmsExpectedResults::s_NumDetections = { 3.0f };
133
134} // anonymous namespace
135
136template<typename FactoryType,
137 armnn::DataType ArmnnType,
138 typename T = armnn::ResolveType<ArmnnType>>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000139void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
140 const armnn::TensorInfo& scoresInfo,
141 const armnn::TensorInfo& anchorsInfo,
142 const std::vector<T>& boxEncodingsData,
143 const std::vector<T>& scoresData,
144 const std::vector<T>& anchorsData,
145 const std::vector<float>& expectedDetectionBoxes,
146 const std::vector<float>& expectedDetectionClasses,
147 const std::vector<float>& expectedDetectionScores,
148 const std::vector<float>& expectedNumDetections,
149 bool useRegularNms)
150{
151 std::unique_ptr<armnn::Profiler> profiler = std::make_unique<armnn::Profiler>();
152 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
153
154 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
155 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
156
157 auto boxEncodings = MakeTensor<T, 3>(boxEncodingsInfo, boxEncodingsData);
158 auto scores = MakeTensor<T, 3>(scoresInfo, scoresData);
159 auto anchors = MakeTensor<T, 2>(anchorsInfo, anchorsData);
160
161 armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
162 armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
163 armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
164 armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
165
166 LayerTestResult<float, 3> detectionBoxesResult(detectionBoxesInfo);
167 detectionBoxesResult.outputExpected = MakeTensor<float, 3>(detectionBoxesInfo, expectedDetectionBoxes);
168 LayerTestResult<float, 2> detectionClassesResult(detectionClassesInfo);
169 detectionClassesResult.outputExpected = MakeTensor<float, 2>(detectionClassesInfo, expectedDetectionClasses);
170 LayerTestResult<float, 2> detectionScoresResult(detectionScoresInfo);
171 detectionScoresResult.outputExpected = MakeTensor<float, 2>(detectionScoresInfo, expectedDetectionScores);
172 LayerTestResult<float, 1> numDetectionsResult(numDetectionInfo);
173 numDetectionsResult.outputExpected = MakeTensor<float, 1>(numDetectionInfo, expectedNumDetections);
174
175 std::unique_ptr<armnn::ITensorHandle> boxedHandle = workloadFactory.CreateTensorHandle(boxEncodingsInfo);
176 std::unique_ptr<armnn::ITensorHandle> scoreshandle = workloadFactory.CreateTensorHandle(scoresInfo);
177 std::unique_ptr<armnn::ITensorHandle> anchorsHandle = workloadFactory.CreateTensorHandle(anchorsInfo);
178 std::unique_ptr<armnn::ITensorHandle> outputBoxesHandle = workloadFactory.CreateTensorHandle(detectionBoxesInfo);
179 std::unique_ptr<armnn::ITensorHandle> classesHandle = workloadFactory.CreateTensorHandle(detectionClassesInfo);
180 std::unique_ptr<armnn::ITensorHandle> outputScoresHandle = workloadFactory.CreateTensorHandle(detectionScoresInfo);
181 std::unique_ptr<armnn::ITensorHandle> numDetectionHandle = workloadFactory.CreateTensorHandle(numDetectionInfo);
182
183 armnn::ScopedCpuTensorHandle anchorsTensor(anchorsInfo);
184 AllocateAndCopyDataToITensorHandle(&anchorsTensor, &anchors[0][0]);
185
186 armnn::DetectionPostProcessQueueDescriptor data;
187 data.m_Parameters.m_UseRegularNms = useRegularNms;
188 data.m_Parameters.m_MaxDetections = 3;
189 data.m_Parameters.m_MaxClassesPerDetection = 1;
190 data.m_Parameters.m_DetectionsPerClass =1;
191 data.m_Parameters.m_NmsScoreThreshold = 0.0;
192 data.m_Parameters.m_NmsIouThreshold = 0.5;
193 data.m_Parameters.m_NumClasses = 2;
194 data.m_Parameters.m_ScaleY = 10.0;
195 data.m_Parameters.m_ScaleX = 10.0;
196 data.m_Parameters.m_ScaleH = 5.0;
197 data.m_Parameters.m_ScaleW = 5.0;
198 data.m_Anchors = &anchorsTensor;
199
200 armnn::WorkloadInfo info;
201 AddInputToWorkload(data, info, boxEncodingsInfo, boxedHandle.get());
202 AddInputToWorkload(data, info, scoresInfo, scoreshandle.get());
203 AddOutputToWorkload(data, info, detectionBoxesInfo, outputBoxesHandle.get());
204 AddOutputToWorkload(data, info, detectionClassesInfo, classesHandle.get());
205 AddOutputToWorkload(data, info, detectionScoresInfo, outputScoresHandle.get());
206 AddOutputToWorkload(data, info, numDetectionInfo, numDetectionHandle.get());
207
208 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateDetectionPostProcess(data, info);
209
210 boxedHandle->Allocate();
211 scoreshandle->Allocate();
212 outputBoxesHandle->Allocate();
213 classesHandle->Allocate();
214 outputScoresHandle->Allocate();
215 numDetectionHandle->Allocate();
216
217 CopyDataToITensorHandle(boxedHandle.get(), boxEncodings.origin());
218 CopyDataToITensorHandle(scoreshandle.get(), scores.origin());
219
220 workload->Execute();
221
222 CopyDataFromITensorHandle(detectionBoxesResult.output.origin(), outputBoxesHandle.get());
223 CopyDataFromITensorHandle(detectionClassesResult.output.origin(), classesHandle.get());
224 CopyDataFromITensorHandle(detectionScoresResult.output.origin(), outputScoresHandle.get());
225 CopyDataFromITensorHandle(numDetectionsResult.output.origin(), numDetectionHandle.get());
226
227 BOOST_TEST(CompareTensors(detectionBoxesResult.output, detectionBoxesResult.outputExpected));
228 BOOST_TEST(CompareTensors(detectionClassesResult.output, detectionClassesResult.outputExpected));
229 BOOST_TEST(CompareTensors(detectionScoresResult.output, detectionScoresResult.outputExpected));
230 BOOST_TEST(CompareTensors(numDetectionsResult.output, numDetectionsResult.outputExpected));
231}
232
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100233template<armnn::DataType QuantizedType, typename RawType = armnn::ResolveType<QuantizedType>>
234void QuantizeData(RawType* quant, const float* dequant, const armnn::TensorInfo& info)
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000235{
236 for (size_t i = 0; i < info.GetNumElements(); i++)
237 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100238 quant[i] = armnn::Quantize<RawType>(
239 dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000240 }
241}
242
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100243template<typename FactoryType>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000244void DetectionPostProcessRegularNmsFloatTest()
245{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100246 return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
247 armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
248 armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
249 armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
250 TestData::s_BoxEncodings,
251 TestData::s_Scores,
252 TestData::s_Anchors,
253 RegularNmsExpectedResults::s_DetectionBoxes,
254 RegularNmsExpectedResults::s_DetectionClasses,
255 RegularNmsExpectedResults::s_DetectionScores,
256 RegularNmsExpectedResults::s_NumDetections,
257 true);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000258}
259
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100260template<typename FactoryType,
261 armnn::DataType QuantizedType,
262 typename RawType = armnn::ResolveType<QuantizedType>>
263void DetectionPostProcessRegularNmsQuantizedTest()
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000264{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100265 armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
266 armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
267 armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000268
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100269 boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
270 boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000271
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100272 scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
273 scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000274
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100275 anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
276 anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000277
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100278 std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
279 QuantizeData<QuantizedType>(boxEncodingsData.data(),
280 TestData::s_BoxEncodings.data(),
281 boxEncodingsInfo);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000282
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100283 std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
284 QuantizeData<QuantizedType>(scoresData.data(),
285 TestData::s_Scores.data(),
286 scoresInfo);
287
288 std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
289 QuantizeData<QuantizedType>(anchorsData.data(),
290 TestData::s_Anchors.data(),
291 anchorsInfo);
292
293 return DetectionPostProcessImpl<FactoryType, QuantizedType>(
294 boxEncodingsInfo,
295 scoresInfo,
296 anchorsInfo,
297 boxEncodingsData,
298 scoresData,
299 anchorsData,
300 RegularNmsExpectedResults::s_DetectionBoxes,
301 RegularNmsExpectedResults::s_DetectionClasses,
302 RegularNmsExpectedResults::s_DetectionScores,
303 RegularNmsExpectedResults::s_NumDetections,
304 true);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000305}
306
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100307template<typename FactoryType>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000308void DetectionPostProcessFastNmsFloatTest()
309{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100310 return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
311 armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
312 armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
313 armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
314 TestData::s_BoxEncodings,
315 TestData::s_Scores,
316 TestData::s_Anchors,
317 FastNmsExpectedResults::s_DetectionBoxes,
318 FastNmsExpectedResults::s_DetectionClasses,
319 FastNmsExpectedResults::s_DetectionScores,
320 FastNmsExpectedResults::s_NumDetections,
321 false);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000322}
323
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100324template<typename FactoryType,
325 armnn::DataType QuantizedType,
326 typename RawType = armnn::ResolveType<QuantizedType>>
327void DetectionPostProcessFastNmsQuantizedTest()
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000328{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100329 armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
330 armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
331 armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000332
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100333 boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
334 boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000335
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100336 scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
337 scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000338
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100339 anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
340 anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000341
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100342 std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
343 QuantizeData<QuantizedType>(boxEncodingsData.data(),
344 TestData::s_BoxEncodings.data(),
345 boxEncodingsInfo);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000346
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100347 std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
348 QuantizeData<QuantizedType>(scoresData.data(),
349 TestData::s_Scores.data(),
350 scoresInfo);
351
352 std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
353 QuantizeData<QuantizedType>(anchorsData.data(),
354 TestData::s_Anchors.data(),
355 anchorsInfo);
356
357 return DetectionPostProcessImpl<FactoryType, QuantizedType>(
358 boxEncodingsInfo,
359 scoresInfo,
360 anchorsInfo,
361 boxEncodingsData,
362 scoresData,
363 anchorsData,
364 FastNmsExpectedResults::s_DetectionBoxes,
365 FastNmsExpectedResults::s_DetectionClasses,
366 FastNmsExpectedResults::s_DetectionScores,
367 FastNmsExpectedResults::s_NumDetections,
368 false);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100369}