blob: fab6e00bad2f01e667125835e7c644909832123c [file] [log] [blame]
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01006#include <reference/workloads/DetectionPostProcess.hpp>
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00007
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10
11#include <boost/test/unit_test.hpp>
12
13BOOST_AUTO_TEST_SUITE(RefDetectionPostProcess)
14
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000015BOOST_AUTO_TEST_CASE(TopKSortTest)
16{
17 unsigned int k = 3;
18 unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
19 float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010020 armnn::TopKSort(k, indices, values, 8);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000021 BOOST_TEST(indices[0] == 7);
22 BOOST_TEST(indices[1] == 1);
23 BOOST_TEST(indices[2] == 2);
24}
25
26BOOST_AUTO_TEST_CASE(FullTopKSortTest)
27{
28 unsigned int k = 8;
29 unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
30 float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010031 armnn::TopKSort(k, indices, values, 8);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000032 BOOST_TEST(indices[0] == 7);
33 BOOST_TEST(indices[1] == 1);
34 BOOST_TEST(indices[2] == 2);
35 BOOST_TEST(indices[3] == 3);
36 BOOST_TEST(indices[4] == 4);
37 BOOST_TEST(indices[5] == 5);
38 BOOST_TEST(indices[6] == 6);
39 BOOST_TEST(indices[7] == 0);
40}
41
42BOOST_AUTO_TEST_CASE(IouTest)
43{
44 float boxI[4] = { 0.0f, 0.0f, 10.0f, 10.0f };
45 float boxJ[4] = { 1.0f, 1.0f, 11.0f, 11.0f };
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010046 float iou = armnn::IntersectionOverUnion(boxI, boxJ);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000047 BOOST_TEST(iou == 0.68, boost::test_tools::tolerance(0.001));
48}
49
50BOOST_AUTO_TEST_CASE(NmsFunction)
51{
52 std::vector<float> boxCorners({
53 0.0f, 0.0f, 1.0f, 1.0f,
54 0.0f, 0.1f, 1.0f, 1.1f,
55 0.0f, -0.1f, 1.0f, 0.9f,
56 0.0f, 10.0f, 1.0f, 11.0f,
57 0.0f, 10.1f, 1.0f, 11.1f,
58 0.0f, 100.0f, 1.0f, 101.0f
59 });
60
61 std::vector<float> scores({ 0.9f, 0.75f, 0.6f, 0.93f, 0.5f, 0.3f });
62
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010063 std::vector<unsigned int> result =
64 armnn::NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5);
65
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000066 BOOST_TEST(result.size() == 3);
67 BOOST_TEST(result[0] == 3);
68 BOOST_TEST(result[1] == 0);
69 BOOST_TEST(result[2] == 5);
70}
71
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010072void DetectionPostProcessTestImpl(bool useRegularNms,
73 const std::vector<float>& expectedDetectionBoxes,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000074 const std::vector<float>& expectedDetectionClasses,
75 const std::vector<float>& expectedDetectionScores,
76 const std::vector<float>& expectedNumDetections)
77{
78 armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::Float32);
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +000079 armnn::TensorInfo scoresInfo({ 1, 6, 3 }, armnn::DataType::Float32);
80 armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000081
82 armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
83 armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
84 armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
85 armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
86
87 armnn::DetectionPostProcessDescriptor desc;
88 desc.m_UseRegularNms = useRegularNms;
89 desc.m_MaxDetections = 3;
90 desc.m_MaxClassesPerDetection = 1;
91 desc.m_DetectionsPerClass =1;
92 desc.m_NmsScoreThreshold = 0.0;
93 desc.m_NmsIouThreshold = 0.5;
94 desc.m_NumClasses = 2;
95 desc.m_ScaleY = 10.0;
96 desc.m_ScaleX = 10.0;
97 desc.m_ScaleH = 5.0;
98 desc.m_ScaleW = 5.0;
99
100 std::vector<float> boxEncodings({
101 0.0f, 0.0f, 0.0f, 0.0f,
102 0.0f, 1.0f, 0.0f, 0.0f,
103 0.0f, -1.0f, 0.0f, 0.0f,
104 0.0f, 0.0f, 0.0f, 0.0f,
105 0.0f, 1.0f, 0.0f, 0.0f,
106 0.0f, 0.0f, 0.0f, 0.0f
107 });
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100108
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000109 std::vector<float> scores({
110 0.0f, 0.9f, 0.8f,
111 0.0f, 0.75f, 0.72f,
112 0.0f, 0.6f, 0.5f,
113 0.0f, 0.93f, 0.95f,
114 0.0f, 0.5f, 0.4f,
115 0.0f, 0.3f, 0.2f
116 });
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100117
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000118 std::vector<float> anchors({
119 0.5f, 0.5f, 1.0f, 1.0f,
120 0.5f, 0.5f, 1.0f, 1.0f,
121 0.5f, 0.5f, 1.0f, 1.0f,
122 0.5f, 10.5f, 1.0f, 1.0f,
123 0.5f, 10.5f, 1.0f, 1.0f,
124 0.5f, 100.5f, 1.0f, 1.0f
125 });
126
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100127 auto boxEncodingsDecoder = armnn::MakeDecoder<float>(boxEncodingsInfo, boxEncodings.data());
128 auto scoresDecoder = armnn::MakeDecoder<float>(scoresInfo, scores.data());
129 auto anchorsDecoder = armnn::MakeDecoder<float>(anchorsInfo, anchors.data());
130
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000131 std::vector<float> detectionBoxes(detectionBoxesInfo.GetNumElements());
132 std::vector<float> detectionScores(detectionScoresInfo.GetNumElements());
133 std::vector<float> detectionClasses(detectionClassesInfo.GetNumElements());
134 std::vector<float> numDetections(1);
135
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100136 armnn::DetectionPostProcess(boxEncodingsInfo,
137 scoresInfo,
138 anchorsInfo,
139 detectionBoxesInfo,
140 detectionClassesInfo,
141 detectionScoresInfo,
142 numDetectionInfo,
143 desc,
144 *boxEncodingsDecoder,
145 *scoresDecoder,
146 *anchorsDecoder,
147 detectionBoxes.data(),
148 detectionClasses.data(),
149 detectionScores.data(),
150 numDetections.data());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000151
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100152 BOOST_CHECK_EQUAL_COLLECTIONS(detectionBoxes.begin(),
153 detectionBoxes.end(),
154 expectedDetectionBoxes.begin(),
155 expectedDetectionBoxes.end());
156
157 BOOST_CHECK_EQUAL_COLLECTIONS(detectionScores.begin(),
158 detectionScores.end(),
159 expectedDetectionScores.begin(),
160 expectedDetectionScores.end());
161
162 BOOST_CHECK_EQUAL_COLLECTIONS(detectionClasses.begin(),
163 detectionClasses.end(),
164 expectedDetectionClasses.begin(),
165 expectedDetectionClasses.end());
166
167 BOOST_CHECK_EQUAL_COLLECTIONS(numDetections.begin(),
168 numDetections.end(),
169 expectedNumDetections.begin(),
170 expectedNumDetections.end());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000171}
172
173BOOST_AUTO_TEST_CASE(RegularNmsDetectionPostProcess)
174{
175 std::vector<float> expectedDetectionBoxes({
176 0.0f, 10.0f, 1.0f, 11.0f,
177 0.0f, 10.0f, 1.0f, 11.0f,
178 0.0f, 0.0f, 0.0f, 0.0f
179 });
180
181 std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
182 std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
183 std::vector<float> expectedNumDetections({ 2.0f });
184
185 DetectionPostProcessTestImpl(true, expectedDetectionBoxes, expectedDetectionClasses,
186 expectedDetectionScores, expectedNumDetections);
187}
188
189BOOST_AUTO_TEST_CASE(FastNmsDetectionPostProcess)
190{
191 std::vector<float> expectedDetectionBoxes({
192 0.0f, 10.0f, 1.0f, 11.0f,
193 0.0f, 0.0f, 1.0f, 1.0f,
194 0.0f, 100.0f, 1.0f, 101.0f
195 });
196 std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
197 std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
198 std::vector<float> expectedNumDetections({ 3.0f });
199
200 DetectionPostProcessTestImpl(false, expectedDetectionBoxes, expectedDetectionClasses,
201 expectedDetectionScores, expectedNumDetections);
202}
203
204BOOST_AUTO_TEST_SUITE_END()