blob: c5ab327f900d5d34c1edf4d18783abc77bf521e1 [file] [log] [blame]
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "DetectionPostProcess.hpp"
7
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +01008#include <armnn/utility/Assert.hpp>
Teresa Charlin5306dc82023-10-30 22:29:58 +00009#include <armnn/utility/IgnoreUnused.hpp>
Matthew Sloyan171214c2020-09-09 09:07:37 +010010#include <armnn/utility/NumericCast.hpp>
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000011
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000012#include <algorithm>
13#include <numeric>
14
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010015namespace armnn
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000016{
17
18std::vector<unsigned int> GenerateRangeK(unsigned int k)
19{
20 std::vector<unsigned int> range(k);
21 std::iota(range.begin(), range.end(), 0);
22 return range;
23}
24
25void TopKSort(unsigned int k, unsigned int* indices, const float* values, unsigned int numElement)
26{
27 std::partial_sort(indices, indices + k, indices + numElement,
28 [&values](unsigned int i, unsigned int j) { return values[i] > values[j]; });
29}
30
31float IntersectionOverUnion(const float* boxI, const float* boxJ)
32{
33 // Box-corner format: ymin, xmin, ymax, xmax.
34 const int yMin = 0;
35 const int xMin = 1;
36 const int yMax = 2;
37 const int xMax = 3;
38 float areaI = (boxI[yMax] - boxI[yMin]) * (boxI[xMax] - boxI[xMin]);
39 float areaJ = (boxJ[yMax] - boxJ[yMin]) * (boxJ[xMax] - boxJ[xMin]);
40 float yMinIntersection = std::max(boxI[yMin], boxJ[yMin]);
41 float xMinIntersection = std::max(boxI[xMin], boxJ[xMin]);
42 float yMaxIntersection = std::min(boxI[yMax], boxJ[yMax]);
43 float xMaxIntersection = std::min(boxI[xMax], boxJ[xMax]);
44 float areaIntersection = std::max(yMaxIntersection - yMinIntersection, 0.0f) *
45 std::max(xMaxIntersection - xMinIntersection, 0.0f);
46 float areaUnion = areaI + areaJ - areaIntersection;
47 return areaIntersection / areaUnion;
48}
49
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010050std::vector<unsigned int> NonMaxSuppression(unsigned int numBoxes,
51 const std::vector<float>& boxCorners,
52 const std::vector<float>& scores,
53 float nmsScoreThreshold,
54 unsigned int maxDetection,
55 float nmsIouThreshold)
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000056{
57 // Select boxes that have scores above a given threshold.
58 std::vector<float> scoresAboveThreshold;
59 std::vector<unsigned int> indicesAboveThreshold;
60 for (unsigned int i = 0; i < numBoxes; ++i)
61 {
62 if (scores[i] >= nmsScoreThreshold)
63 {
64 scoresAboveThreshold.push_back(scores[i]);
65 indicesAboveThreshold.push_back(i);
66 }
67 }
68
69 // Sort the indices based on scores.
Matthew Sloyan171214c2020-09-09 09:07:37 +010070 unsigned int numAboveThreshold = armnn::numeric_cast<unsigned int>(scoresAboveThreshold.size());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000071 std::vector<unsigned int> sortedIndices = GenerateRangeK(numAboveThreshold);
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010072 TopKSort(numAboveThreshold, sortedIndices.data(), scoresAboveThreshold.data(), numAboveThreshold);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000073
74 // Number of output cannot be more than max detections specified in the option.
75 unsigned int numOutput = std::min(maxDetection, numAboveThreshold);
76 std::vector<unsigned int> outputIndices;
77 std::vector<bool> visited(numAboveThreshold, false);
78
79 // Prune out the boxes with high intersection over union by keeping the box with higher score.
80 for (unsigned int i = 0; i < numAboveThreshold; ++i)
81 {
82 if (outputIndices.size() >= numOutput)
83 {
84 break;
85 }
86 if (!visited[sortedIndices[i]])
87 {
88 outputIndices.push_back(indicesAboveThreshold[sortedIndices[i]]);
antkillerfarmdb6e8a92020-10-15 11:02:07 +080089 for (unsigned int j = i + 1; j < numAboveThreshold; ++j)
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000090 {
antkillerfarmdb6e8a92020-10-15 11:02:07 +080091 unsigned int iIndex = indicesAboveThreshold[sortedIndices[i]] * 4;
92 unsigned int jIndex = indicesAboveThreshold[sortedIndices[j]] * 4;
93 if (IntersectionOverUnion(&boxCorners[iIndex], &boxCorners[jIndex]) > nmsIouThreshold)
94 {
95 visited[sortedIndices[j]] = true;
96 }
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000097 }
98 }
99 }
100 return outputIndices;
101}
102
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100103void AllocateOutputData(unsigned int numOutput,
104 unsigned int numSelected,
105 const std::vector<float>& boxCorners,
106 const std::vector<unsigned int>& outputIndices,
107 const std::vector<unsigned int>& selectedBoxes,
108 const std::vector<unsigned int>& selectedClasses,
109 const std::vector<float>& selectedScores,
110 float* detectionBoxes,
111 float* detectionScores,
112 float* detectionClasses,
113 float* numDetections)
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000114{
115 for (unsigned int i = 0; i < numOutput; ++i)
116 {
117 unsigned int boxIndex = i * 4;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000118 if (i < numSelected)
119 {
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000120 unsigned int boxCornorIndex = selectedBoxes[outputIndices[i]] * 4;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000121 detectionScores[i] = selectedScores[outputIndices[i]];
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100122 detectionClasses[i] = armnn::numeric_cast<float>(selectedClasses[outputIndices[i]]);
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000123 detectionBoxes[boxIndex] = boxCorners[boxCornorIndex];
124 detectionBoxes[boxIndex + 1] = boxCorners[boxCornorIndex + 1];
125 detectionBoxes[boxIndex + 2] = boxCorners[boxCornorIndex + 2];
126 detectionBoxes[boxIndex + 3] = boxCorners[boxCornorIndex + 3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000127 }
128 else
129 {
130 detectionScores[i] = 0.0f;
131 detectionClasses[i] = 0.0f;
132 detectionBoxes[boxIndex] = 0.0f;
133 detectionBoxes[boxIndex + 1] = 0.0f;
134 detectionBoxes[boxIndex + 2] = 0.0f;
135 detectionBoxes[boxIndex + 3] = 0.0f;
136 }
137 }
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100138 numDetections[0] = armnn::numeric_cast<float>(numSelected);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000139}
140
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000141void DetectionPostProcess(const TensorInfo& boxEncodingsInfo,
142 const TensorInfo& scoresInfo,
143 const TensorInfo& anchorsInfo,
144 const TensorInfo& detectionBoxesInfo,
145 const TensorInfo& detectionClassesInfo,
146 const TensorInfo& detectionScoresInfo,
147 const TensorInfo& numDetectionsInfo,
148 const DetectionPostProcessDescriptor& desc,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100149 Decoder<float>& boxEncodings,
150 Decoder<float>& scores,
151 Decoder<float>& anchors,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000152 float* detectionBoxes,
153 float* detectionClasses,
154 float* detectionScores,
155 float* numDetections)
156{
Jan Eilers8eb25602020-03-09 12:13:48 +0000157 IgnoreUnused(anchorsInfo, detectionClassesInfo, detectionScoresInfo, numDetectionsInfo);
Derek Lamberti901ea112019-12-10 22:07:09 +0000158
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000159 // Transform center-size format which is (ycenter, xcenter, height, width) to box-corner format,
160 // which represents the lower left corner and the upper right corner (ymin, xmin, ymax, xmax)
161 std::vector<float> boxCorners(boxEncodingsInfo.GetNumElements());
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100162
163 const unsigned int numBoxes = boxEncodingsInfo.GetShape()[1];
164 const unsigned int numScores = scoresInfo.GetNumElements();
165
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000166 for (unsigned int i = 0; i < numBoxes; ++i)
167 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100168 // Y
169 float boxEncodingY = boxEncodings.Get();
170 float anchorY = anchors.Get();
171
172 ++boxEncodings;
173 ++anchors;
174
175 // X
176 float boxEncodingX = boxEncodings.Get();
177 float anchorX = anchors.Get();
178
179 ++boxEncodings;
180 ++anchors;
181
182 // H
183 float boxEncodingH = boxEncodings.Get();
184 float anchorH = anchors.Get();
185
186 ++boxEncodings;
187 ++anchors;
188
189 // W
190 float boxEncodingW = boxEncodings.Get();
191 float anchorW = anchors.Get();
192
193 ++boxEncodings;
194 ++anchors;
195
196 float yCentre = boxEncodingY / desc.m_ScaleY * anchorH + anchorY;
197 float xCentre = boxEncodingX / desc.m_ScaleX * anchorW + anchorX;
198
199 float halfH = 0.5f * expf(boxEncodingH / desc.m_ScaleH) * anchorH;
200 float halfW = 0.5f * expf(boxEncodingW / desc.m_ScaleW) * anchorW;
201
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000202 unsigned int indexY = i * 4;
203 unsigned int indexX = indexY + 1;
204 unsigned int indexH = indexX + 1;
205 unsigned int indexW = indexH + 1;
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100206
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000207 // ymin
208 boxCorners[indexY] = yCentre - halfH;
209 // xmin
210 boxCorners[indexX] = xCentre - halfW;
211 // ymax
212 boxCorners[indexH] = yCentre + halfH;
213 // xmax
214 boxCorners[indexW] = xCentre + halfW;
215
Narumol Prangnawaratac2770a2020-04-01 16:51:23 +0100216 ARMNN_ASSERT(boxCorners[indexY] < boxCorners[indexH]);
217 ARMNN_ASSERT(boxCorners[indexX] < boxCorners[indexW]);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000218 }
219
220 unsigned int numClassesWithBg = desc.m_NumClasses + 1;
221
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100222 // Decode scores
223 std::vector<float> decodedScores;
224 decodedScores.reserve(numScores);
225
226 for (unsigned int i = 0u; i < numScores; ++i)
227 {
228 decodedScores.emplace_back(scores.Get());
229 ++scores;
230 }
231
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000232 // Perform Non Max Suppression.
233 if (desc.m_UseRegularNms)
234 {
235 // Perform Regular NMS.
236 // For each class, perform NMS and select max detection numbers of the highest score across all classes.
237 std::vector<float> classScores(numBoxes);
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100238
239 std::vector<unsigned int> selectedBoxesAfterNms;
240 selectedBoxesAfterNms.reserve(numBoxes);
241
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000242 std::vector<float> selectedScoresAfterNms;
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100243 selectedBoxesAfterNms.reserve(numScores);
244
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000245 std::vector<unsigned int> selectedClasses;
246
247 for (unsigned int c = 0; c < desc.m_NumClasses; ++c)
248 {
249 // For each boxes, get scores of the boxes for the class c.
250 for (unsigned int i = 0; i < numBoxes; ++i)
251 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100252 classScores[i] = decodedScores[i * numClassesWithBg + c + 1];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000253 }
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100254 std::vector<unsigned int> selectedIndices = NonMaxSuppression(numBoxes,
255 boxCorners,
256 classScores,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000257 desc.m_NmsScoreThreshold,
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000258 desc.m_DetectionsPerClass,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000259 desc.m_NmsIouThreshold);
260
261 for (unsigned int i = 0; i < selectedIndices.size(); ++i)
262 {
263 selectedBoxesAfterNms.push_back(selectedIndices[i]);
264 selectedScoresAfterNms.push_back(classScores[selectedIndices[i]]);
265 selectedClasses.push_back(c);
266 }
267 }
268
269 // Select max detection numbers of the highest score across all classes
Matthew Sloyan171214c2020-09-09 09:07:37 +0100270 unsigned int numSelected = armnn::numeric_cast<unsigned int>(selectedBoxesAfterNms.size());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000271 unsigned int numOutput = std::min(desc.m_MaxDetections, numSelected);
272
273 // Sort the max scores among the selected indices.
274 std::vector<unsigned int> outputIndices = GenerateRangeK(numSelected);
275 TopKSort(numOutput, outputIndices.data(), selectedScoresAfterNms.data(), numSelected);
276
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000277 AllocateOutputData(detectionBoxesInfo.GetShape()[1], numOutput, boxCorners, outputIndices,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000278 selectedBoxesAfterNms, selectedClasses, selectedScoresAfterNms,
279 detectionBoxes, detectionScores, detectionClasses, numDetections);
280 }
281 else
282 {
283 // Perform Fast NMS.
284 // Select max scores of boxes and perform NMS on max scores,
285 // select max detection numbers of the highest score
286 unsigned int numClassesPerBox = std::min(desc.m_MaxClassesPerDetection, desc.m_NumClasses);
287 std::vector<float> maxScores;
288 std::vector<unsigned int>boxIndices;
289 std::vector<unsigned int>maxScoreClasses;
290
291 for (unsigned int box = 0; box < numBoxes; ++box)
292 {
293 unsigned int scoreIndex = box * numClassesWithBg + 1;
294
295 // Get the max scores of the box.
296 std::vector<unsigned int> maxScoreIndices = GenerateRangeK(desc.m_NumClasses);
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100297 TopKSort(numClassesPerBox, maxScoreIndices.data(),
298 decodedScores.data() + scoreIndex, desc.m_NumClasses);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000299
300 for (unsigned int i = 0; i < numClassesPerBox; ++i)
301 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100302 maxScores.push_back(decodedScores[scoreIndex + maxScoreIndices[i]]);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000303 maxScoreClasses.push_back(maxScoreIndices[i]);
304 boxIndices.push_back(box);
305 }
306 }
307
308 // Perform NMS on max scores
309 std::vector<unsigned int> selectedIndices = NonMaxSuppression(numBoxes, boxCorners, maxScores,
310 desc.m_NmsScoreThreshold,
311 desc.m_MaxDetections,
312 desc.m_NmsIouThreshold);
313
Matthew Sloyan171214c2020-09-09 09:07:37 +0100314 unsigned int numSelected = armnn::numeric_cast<unsigned int>(selectedIndices.size());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000315 unsigned int numOutput = std::min(desc.m_MaxDetections, numSelected);
316
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000317 AllocateOutputData(detectionBoxesInfo.GetShape()[1], numOutput, boxCorners, selectedIndices,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000318 boxIndices, maxScoreClasses, maxScores,
319 detectionBoxes, detectionScores, detectionClasses, numDetections);
320 }
321}
322
323} // namespace armnn