blob: d067f1a004b394ec208137bc09ee6bb738c072f0 [file] [log] [blame]
Derek Lambertid6cb30e2020-04-28 13:31:29 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6
7#include "NMS.hpp"
8
Ryan OShea74af0932020-08-07 16:27:34 +01009#include <cmath>
Derek Lambertid6cb30e2020-04-28 13:31:29 +010010#include <algorithm>
11#include <cstddef>
12#include <numeric>
13#include <ostream>
14
15namespace yolov3 {
16namespace {
17/** Number of elements needed to represent a box */
18constexpr int box_elements = 4;
19/** Number of elements needed to represent a confidence factor */
20constexpr int confidence_elements = 1;
21
22/** Calculate Intersection Over Union of two boxes
23 *
24 * @param[in] box1 First box
25 * @param[in] box2 Second box
26 *
27 * @return The IoU of the two boxes
28 */
29float iou(const Box& box1, const Box& box2)
30{
31 const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
32 const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
33 float overlap;
34 if (area1 <= 0 || area2 <= 0)
35 {
36 overlap = 0.0f;
37 }
38 else
39 {
40 const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin);
41 const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin);
42 const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax);
43 const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax);
44 const auto area_intersection =
45 std::max<float>(y_max_intersection - y_min_intersection, 0.0f) *
46 std::max<float>(x_max_intersection - x_min_intersection, 0.0f);
47 overlap = area_intersection / (area1 + area2 - area_intersection);
48 }
49 return overlap;
50}
51
52std::vector<Detection> convert_to_detections(const NMSConfig& config,
53 const std::vector<float>& detected_boxes)
54{
55 const size_t element_step = static_cast<size_t>(
56 box_elements + confidence_elements + config.num_classes);
57 std::vector<Detection> detections;
58
59 for (unsigned int i = 0; i < config.num_boxes; ++i)
60 {
61 const float* cur_box = &detected_boxes[i * element_step];
62 if (cur_box[4] > config.confidence_threshold)
63 {
64 Detection det;
65 det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1],
66 cur_box[1] + cur_box[3]};
67 det.confidence = cur_box[4];
68 det.classes.resize(static_cast<size_t>(config.num_classes), 0);
69 for (unsigned int c = 0; c < config.num_classes; ++c)
70 {
71 const float class_prob = det.confidence * cur_box[5 + c];
72 if (class_prob > config.confidence_threshold)
73 {
74 det.classes[c] = class_prob;
75 }
76 }
77 detections.emplace_back(std::move(det));
78 }
79 }
80 return detections;
81}
82} // namespace
83
Ryan OShea74af0932020-08-07 16:27:34 +010084bool compare_detection(const yolov3::Detection& detection,
85 const std::vector<float>& expected)
86{
87 float tolerance = 0.001f;
88 return (std::fabs(detection.classes[0] - expected[0]) < tolerance &&
89 std::fabs(detection.box.xmin - expected[1]) < tolerance &&
90 std::fabs(detection.box.ymin - expected[2]) < tolerance &&
91 std::fabs(detection.box.xmax - expected[3]) < tolerance &&
92 std::fabs(detection.box.ymax - expected[4]) < tolerance &&
93 std::fabs(detection.confidence - expected[5]) < tolerance );
94}
95
Derek Lambertid6cb30e2020-04-28 13:31:29 +010096void print_detection(std::ostream& os,
97 const std::vector<Detection>& detections)
98{
99 for (const auto& detection : detections)
100 {
101 for (unsigned int c = 0; c < detection.classes.size(); ++c)
102 {
103 if (detection.classes[c] != 0.0f)
104 {
105 os << c << " " << detection.classes[c] << " " << detection.box.xmin
106 << " " << detection.box.ymin << " " << detection.box.xmax << " "
107 << detection.box.ymax << std::endl;
108 }
109 }
110 }
111}
112
113std::vector<Detection> nms(const NMSConfig& config,
114 const std::vector<float>& detected_boxes) {
115 // Get detections that comply with the expected confidence threshold
116 std::vector<Detection> detections =
117 convert_to_detections(config, detected_boxes);
118
119 const unsigned int num_detections = static_cast<unsigned int>(detections.size());
120 for (unsigned int c = 0; c < config.num_classes; ++c)
121 {
122 // Sort classes
123 std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections),
124 [c](Detection& detection1, Detection& detection2)
125 {
126 return (detection1.classes[c] - detection2.classes[c]) > 0;
127 });
128 // Clear detections with high IoU
129 for (unsigned int d = 0; d < num_detections; ++d)
130 {
131 // Check if class is already cleared/invalidated
132 if (detections[d].classes[c] == 0.f)
133 {
134 continue;
135 }
136
137 // Filter out boxes on IoU threshold
138 const Box& box1 = detections[d].box;
139 for (unsigned int b = d + 1; b < num_detections; ++b)
140 {
141 const Box& box2 = detections[b].box;
142 if (iou(box1, box2) > config.iou_threshold)
143 {
144 detections[b].classes[c] = 0.f;
145 }
146 }
147 }
148 }
149 return detections;
150}
151} // namespace yolov3