blob: 7bcd9045a5aa072615f81c0900f07472a22171ee [file] [log] [blame]
Éanna Ó Catháin919c14e2020-09-14 17:36:49 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include "NonMaxSuppression.hpp"
6
7#include <algorithm>
8
9namespace od
10{
11
12static std::vector<unsigned int> GenerateRangeK(unsigned int k)
13{
14 std::vector<unsigned int> range(k);
15 std::iota(range.begin(), range.end(), 0);
16 return range;
17}
18
19
20/**
21* @brief Returns the intersection over union for two bounding boxes
22*
23* @param[in] First detect containing bounding box.
24* @param[in] Second detect containing bounding box.
25* @return Calculated intersection over union.
26*
27*/
28static double IntersectionOverUnion(DetectedObject& detect1, DetectedObject& detect2)
29{
30 uint32_t area1 = (detect1.GetBoundingBox().GetHeight() * detect1.GetBoundingBox().GetWidth());
31 uint32_t area2 = (detect2.GetBoundingBox().GetHeight() * detect2.GetBoundingBox().GetWidth());
32
33 float yMinIntersection = std::max(detect1.GetBoundingBox().GetY(), detect2.GetBoundingBox().GetY());
34 float xMinIntersection = std::max(detect1.GetBoundingBox().GetX(), detect2.GetBoundingBox().GetX());
35
36 float yMaxIntersection = std::min(detect1.GetBoundingBox().GetY() + detect1.GetBoundingBox().GetHeight(),
37 detect2.GetBoundingBox().GetY() + detect2.GetBoundingBox().GetHeight());
38 float xMaxIntersection = std::min(detect1.GetBoundingBox().GetX() + detect1.GetBoundingBox().GetWidth(),
39 detect2.GetBoundingBox().GetX() + detect2.GetBoundingBox().GetWidth());
40
41 double areaIntersection = std::max(yMaxIntersection - yMinIntersection, 0.0f) *
42 std::max(xMaxIntersection - xMinIntersection, 0.0f);
43 double areaUnion = area1 + area2 - areaIntersection;
44
45 return areaIntersection / areaUnion;
46}
47
48std::vector<int> NonMaxSuppression(DetectedObjects& inputDetections, float iouThresh)
49{
50 // Sort indicies of detections by highest score to lowest.
51 std::vector<unsigned int> sortedIndicies = GenerateRangeK(inputDetections.size());
52 std::sort(sortedIndicies.begin(), sortedIndicies.end(),
53 [&inputDetections](int idx1, int idx2)
54 {
55 return inputDetections[idx1].GetScore() > inputDetections[idx2].GetScore();
56 });
57
58 std::vector<bool> visited(inputDetections.size(), false);
59 std::vector<int> outputIndiciesAfterNMS;
60
61 for (int i=0; i < inputDetections.size(); ++i)
62 {
63 // Each new unvisited detect should be kept.
64 if (!visited[sortedIndicies[i]])
65 {
66 outputIndiciesAfterNMS.emplace_back(sortedIndicies[i]);
67 visited[sortedIndicies[i]] = true;
68 }
69
70 // Look for detections to suppress.
71 for (int j=i+1; j<inputDetections.size(); ++j)
72 {
73 // Skip if already kept or suppressed.
74 if (!visited[sortedIndicies[j]])
75 {
76 // Detects must have the same label to be suppressed.
77 if (inputDetections[sortedIndicies[j]].GetLabel() == inputDetections[sortedIndicies[i]].GetLabel())
78 {
79 auto iou = IntersectionOverUnion(inputDetections[sortedIndicies[i]],
80 inputDetections[sortedIndicies[j]]);
81 if (iou > iouThresh)
82 {
83 visited[sortedIndicies[j]] = true;
84 }
85 }
86 }
87 }
88 }
89 return outputIndiciesAfterNMS;
90}
91
92} // namespace od