blob: a9faff70b15ed897f51d7381f984515ee532b799 [file] [log] [blame]
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "reference/workloads/DetectionPostProcess.cpp"
7
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10
11#include <boost/test/unit_test.hpp>
12
13BOOST_AUTO_TEST_SUITE(RefDetectionPostProcess)
14
15
16BOOST_AUTO_TEST_CASE(TopKSortTest)
17{
18 unsigned int k = 3;
19 unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
20 float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
21 TopKSort(k, indices, values, 8);
22 BOOST_TEST(indices[0] == 7);
23 BOOST_TEST(indices[1] == 1);
24 BOOST_TEST(indices[2] == 2);
25}
26
27BOOST_AUTO_TEST_CASE(FullTopKSortTest)
28{
29 unsigned int k = 8;
30 unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
31 float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
32 TopKSort(k, indices, values, 8);
33 BOOST_TEST(indices[0] == 7);
34 BOOST_TEST(indices[1] == 1);
35 BOOST_TEST(indices[2] == 2);
36 BOOST_TEST(indices[3] == 3);
37 BOOST_TEST(indices[4] == 4);
38 BOOST_TEST(indices[5] == 5);
39 BOOST_TEST(indices[6] == 6);
40 BOOST_TEST(indices[7] == 0);
41}
42
43BOOST_AUTO_TEST_CASE(IouTest)
44{
45 float boxI[4] = { 0.0f, 0.0f, 10.0f, 10.0f };
46 float boxJ[4] = { 1.0f, 1.0f, 11.0f, 11.0f };
47 float iou = IntersectionOverUnion(boxI, boxJ);
48 BOOST_TEST(iou == 0.68, boost::test_tools::tolerance(0.001));
49}
50
51BOOST_AUTO_TEST_CASE(NmsFunction)
52{
53 std::vector<float> boxCorners({
54 0.0f, 0.0f, 1.0f, 1.0f,
55 0.0f, 0.1f, 1.0f, 1.1f,
56 0.0f, -0.1f, 1.0f, 0.9f,
57 0.0f, 10.0f, 1.0f, 11.0f,
58 0.0f, 10.1f, 1.0f, 11.1f,
59 0.0f, 100.0f, 1.0f, 101.0f
60 });
61
62 std::vector<float> scores({ 0.9f, 0.75f, 0.6f, 0.93f, 0.5f, 0.3f });
63
64 std::vector<unsigned int> result = NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5);
65 BOOST_TEST(result.size() == 3);
66 BOOST_TEST(result[0] == 3);
67 BOOST_TEST(result[1] == 0);
68 BOOST_TEST(result[2] == 5);
69}
70
71void DetectionPostProcessTestImpl(bool useRegularNms, const std::vector<float>& expectedDetectionBoxes,
72 const std::vector<float>& expectedDetectionClasses,
73 const std::vector<float>& expectedDetectionScores,
74 const std::vector<float>& expectedNumDetections)
75{
76 armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::Float32);
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +000077 armnn::TensorInfo scoresInfo({ 1, 6, 3 }, armnn::DataType::Float32);
78 armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000079
80 armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
81 armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
82 armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
83 armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
84
85 armnn::DetectionPostProcessDescriptor desc;
86 desc.m_UseRegularNms = useRegularNms;
87 desc.m_MaxDetections = 3;
88 desc.m_MaxClassesPerDetection = 1;
89 desc.m_DetectionsPerClass =1;
90 desc.m_NmsScoreThreshold = 0.0;
91 desc.m_NmsIouThreshold = 0.5;
92 desc.m_NumClasses = 2;
93 desc.m_ScaleY = 10.0;
94 desc.m_ScaleX = 10.0;
95 desc.m_ScaleH = 5.0;
96 desc.m_ScaleW = 5.0;
97
98 std::vector<float> boxEncodings({
99 0.0f, 0.0f, 0.0f, 0.0f,
100 0.0f, 1.0f, 0.0f, 0.0f,
101 0.0f, -1.0f, 0.0f, 0.0f,
102 0.0f, 0.0f, 0.0f, 0.0f,
103 0.0f, 1.0f, 0.0f, 0.0f,
104 0.0f, 0.0f, 0.0f, 0.0f
105 });
106 std::vector<float> scores({
107 0.0f, 0.9f, 0.8f,
108 0.0f, 0.75f, 0.72f,
109 0.0f, 0.6f, 0.5f,
110 0.0f, 0.93f, 0.95f,
111 0.0f, 0.5f, 0.4f,
112 0.0f, 0.3f, 0.2f
113 });
114 std::vector<float> anchors({
115 0.5f, 0.5f, 1.0f, 1.0f,
116 0.5f, 0.5f, 1.0f, 1.0f,
117 0.5f, 0.5f, 1.0f, 1.0f,
118 0.5f, 10.5f, 1.0f, 1.0f,
119 0.5f, 10.5f, 1.0f, 1.0f,
120 0.5f, 100.5f, 1.0f, 1.0f
121 });
122
123 std::vector<float> detectionBoxes(detectionBoxesInfo.GetNumElements());
124 std::vector<float> detectionScores(detectionScoresInfo.GetNumElements());
125 std::vector<float> detectionClasses(detectionClassesInfo.GetNumElements());
126 std::vector<float> numDetections(1);
127
128 armnn::DetectionPostProcess(boxEncodingsInfo, scoresInfo, anchorsInfo,
129 detectionBoxesInfo, detectionClassesInfo,
130 detectionScoresInfo, numDetectionInfo, desc,
131 boxEncodings.data(), scores.data(), anchors.data(),
132 detectionBoxes.data(), detectionClasses.data(),
133 detectionScores.data(), numDetections.data());
134
135 BOOST_TEST(detectionBoxes == expectedDetectionBoxes);
136 BOOST_TEST(detectionScores == expectedDetectionScores);
137 BOOST_TEST(detectionClasses == expectedDetectionClasses);
138 BOOST_TEST(numDetections == expectedNumDetections);
139}
140
141BOOST_AUTO_TEST_CASE(RegularNmsDetectionPostProcess)
142{
143 std::vector<float> expectedDetectionBoxes({
144 0.0f, 10.0f, 1.0f, 11.0f,
145 0.0f, 10.0f, 1.0f, 11.0f,
146 0.0f, 0.0f, 0.0f, 0.0f
147 });
148
149 std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
150 std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
151 std::vector<float> expectedNumDetections({ 2.0f });
152
153 DetectionPostProcessTestImpl(true, expectedDetectionBoxes, expectedDetectionClasses,
154 expectedDetectionScores, expectedNumDetections);
155}
156
157BOOST_AUTO_TEST_CASE(FastNmsDetectionPostProcess)
158{
159 std::vector<float> expectedDetectionBoxes({
160 0.0f, 10.0f, 1.0f, 11.0f,
161 0.0f, 0.0f, 1.0f, 1.0f,
162 0.0f, 100.0f, 1.0f, 101.0f
163 });
164 std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
165 std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
166 std::vector<float> expectedNumDetections({ 3.0f });
167
168 DetectionPostProcessTestImpl(false, expectedDetectionBoxes, expectedDetectionClasses,
169 expectedDetectionScores, expectedNumDetections);
170}
171
172BOOST_AUTO_TEST_SUITE_END()