blob: 361f8865bee154eef0e0f101948de6738043cd05 [file] [log] [blame]
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00001//
Colm Donelanb4ef1632024-02-01 15:00:43 +00002// Copyright © 2017, 2024 Arm Ltd. All rights reserved.
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00003// SPDX-License-Identifier: MIT
4//
5
6#include "DetectionPostProcess.hpp"
7
Matthew Sloyan171214c2020-09-09 09:07:37 +01008#include <armnn/utility/NumericCast.hpp>
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +00009
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000010#include <algorithm>
11#include <numeric>
12
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010013namespace armnn
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000014{
15
16std::vector<unsigned int> GenerateRangeK(unsigned int k)
17{
18 std::vector<unsigned int> range(k);
19 std::iota(range.begin(), range.end(), 0);
20 return range;
21}
22
23void TopKSort(unsigned int k, unsigned int* indices, const float* values, unsigned int numElement)
24{
25 std::partial_sort(indices, indices + k, indices + numElement,
26 [&values](unsigned int i, unsigned int j) { return values[i] > values[j]; });
27}
28
29float IntersectionOverUnion(const float* boxI, const float* boxJ)
30{
31 // Box-corner format: ymin, xmin, ymax, xmax.
32 const int yMin = 0;
33 const int xMin = 1;
34 const int yMax = 2;
35 const int xMax = 3;
36 float areaI = (boxI[yMax] - boxI[yMin]) * (boxI[xMax] - boxI[xMin]);
37 float areaJ = (boxJ[yMax] - boxJ[yMin]) * (boxJ[xMax] - boxJ[xMin]);
38 float yMinIntersection = std::max(boxI[yMin], boxJ[yMin]);
39 float xMinIntersection = std::max(boxI[xMin], boxJ[xMin]);
40 float yMaxIntersection = std::min(boxI[yMax], boxJ[yMax]);
41 float xMaxIntersection = std::min(boxI[xMax], boxJ[xMax]);
42 float areaIntersection = std::max(yMaxIntersection - yMinIntersection, 0.0f) *
43 std::max(xMaxIntersection - xMinIntersection, 0.0f);
44 float areaUnion = areaI + areaJ - areaIntersection;
45 return areaIntersection / areaUnion;
46}
47
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010048std::vector<unsigned int> NonMaxSuppression(unsigned int numBoxes,
49 const std::vector<float>& boxCorners,
50 const std::vector<float>& scores,
51 float nmsScoreThreshold,
52 unsigned int maxDetection,
53 float nmsIouThreshold)
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000054{
55 // Select boxes that have scores above a given threshold.
56 std::vector<float> scoresAboveThreshold;
57 std::vector<unsigned int> indicesAboveThreshold;
58 for (unsigned int i = 0; i < numBoxes; ++i)
59 {
60 if (scores[i] >= nmsScoreThreshold)
61 {
62 scoresAboveThreshold.push_back(scores[i]);
63 indicesAboveThreshold.push_back(i);
64 }
65 }
66
67 // Sort the indices based on scores.
Matthew Sloyan171214c2020-09-09 09:07:37 +010068 unsigned int numAboveThreshold = armnn::numeric_cast<unsigned int>(scoresAboveThreshold.size());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000069 std::vector<unsigned int> sortedIndices = GenerateRangeK(numAboveThreshold);
Aron Virginas-Tar6331f912019-06-03 17:10:02 +010070 TopKSort(numAboveThreshold, sortedIndices.data(), scoresAboveThreshold.data(), numAboveThreshold);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000071
72 // Number of output cannot be more than max detections specified in the option.
73 unsigned int numOutput = std::min(maxDetection, numAboveThreshold);
74 std::vector<unsigned int> outputIndices;
75 std::vector<bool> visited(numAboveThreshold, false);
76
77 // Prune out the boxes with high intersection over union by keeping the box with higher score.
78 for (unsigned int i = 0; i < numAboveThreshold; ++i)
79 {
80 if (outputIndices.size() >= numOutput)
81 {
82 break;
83 }
84 if (!visited[sortedIndices[i]])
85 {
86 outputIndices.push_back(indicesAboveThreshold[sortedIndices[i]]);
antkillerfarmdb6e8a92020-10-15 11:02:07 +080087 for (unsigned int j = i + 1; j < numAboveThreshold; ++j)
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000088 {
antkillerfarmdb6e8a92020-10-15 11:02:07 +080089 unsigned int iIndex = indicesAboveThreshold[sortedIndices[i]] * 4;
90 unsigned int jIndex = indicesAboveThreshold[sortedIndices[j]] * 4;
91 if (IntersectionOverUnion(&boxCorners[iIndex], &boxCorners[jIndex]) > nmsIouThreshold)
92 {
93 visited[sortedIndices[j]] = true;
94 }
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +000095 }
96 }
97 }
98 return outputIndices;
99}
100
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100101void AllocateOutputData(unsigned int numOutput,
102 unsigned int numSelected,
103 const std::vector<float>& boxCorners,
104 const std::vector<unsigned int>& outputIndices,
105 const std::vector<unsigned int>& selectedBoxes,
106 const std::vector<unsigned int>& selectedClasses,
107 const std::vector<float>& selectedScores,
108 float* detectionBoxes,
109 float* detectionScores,
110 float* detectionClasses,
111 float* numDetections)
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000112{
113 for (unsigned int i = 0; i < numOutput; ++i)
114 {
115 unsigned int boxIndex = i * 4;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000116 if (i < numSelected)
117 {
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000118 unsigned int boxCornorIndex = selectedBoxes[outputIndices[i]] * 4;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000119 detectionScores[i] = selectedScores[outputIndices[i]];
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100120 detectionClasses[i] = armnn::numeric_cast<float>(selectedClasses[outputIndices[i]]);
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000121 detectionBoxes[boxIndex] = boxCorners[boxCornorIndex];
122 detectionBoxes[boxIndex + 1] = boxCorners[boxCornorIndex + 1];
123 detectionBoxes[boxIndex + 2] = boxCorners[boxCornorIndex + 2];
124 detectionBoxes[boxIndex + 3] = boxCorners[boxCornorIndex + 3];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000125 }
126 else
127 {
128 detectionScores[i] = 0.0f;
129 detectionClasses[i] = 0.0f;
130 detectionBoxes[boxIndex] = 0.0f;
131 detectionBoxes[boxIndex + 1] = 0.0f;
132 detectionBoxes[boxIndex + 2] = 0.0f;
133 detectionBoxes[boxIndex + 3] = 0.0f;
134 }
135 }
Matthew Sloyan24ac8592020-09-23 16:57:23 +0100136 numDetections[0] = armnn::numeric_cast<float>(numSelected);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000137}
138
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000139void DetectionPostProcess(const TensorInfo& boxEncodingsInfo,
140 const TensorInfo& scoresInfo,
Colm Donelanb4ef1632024-02-01 15:00:43 +0000141 const TensorInfo&,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000142 const TensorInfo& detectionBoxesInfo,
Colm Donelanb4ef1632024-02-01 15:00:43 +0000143 const TensorInfo&,
144 const TensorInfo&,
145 const TensorInfo&,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000146 const DetectionPostProcessDescriptor& desc,
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100147 Decoder<float>& boxEncodings,
148 Decoder<float>& scores,
149 Decoder<float>& anchors,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000150 float* detectionBoxes,
151 float* detectionClasses,
152 float* detectionScores,
153 float* numDetections)
154{
Derek Lamberti901ea112019-12-10 22:07:09 +0000155
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000156 // Transform center-size format which is (ycenter, xcenter, height, width) to box-corner format,
157 // which represents the lower left corner and the upper right corner (ymin, xmin, ymax, xmax)
158 std::vector<float> boxCorners(boxEncodingsInfo.GetNumElements());
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100159
160 const unsigned int numBoxes = boxEncodingsInfo.GetShape()[1];
161 const unsigned int numScores = scoresInfo.GetNumElements();
162
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000163 for (unsigned int i = 0; i < numBoxes; ++i)
164 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100165 // Y
166 float boxEncodingY = boxEncodings.Get();
167 float anchorY = anchors.Get();
168
169 ++boxEncodings;
170 ++anchors;
171
172 // X
173 float boxEncodingX = boxEncodings.Get();
174 float anchorX = anchors.Get();
175
176 ++boxEncodings;
177 ++anchors;
178
179 // H
180 float boxEncodingH = boxEncodings.Get();
181 float anchorH = anchors.Get();
182
183 ++boxEncodings;
184 ++anchors;
185
186 // W
187 float boxEncodingW = boxEncodings.Get();
188 float anchorW = anchors.Get();
189
190 ++boxEncodings;
191 ++anchors;
192
193 float yCentre = boxEncodingY / desc.m_ScaleY * anchorH + anchorY;
194 float xCentre = boxEncodingX / desc.m_ScaleX * anchorW + anchorX;
195
196 float halfH = 0.5f * expf(boxEncodingH / desc.m_ScaleH) * anchorH;
197 float halfW = 0.5f * expf(boxEncodingW / desc.m_ScaleW) * anchorW;
198
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000199 unsigned int indexY = i * 4;
200 unsigned int indexX = indexY + 1;
201 unsigned int indexH = indexX + 1;
202 unsigned int indexW = indexH + 1;
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100203
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000204 // ymin
205 boxCorners[indexY] = yCentre - halfH;
206 // xmin
207 boxCorners[indexX] = xCentre - halfW;
208 // ymax
209 boxCorners[indexH] = yCentre + halfH;
210 // xmax
211 boxCorners[indexW] = xCentre + halfW;
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000212 }
213
214 unsigned int numClassesWithBg = desc.m_NumClasses + 1;
215
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100216 // Decode scores
217 std::vector<float> decodedScores;
218 decodedScores.reserve(numScores);
219
220 for (unsigned int i = 0u; i < numScores; ++i)
221 {
222 decodedScores.emplace_back(scores.Get());
223 ++scores;
224 }
225
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000226 // Perform Non Max Suppression.
227 if (desc.m_UseRegularNms)
228 {
229 // Perform Regular NMS.
230 // For each class, perform NMS and select max detection numbers of the highest score across all classes.
231 std::vector<float> classScores(numBoxes);
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100232
233 std::vector<unsigned int> selectedBoxesAfterNms;
234 selectedBoxesAfterNms.reserve(numBoxes);
235
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000236 std::vector<float> selectedScoresAfterNms;
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100237 selectedBoxesAfterNms.reserve(numScores);
238
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000239 std::vector<unsigned int> selectedClasses;
240
241 for (unsigned int c = 0; c < desc.m_NumClasses; ++c)
242 {
243 // For each boxes, get scores of the boxes for the class c.
244 for (unsigned int i = 0; i < numBoxes; ++i)
245 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100246 classScores[i] = decodedScores[i * numClassesWithBg + c + 1];
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000247 }
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100248 std::vector<unsigned int> selectedIndices = NonMaxSuppression(numBoxes,
249 boxCorners,
250 classScores,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000251 desc.m_NmsScoreThreshold,
Narumol Prangnawarat4628d052019-02-25 17:26:05 +0000252 desc.m_DetectionsPerClass,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000253 desc.m_NmsIouThreshold);
254
255 for (unsigned int i = 0; i < selectedIndices.size(); ++i)
256 {
257 selectedBoxesAfterNms.push_back(selectedIndices[i]);
258 selectedScoresAfterNms.push_back(classScores[selectedIndices[i]]);
259 selectedClasses.push_back(c);
260 }
261 }
262
263 // Select max detection numbers of the highest score across all classes
Matthew Sloyan171214c2020-09-09 09:07:37 +0100264 unsigned int numSelected = armnn::numeric_cast<unsigned int>(selectedBoxesAfterNms.size());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000265 unsigned int numOutput = std::min(desc.m_MaxDetections, numSelected);
266
267 // Sort the max scores among the selected indices.
268 std::vector<unsigned int> outputIndices = GenerateRangeK(numSelected);
269 TopKSort(numOutput, outputIndices.data(), selectedScoresAfterNms.data(), numSelected);
270
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000271 AllocateOutputData(detectionBoxesInfo.GetShape()[1], numOutput, boxCorners, outputIndices,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000272 selectedBoxesAfterNms, selectedClasses, selectedScoresAfterNms,
273 detectionBoxes, detectionScores, detectionClasses, numDetections);
274 }
275 else
276 {
277 // Perform Fast NMS.
278 // Select max scores of boxes and perform NMS on max scores,
279 // select max detection numbers of the highest score
280 unsigned int numClassesPerBox = std::min(desc.m_MaxClassesPerDetection, desc.m_NumClasses);
281 std::vector<float> maxScores;
282 std::vector<unsigned int>boxIndices;
283 std::vector<unsigned int>maxScoreClasses;
284
285 for (unsigned int box = 0; box < numBoxes; ++box)
286 {
287 unsigned int scoreIndex = box * numClassesWithBg + 1;
288
289 // Get the max scores of the box.
290 std::vector<unsigned int> maxScoreIndices = GenerateRangeK(desc.m_NumClasses);
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100291 TopKSort(numClassesPerBox, maxScoreIndices.data(),
292 decodedScores.data() + scoreIndex, desc.m_NumClasses);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000293
294 for (unsigned int i = 0; i < numClassesPerBox; ++i)
295 {
Aron Virginas-Tar6331f912019-06-03 17:10:02 +0100296 maxScores.push_back(decodedScores[scoreIndex + maxScoreIndices[i]]);
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000297 maxScoreClasses.push_back(maxScoreIndices[i]);
298 boxIndices.push_back(box);
299 }
300 }
301
302 // Perform NMS on max scores
303 std::vector<unsigned int> selectedIndices = NonMaxSuppression(numBoxes, boxCorners, maxScores,
304 desc.m_NmsScoreThreshold,
305 desc.m_MaxDetections,
306 desc.m_NmsIouThreshold);
307
Matthew Sloyan171214c2020-09-09 09:07:37 +0100308 unsigned int numSelected = armnn::numeric_cast<unsigned int>(selectedIndices.size());
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000309 unsigned int numOutput = std::min(desc.m_MaxDetections, numSelected);
310
Narumol Prangnawarat6d302bf2019-02-04 11:46:26 +0000311 AllocateOutputData(detectionBoxesInfo.GetShape()[1], numOutput, boxCorners, selectedIndices,
Narumol Prangnawaratbc67cef2019-01-31 15:31:54 +0000312 boxIndices, maxScoreClasses, maxScores,
313 detectionBoxes, detectionScores, detectionClasses, numDetections);
314 }
315}
316
317} // namespace armnn