blob: 763578be3c8a5da3f283697eb711c3b1fab16d7b [file] [log] [blame]
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
Aron Virginas-Tar6331f912019-06-03 17:10:02 +01006#include <reference/workloads/DetectionPostProcess.hpp>
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00007
8#include <armnn/Descriptors.hpp>
9#include <armnn/Types.hpp>
10
Sadik Armagan1625efc2021-06-10 18:24:34 +010011#include <doctest/doctest.h>
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000012
Sadik Armagan1625efc2021-06-10 18:24:34 +010013TEST_SUITE("RefDetectionPostProcess")
14{
15TEST_CASE("TopKSortTest")
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000016{
17 unsigned int k = 3;
18 unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
19 float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010020 armnn::TopKSort(k, indices, values, 8);
Sadik Armagan1625efc2021-06-10 18:24:34 +010021 CHECK(indices[0] == 7);
22 CHECK(indices[1] == 1);
23 CHECK(indices[2] == 2);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000024}
25
Sadik Armagan1625efc2021-06-10 18:24:34 +010026TEST_CASE("FullTopKSortTest")
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000027{
28 unsigned int k = 8;
29 unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
30 float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010031 armnn::TopKSort(k, indices, values, 8);
Sadik Armagan1625efc2021-06-10 18:24:34 +010032 CHECK(indices[0] == 7);
33 CHECK(indices[1] == 1);
34 CHECK(indices[2] == 2);
35 CHECK(indices[3] == 3);
36 CHECK(indices[4] == 4);
37 CHECK(indices[5] == 5);
38 CHECK(indices[6] == 6);
39 CHECK(indices[7] == 0);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000040}
41
Sadik Armagan1625efc2021-06-10 18:24:34 +010042TEST_CASE("IouTest")
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000043{
44 float boxI[4] = { 0.0f, 0.0f, 10.0f, 10.0f };
45 float boxJ[4] = { 1.0f, 1.0f, 11.0f, 11.0f };
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010046 float iou = armnn::IntersectionOverUnion(boxI, boxJ);
Sadik Armagan1625efc2021-06-10 18:24:34 +010047 CHECK(iou == doctest::Approx(0.68).epsilon(0.001f));
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000048}
49
Sadik Armagan1625efc2021-06-10 18:24:34 +010050TEST_CASE("NmsFunction")
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000051{
52 std::vector<float> boxCorners({
53 0.0f, 0.0f, 1.0f, 1.0f,
54 0.0f, 0.1f, 1.0f, 1.1f,
55 0.0f, -0.1f, 1.0f, 0.9f,
56 0.0f, 10.0f, 1.0f, 11.0f,
57 0.0f, 10.1f, 1.0f, 11.1f,
58 0.0f, 100.0f, 1.0f, 101.0f
59 });
60
61 std::vector<float> scores({ 0.9f, 0.75f, 0.6f, 0.93f, 0.5f, 0.3f });
62
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010063 std::vector<unsigned int> result =
64 armnn::NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5);
65
Sadik Armagan1625efc2021-06-10 18:24:34 +010066 CHECK(result.size() == 3);
67 CHECK(result[0] == 3);
68 CHECK(result[1] == 0);
69 CHECK(result[2] == 5);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000070}
71
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010072void DetectionPostProcessTestImpl(bool useRegularNms,
73 const std::vector<float>& expectedDetectionBoxes,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000074 const std::vector<float>& expectedDetectionClasses,
75 const std::vector<float>& expectedDetectionScores,
76 const std::vector<float>& expectedNumDetections)
77{
78 armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::Float32);
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +000079 armnn::TensorInfo scoresInfo({ 1, 6, 3 }, armnn::DataType::Float32);
80 armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000081
82 armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
83 armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
84 armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
85 armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
86
87 armnn::DetectionPostProcessDescriptor desc;
88 desc.m_UseRegularNms = useRegularNms;
89 desc.m_MaxDetections = 3;
90 desc.m_MaxClassesPerDetection = 1;
91 desc.m_DetectionsPerClass =1;
92 desc.m_NmsScoreThreshold = 0.0;
93 desc.m_NmsIouThreshold = 0.5;
94 desc.m_NumClasses = 2;
95 desc.m_ScaleY = 10.0;
96 desc.m_ScaleX = 10.0;
97 desc.m_ScaleH = 5.0;
98 desc.m_ScaleW = 5.0;
99
100 std::vector<float> boxEncodings({
101 0.0f, 0.0f, 0.0f, 0.0f,
102 0.0f, 1.0f, 0.0f, 0.0f,
103 0.0f, -1.0f, 0.0f, 0.0f,
104 0.0f, 0.0f, 0.0f, 0.0f,
105 0.0f, 1.0f, 0.0f, 0.0f,
106 0.0f, 0.0f, 0.0f, 0.0f
107 });
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100108
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000109 std::vector<float> scores({
110 0.0f, 0.9f, 0.8f,
111 0.0f, 0.75f, 0.72f,
112 0.0f, 0.6f, 0.5f,
113 0.0f, 0.93f, 0.95f,
114 0.0f, 0.5f, 0.4f,
115 0.0f, 0.3f, 0.2f
116 });
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100117
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000118 std::vector<float> anchors({
119 0.5f, 0.5f, 1.0f, 1.0f,
120 0.5f, 0.5f, 1.0f, 1.0f,
121 0.5f, 0.5f, 1.0f, 1.0f,
122 0.5f, 10.5f, 1.0f, 1.0f,
123 0.5f, 10.5f, 1.0f, 1.0f,
124 0.5f, 100.5f, 1.0f, 1.0f
125 });
126
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100127 auto boxEncodingsDecoder = armnn::MakeDecoder<float>(boxEncodingsInfo, boxEncodings.data());
128 auto scoresDecoder = armnn::MakeDecoder<float>(scoresInfo, scores.data());
129 auto anchorsDecoder = armnn::MakeDecoder<float>(anchorsInfo, anchors.data());
130
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000131 std::vector<float> detectionBoxes(detectionBoxesInfo.GetNumElements());
132 std::vector<float> detectionScores(detectionScoresInfo.GetNumElements());
133 std::vector<float> detectionClasses(detectionClassesInfo.GetNumElements());
134 std::vector<float> numDetections(1);
135
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100136 armnn::DetectionPostProcess(boxEncodingsInfo,
137 scoresInfo,
138 anchorsInfo,
139 detectionBoxesInfo,
140 detectionClassesInfo,
141 detectionScoresInfo,
142 numDetectionInfo,
143 desc,
144 *boxEncodingsDecoder,
145 *scoresDecoder,
146 *anchorsDecoder,
147 detectionBoxes.data(),
148 detectionClasses.data(),
149 detectionScores.data(),
150 numDetections.data());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000151
Sadik Armagan1625efc2021-06-10 18:24:34 +0100152 CHECK(std::equal(detectionBoxes.begin(),
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100153 detectionBoxes.end(),
154 expectedDetectionBoxes.begin(),
Sadik Armagan1625efc2021-06-10 18:24:34 +0100155 expectedDetectionBoxes.end()));
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100156
Sadik Armagan1625efc2021-06-10 18:24:34 +0100157 CHECK(std::equal(detectionScores.begin(), detectionScores.end(),
158 expectedDetectionScores.begin(), expectedDetectionScores.end()));
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100159
Sadik Armagan1625efc2021-06-10 18:24:34 +0100160 CHECK(std::equal(detectionClasses.begin(), detectionClasses.end(),
161 expectedDetectionClasses.begin(), expectedDetectionClasses.end()));
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100162
Sadik Armagan1625efc2021-06-10 18:24:34 +0100163 CHECK(std::equal(numDetections.begin(), numDetections.end(),
164 expectedNumDetections.begin(), expectedNumDetections.end()));
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000165}
166
Sadik Armagan1625efc2021-06-10 18:24:34 +0100167TEST_CASE("RegularNmsDetectionPostProcess")
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000168{
169 std::vector<float> expectedDetectionBoxes({
170 0.0f, 10.0f, 1.0f, 11.0f,
171 0.0f, 10.0f, 1.0f, 11.0f,
172 0.0f, 0.0f, 0.0f, 0.0f
173 });
174
175 std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
176 std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
177 std::vector<float> expectedNumDetections({ 2.0f });
178
179 DetectionPostProcessTestImpl(true, expectedDetectionBoxes, expectedDetectionClasses,
180 expectedDetectionScores, expectedNumDetections);
181}
182
Sadik Armagan1625efc2021-06-10 18:24:34 +0100183TEST_CASE("FastNmsDetectionPostProcess")
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000184{
185 std::vector<float> expectedDetectionBoxes({
186 0.0f, 10.0f, 1.0f, 11.0f,
187 0.0f, 0.0f, 1.0f, 1.0f,
188 0.0f, 100.0f, 1.0f, 101.0f
189 });
190 std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
191 std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
192 std::vector<float> expectedNumDetections({ 3.0f });
193
194 DetectionPostProcessTestImpl(false, expectedDetectionBoxes, expectedDetectionClasses,
195 expectedDetectionScores, expectedNumDetections);
196}
197
Sadik Armagan1625efc2021-06-10 18:24:34 +0100198}