blob: 2472c342ea7f70dd29e67054709f15d7c0995918 [file] [log] [blame]
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#pragma once
6
Aron Virginas-Tard4f0fea2019-04-09 14:08:06 +01007#include <ResolveType.hpp>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +00008
9#include <armnn/Types.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010010
James Conroy1f58f032021-04-27 17:13:27 +010011#include <backendsCommon/TensorHandle.hpp>
Matteo Martincighe5b8eb92019-11-28 15:45:42 +000012#include <armnn/backends/IBackendInternal.hpp>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +000013#include <backendsCommon/WorkloadFactory.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010014
15#include <backendsCommon/test/TensorCopyUtils.hpp>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +000016#include <backendsCommon/test/WorkloadFactoryHelper.hpp>
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +010017#include <backendsCommon/test/WorkloadTestUtils.hpp>
18
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +000019#include <test/TensorHelpers.hpp>
20
Sadik Armagan1625efc2021-06-10 18:24:34 +010021#include <doctest/doctest.h>
22
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010023namespace
24{
25
26using FloatData = std::vector<float>;
27using QuantData = std::pair<float, int32_t>;
28
29struct TestData
30{
31 static const armnn::TensorShape s_BoxEncodingsShape;
32 static const armnn::TensorShape s_ScoresShape;
33 static const armnn::TensorShape s_AnchorsShape;
34
35 static const QuantData s_BoxEncodingsQuantData;
36 static const QuantData s_ScoresQuantData;
37 static const QuantData s_AnchorsQuantData;
38
39 static const FloatData s_BoxEncodings;
40 static const FloatData s_Scores;
41 static const FloatData s_Anchors;
42};
43
44struct RegularNmsExpectedResults
45{
46 static const FloatData s_DetectionBoxes;
47 static const FloatData s_DetectionScores;
48 static const FloatData s_DetectionClasses;
49 static const FloatData s_NumDetections;
50};
51
52struct FastNmsExpectedResults
53{
54 static const FloatData s_DetectionBoxes;
55 static const FloatData s_DetectionScores;
56 static const FloatData s_DetectionClasses;
57 static const FloatData s_NumDetections;
58};
59
60const armnn::TensorShape TestData::s_BoxEncodingsShape = { 1, 6, 4 };
61const armnn::TensorShape TestData::s_ScoresShape = { 1, 6, 3 };
62const armnn::TensorShape TestData::s_AnchorsShape = { 6, 4 };
63
64const QuantData TestData::s_BoxEncodingsQuantData = { 1.00f, 1 };
65const QuantData TestData::s_ScoresQuantData = { 0.01f, 0 };
66const QuantData TestData::s_AnchorsQuantData = { 0.50f, 0 };
67
68const FloatData TestData::s_BoxEncodings =
69{
70 0.0f, 0.0f, 0.0f, 0.0f,
71 0.0f, 1.0f, 0.0f, 0.0f,
72 0.0f, -1.0f, 0.0f, 0.0f,
73 0.0f, 0.0f, 0.0f, 0.0f,
74 0.0f, 1.0f, 0.0f, 0.0f,
75 0.0f, 0.0f, 0.0f, 0.0f
76};
77
78const FloatData TestData::s_Scores =
79{
80 0.0f, 0.90f, 0.80f,
81 0.0f, 0.75f, 0.72f,
82 0.0f, 0.60f, 0.50f,
83 0.0f, 0.93f, 0.95f,
84 0.0f, 0.50f, 0.40f,
85 0.0f, 0.30f, 0.20f
86};
87
88const FloatData TestData::s_Anchors =
89{
90 0.5f, 0.5f, 1.0f, 1.0f,
91 0.5f, 0.5f, 1.0f, 1.0f,
92 0.5f, 0.5f, 1.0f, 1.0f,
93 0.5f, 10.5f, 1.0f, 1.0f,
94 0.5f, 10.5f, 1.0f, 1.0f,
95 0.5f, 100.5f, 1.0f, 1.0f
96};
97
98const FloatData RegularNmsExpectedResults::s_DetectionBoxes =
99{
100 0.0f, 10.0f, 1.0f, 11.0f,
101 0.0f, 10.0f, 1.0f, 11.0f,
102 0.0f, 0.0f, 0.0f, 0.0f
103};
104
105const FloatData RegularNmsExpectedResults::s_DetectionScores =
106{
107 0.95f, 0.93f, 0.0f
108};
109
110const FloatData RegularNmsExpectedResults::s_DetectionClasses =
111{
112 1.0f, 0.0f, 0.0f
113};
114
115const FloatData RegularNmsExpectedResults::s_NumDetections = { 2.0f };
116
117const FloatData FastNmsExpectedResults::s_DetectionBoxes =
118{
119 0.0f, 10.0f, 1.0f, 11.0f,
120 0.0f, 0.0f, 1.0f, 1.0f,
121 0.0f, 100.0f, 1.0f, 101.0f
122};
123
124const FloatData FastNmsExpectedResults::s_DetectionScores =
125{
126 0.95f, 0.9f, 0.3f
127};
128
129const FloatData FastNmsExpectedResults::s_DetectionClasses =
130{
131 1.0f, 0.0f, 0.0f
132};
133
134const FloatData FastNmsExpectedResults::s_NumDetections = { 3.0f };
135
136} // anonymous namespace
137
138template<typename FactoryType,
139 armnn::DataType ArmnnType,
140 typename T = armnn::ResolveType<ArmnnType>>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000141void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
142 const armnn::TensorInfo& scoresInfo,
143 const armnn::TensorInfo& anchorsInfo,
144 const std::vector<T>& boxEncodingsData,
145 const std::vector<T>& scoresData,
146 const std::vector<T>& anchorsData,
147 const std::vector<float>& expectedDetectionBoxes,
148 const std::vector<float>& expectedDetectionClasses,
149 const std::vector<float>& expectedDetectionScores,
150 const std::vector<float>& expectedNumDetections,
151 bool useRegularNms)
152{
Francis Murtagh33199c22021-02-15 10:11:28 +0000153 std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000154 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
155
156 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
157 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
Francis Murtagh623069d2020-08-14 17:24:39 +0100158 auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000159
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000160 armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000161 armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
Sadik Armagan483c8112021-06-01 09:24:52 +0100162 armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000163 armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
164
Sadik Armagan483c8112021-06-01 09:24:52 +0100165 std::vector<float> actualDetectionBoxesOutput(detectionBoxesInfo.GetNumElements());
166 std::vector<float> actualDetectionClassesOutput(detectionClassesInfo.GetNumElements());
167 std::vector<float> actualDetectionScoresOutput(detectionScoresInfo.GetNumElements());
168 std::vector<float> actualNumDetectionOutput(numDetectionInfo.GetNumElements());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000169
Francis Murtagh623069d2020-08-14 17:24:39 +0100170 auto boxedHandle = tensorHandleFactory.CreateTensorHandle(boxEncodingsInfo);
171 auto scoreshandle = tensorHandleFactory.CreateTensorHandle(scoresInfo);
172 auto anchorsHandle = tensorHandleFactory.CreateTensorHandle(anchorsInfo);
173 auto outputBoxesHandle = tensorHandleFactory.CreateTensorHandle(detectionBoxesInfo);
174 auto classesHandle = tensorHandleFactory.CreateTensorHandle(detectionClassesInfo);
175 auto outputScoresHandle = tensorHandleFactory.CreateTensorHandle(detectionScoresInfo);
176 auto numDetectionHandle = tensorHandleFactory.CreateTensorHandle(numDetectionInfo);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000177
James Conroy1f58f032021-04-27 17:13:27 +0100178 armnn::ScopedTensorHandle anchorsTensor(anchorsInfo);
Sadik Armagan483c8112021-06-01 09:24:52 +0100179 AllocateAndCopyDataToITensorHandle(&anchorsTensor, anchorsData.data());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000180
181 armnn::DetectionPostProcessQueueDescriptor data;
182 data.m_Parameters.m_UseRegularNms = useRegularNms;
183 data.m_Parameters.m_MaxDetections = 3;
184 data.m_Parameters.m_MaxClassesPerDetection = 1;
185 data.m_Parameters.m_DetectionsPerClass =1;
186 data.m_Parameters.m_NmsScoreThreshold = 0.0;
187 data.m_Parameters.m_NmsIouThreshold = 0.5;
188 data.m_Parameters.m_NumClasses = 2;
189 data.m_Parameters.m_ScaleY = 10.0;
190 data.m_Parameters.m_ScaleX = 10.0;
191 data.m_Parameters.m_ScaleH = 5.0;
192 data.m_Parameters.m_ScaleW = 5.0;
193 data.m_Anchors = &anchorsTensor;
194
195 armnn::WorkloadInfo info;
196 AddInputToWorkload(data, info, boxEncodingsInfo, boxedHandle.get());
Sadik Armagan483c8112021-06-01 09:24:52 +0100197 AddInputToWorkload(data, info, scoresInfo, scoreshandle.get());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000198 AddOutputToWorkload(data, info, detectionBoxesInfo, outputBoxesHandle.get());
199 AddOutputToWorkload(data, info, detectionClassesInfo, classesHandle.get());
200 AddOutputToWorkload(data, info, detectionScoresInfo, outputScoresHandle.get());
201 AddOutputToWorkload(data, info, numDetectionInfo, numDetectionHandle.get());
202
203 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateDetectionPostProcess(data, info);
204
205 boxedHandle->Allocate();
206 scoreshandle->Allocate();
207 outputBoxesHandle->Allocate();
208 classesHandle->Allocate();
209 outputScoresHandle->Allocate();
210 numDetectionHandle->Allocate();
211
Sadik Armagan483c8112021-06-01 09:24:52 +0100212 CopyDataToITensorHandle(boxedHandle.get(), boxEncodingsData.data());
213 CopyDataToITensorHandle(scoreshandle.get(), scoresData.data());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000214
215 workload->Execute();
216
Sadik Armagan483c8112021-06-01 09:24:52 +0100217 CopyDataFromITensorHandle(actualDetectionBoxesOutput.data(), outputBoxesHandle.get());
218 CopyDataFromITensorHandle(actualDetectionClassesOutput.data(), classesHandle.get());
219 CopyDataFromITensorHandle(actualDetectionScoresOutput.data(), outputScoresHandle.get());
220 CopyDataFromITensorHandle(actualNumDetectionOutput.data(), numDetectionHandle.get());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000221
Sadik Armagan483c8112021-06-01 09:24:52 +0100222 auto result = CompareTensors(actualDetectionBoxesOutput,
223 expectedDetectionBoxes,
224 outputBoxesHandle->GetShape(),
225 detectionBoxesInfo.GetShape());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100226 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Sadik Armagan483c8112021-06-01 09:24:52 +0100227
228 result = CompareTensors(actualDetectionClassesOutput,
229 expectedDetectionClasses,
230 classesHandle->GetShape(),
231 detectionClassesInfo.GetShape());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100232 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Sadik Armagan483c8112021-06-01 09:24:52 +0100233
234 result = CompareTensors(actualDetectionScoresOutput,
235 expectedDetectionScores,
236 outputScoresHandle->GetShape(),
237 detectionScoresInfo.GetShape());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100238 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Sadik Armagan483c8112021-06-01 09:24:52 +0100239
240 result = CompareTensors(actualNumDetectionOutput,
241 expectedNumDetections,
242 numDetectionHandle->GetShape(),
243 numDetectionInfo.GetShape());
Sadik Armagan1625efc2021-06-10 18:24:34 +0100244 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000245}
246
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100247template<armnn::DataType QuantizedType, typename RawType = armnn::ResolveType<QuantizedType>>
248void QuantizeData(RawType* quant, const float* dequant, const armnn::TensorInfo& info)
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000249{
250 for (size_t i = 0; i < info.GetNumElements(); i++)
251 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100252 quant[i] = armnn::Quantize<RawType>(
253 dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000254 }
255}
256
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100257template<typename FactoryType>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000258void DetectionPostProcessRegularNmsFloatTest()
259{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100260 return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
261 armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
262 armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
263 armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
264 TestData::s_BoxEncodings,
265 TestData::s_Scores,
266 TestData::s_Anchors,
267 RegularNmsExpectedResults::s_DetectionBoxes,
268 RegularNmsExpectedResults::s_DetectionClasses,
269 RegularNmsExpectedResults::s_DetectionScores,
270 RegularNmsExpectedResults::s_NumDetections,
271 true);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000272}
273
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100274template<typename FactoryType,
275 armnn::DataType QuantizedType,
276 typename RawType = armnn::ResolveType<QuantizedType>>
277void DetectionPostProcessRegularNmsQuantizedTest()
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000278{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100279 armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
280 armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
281 armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000282
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100283 boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
284 boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000285
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100286 scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
287 scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000288
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100289 anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
290 anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000291
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100292 std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
293 QuantizeData<QuantizedType>(boxEncodingsData.data(),
294 TestData::s_BoxEncodings.data(),
295 boxEncodingsInfo);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000296
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100297 std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
298 QuantizeData<QuantizedType>(scoresData.data(),
299 TestData::s_Scores.data(),
300 scoresInfo);
301
302 std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
303 QuantizeData<QuantizedType>(anchorsData.data(),
304 TestData::s_Anchors.data(),
305 anchorsInfo);
306
307 return DetectionPostProcessImpl<FactoryType, QuantizedType>(
308 boxEncodingsInfo,
309 scoresInfo,
310 anchorsInfo,
311 boxEncodingsData,
312 scoresData,
313 anchorsData,
314 RegularNmsExpectedResults::s_DetectionBoxes,
315 RegularNmsExpectedResults::s_DetectionClasses,
316 RegularNmsExpectedResults::s_DetectionScores,
317 RegularNmsExpectedResults::s_NumDetections,
318 true);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000319}
320
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100321template<typename FactoryType>
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000322void DetectionPostProcessFastNmsFloatTest()
323{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100324 return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
325 armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
326 armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
327 armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
328 TestData::s_BoxEncodings,
329 TestData::s_Scores,
330 TestData::s_Anchors,
331 FastNmsExpectedResults::s_DetectionBoxes,
332 FastNmsExpectedResults::s_DetectionClasses,
333 FastNmsExpectedResults::s_DetectionScores,
334 FastNmsExpectedResults::s_NumDetections,
335 false);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000336}
337
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100338template<typename FactoryType,
339 armnn::DataType QuantizedType,
340 typename RawType = armnn::ResolveType<QuantizedType>>
341void DetectionPostProcessFastNmsQuantizedTest()
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000342{
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100343 armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
344 armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
345 armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000346
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100347 boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
348 boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000349
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100350 scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
351 scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000352
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100353 anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
354 anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000355
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100356 std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
357 QuantizeData<QuantizedType>(boxEncodingsData.data(),
358 TestData::s_BoxEncodings.data(),
359 boxEncodingsInfo);
Narumol Prangnawarate0a4ad82019-02-04 19:05:27 +0000360
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100361 std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
362 QuantizeData<QuantizedType>(scoresData.data(),
363 TestData::s_Scores.data(),
364 scoresInfo);
365
366 std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
367 QuantizeData<QuantizedType>(anchorsData.data(),
368 TestData::s_Anchors.data(),
369 anchorsInfo);
370
371 return DetectionPostProcessImpl<FactoryType, QuantizedType>(
372 boxEncodingsInfo,
373 scoresInfo,
374 anchorsInfo,
375 boxEncodingsData,
376 scoresData,
377 anchorsData,
378 FastNmsExpectedResults::s_DetectionBoxes,
379 FastNmsExpectedResults::s_DetectionClasses,
380 FastNmsExpectedResults::s_DetectionScores,
381 FastNmsExpectedResults::s_NumDetections,
382 false);
Aron Virginas-Tar00d306e2019-08-28 18:08:46 +0100383}