blob: 3ef840f8750f0c38494175e996fff9e7c9038e47 [file] [log] [blame]
Derek Lambertid6cb30e2020-04-28 13:31:29 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6
7#include "NMS.hpp"
8
9#include <algorithm>
10#include <cstddef>
11#include <numeric>
12#include <ostream>
13
14namespace yolov3 {
15namespace {
16/** Number of elements needed to represent a box */
17constexpr int box_elements = 4;
18/** Number of elements needed to represent a confidence factor */
19constexpr int confidence_elements = 1;
20
21/** Calculate Intersection Over Union of two boxes
22 *
23 * @param[in] box1 First box
24 * @param[in] box2 Second box
25 *
26 * @return The IoU of the two boxes
27 */
28float iou(const Box& box1, const Box& box2)
29{
30 const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
31 const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
32 float overlap;
33 if (area1 <= 0 || area2 <= 0)
34 {
35 overlap = 0.0f;
36 }
37 else
38 {
39 const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin);
40 const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin);
41 const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax);
42 const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax);
43 const auto area_intersection =
44 std::max<float>(y_max_intersection - y_min_intersection, 0.0f) *
45 std::max<float>(x_max_intersection - x_min_intersection, 0.0f);
46 overlap = area_intersection / (area1 + area2 - area_intersection);
47 }
48 return overlap;
49}
50
51std::vector<Detection> convert_to_detections(const NMSConfig& config,
52 const std::vector<float>& detected_boxes)
53{
54 const size_t element_step = static_cast<size_t>(
55 box_elements + confidence_elements + config.num_classes);
56 std::vector<Detection> detections;
57
58 for (unsigned int i = 0; i < config.num_boxes; ++i)
59 {
60 const float* cur_box = &detected_boxes[i * element_step];
61 if (cur_box[4] > config.confidence_threshold)
62 {
63 Detection det;
64 det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1],
65 cur_box[1] + cur_box[3]};
66 det.confidence = cur_box[4];
67 det.classes.resize(static_cast<size_t>(config.num_classes), 0);
68 for (unsigned int c = 0; c < config.num_classes; ++c)
69 {
70 const float class_prob = det.confidence * cur_box[5 + c];
71 if (class_prob > config.confidence_threshold)
72 {
73 det.classes[c] = class_prob;
74 }
75 }
76 detections.emplace_back(std::move(det));
77 }
78 }
79 return detections;
80}
81} // namespace
82
83void print_detection(std::ostream& os,
84 const std::vector<Detection>& detections)
85{
86 for (const auto& detection : detections)
87 {
88 for (unsigned int c = 0; c < detection.classes.size(); ++c)
89 {
90 if (detection.classes[c] != 0.0f)
91 {
92 os << c << " " << detection.classes[c] << " " << detection.box.xmin
93 << " " << detection.box.ymin << " " << detection.box.xmax << " "
94 << detection.box.ymax << std::endl;
95 }
96 }
97 }
98}
99
100std::vector<Detection> nms(const NMSConfig& config,
101 const std::vector<float>& detected_boxes) {
102 // Get detections that comply with the expected confidence threshold
103 std::vector<Detection> detections =
104 convert_to_detections(config, detected_boxes);
105
106 const unsigned int num_detections = static_cast<unsigned int>(detections.size());
107 for (unsigned int c = 0; c < config.num_classes; ++c)
108 {
109 // Sort classes
110 std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections),
111 [c](Detection& detection1, Detection& detection2)
112 {
113 return (detection1.classes[c] - detection2.classes[c]) > 0;
114 });
115 // Clear detections with high IoU
116 for (unsigned int d = 0; d < num_detections; ++d)
117 {
118 // Check if class is already cleared/invalidated
119 if (detections[d].classes[c] == 0.f)
120 {
121 continue;
122 }
123
124 // Filter out boxes on IoU threshold
125 const Box& box1 = detections[d].box;
126 for (unsigned int b = d + 1; b < num_detections; ++b)
127 {
128 const Box& box2 = detections[b].box;
129 if (iou(box1, box2) > config.iou_threshold)
130 {
131 detections[b].classes[c] = 0.f;
132 }
133 }
134 }
135 }
136 return detections;
137}
138} // namespace yolov3