Derek Lamberti | d6cb30e | 2020-04-28 13:31:29 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2020 Arm Ltd. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | |
| 7 | #include "NMS.hpp" |
| 8 | |
Ryan OShea | 74af093 | 2020-08-07 16:27:34 +0100 | [diff] [blame] | 9 | #include <cmath> |
Derek Lamberti | d6cb30e | 2020-04-28 13:31:29 +0100 | [diff] [blame] | 10 | #include <algorithm> |
| 11 | #include <cstddef> |
| 12 | #include <numeric> |
| 13 | #include <ostream> |
| 14 | |
| 15 | namespace yolov3 { |
| 16 | namespace { |
| 17 | /** Number of elements needed to represent a box */ |
| 18 | constexpr int box_elements = 4; |
| 19 | /** Number of elements needed to represent a confidence factor */ |
| 20 | constexpr int confidence_elements = 1; |
| 21 | |
| 22 | /** Calculate Intersection Over Union of two boxes |
| 23 | * |
| 24 | * @param[in] box1 First box |
| 25 | * @param[in] box2 Second box |
| 26 | * |
| 27 | * @return The IoU of the two boxes |
| 28 | */ |
| 29 | float iou(const Box& box1, const Box& box2) |
| 30 | { |
| 31 | const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin); |
| 32 | const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin); |
| 33 | float overlap; |
| 34 | if (area1 <= 0 || area2 <= 0) |
| 35 | { |
| 36 | overlap = 0.0f; |
| 37 | } |
| 38 | else |
| 39 | { |
| 40 | const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin); |
| 41 | const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin); |
| 42 | const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax); |
| 43 | const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax); |
| 44 | const auto area_intersection = |
| 45 | std::max<float>(y_max_intersection - y_min_intersection, 0.0f) * |
| 46 | std::max<float>(x_max_intersection - x_min_intersection, 0.0f); |
| 47 | overlap = area_intersection / (area1 + area2 - area_intersection); |
| 48 | } |
| 49 | return overlap; |
| 50 | } |
| 51 | |
| 52 | std::vector<Detection> convert_to_detections(const NMSConfig& config, |
| 53 | const std::vector<float>& detected_boxes) |
| 54 | { |
| 55 | const size_t element_step = static_cast<size_t>( |
| 56 | box_elements + confidence_elements + config.num_classes); |
| 57 | std::vector<Detection> detections; |
| 58 | |
| 59 | for (unsigned int i = 0; i < config.num_boxes; ++i) |
| 60 | { |
| 61 | const float* cur_box = &detected_boxes[i * element_step]; |
| 62 | if (cur_box[4] > config.confidence_threshold) |
| 63 | { |
| 64 | Detection det; |
| 65 | det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1], |
| 66 | cur_box[1] + cur_box[3]}; |
| 67 | det.confidence = cur_box[4]; |
| 68 | det.classes.resize(static_cast<size_t>(config.num_classes), 0); |
| 69 | for (unsigned int c = 0; c < config.num_classes; ++c) |
| 70 | { |
| 71 | const float class_prob = det.confidence * cur_box[5 + c]; |
| 72 | if (class_prob > config.confidence_threshold) |
| 73 | { |
| 74 | det.classes[c] = class_prob; |
| 75 | } |
| 76 | } |
| 77 | detections.emplace_back(std::move(det)); |
| 78 | } |
| 79 | } |
| 80 | return detections; |
| 81 | } |
| 82 | } // namespace |
| 83 | |
Ryan OShea | 74af093 | 2020-08-07 16:27:34 +0100 | [diff] [blame] | 84 | bool compare_detection(const yolov3::Detection& detection, |
| 85 | const std::vector<float>& expected) |
| 86 | { |
| 87 | float tolerance = 0.001f; |
| 88 | return (std::fabs(detection.classes[0] - expected[0]) < tolerance && |
| 89 | std::fabs(detection.box.xmin - expected[1]) < tolerance && |
| 90 | std::fabs(detection.box.ymin - expected[2]) < tolerance && |
| 91 | std::fabs(detection.box.xmax - expected[3]) < tolerance && |
| 92 | std::fabs(detection.box.ymax - expected[4]) < tolerance && |
| 93 | std::fabs(detection.confidence - expected[5]) < tolerance ); |
| 94 | } |
| 95 | |
Derek Lamberti | d6cb30e | 2020-04-28 13:31:29 +0100 | [diff] [blame] | 96 | void print_detection(std::ostream& os, |
| 97 | const std::vector<Detection>& detections) |
| 98 | { |
| 99 | for (const auto& detection : detections) |
| 100 | { |
| 101 | for (unsigned int c = 0; c < detection.classes.size(); ++c) |
| 102 | { |
| 103 | if (detection.classes[c] != 0.0f) |
| 104 | { |
| 105 | os << c << " " << detection.classes[c] << " " << detection.box.xmin |
| 106 | << " " << detection.box.ymin << " " << detection.box.xmax << " " |
| 107 | << detection.box.ymax << std::endl; |
| 108 | } |
| 109 | } |
| 110 | } |
| 111 | } |
| 112 | |
| 113 | std::vector<Detection> nms(const NMSConfig& config, |
| 114 | const std::vector<float>& detected_boxes) { |
| 115 | // Get detections that comply with the expected confidence threshold |
| 116 | std::vector<Detection> detections = |
| 117 | convert_to_detections(config, detected_boxes); |
| 118 | |
| 119 | const unsigned int num_detections = static_cast<unsigned int>(detections.size()); |
| 120 | for (unsigned int c = 0; c < config.num_classes; ++c) |
| 121 | { |
| 122 | // Sort classes |
| 123 | std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections), |
| 124 | [c](Detection& detection1, Detection& detection2) |
| 125 | { |
| 126 | return (detection1.classes[c] - detection2.classes[c]) > 0; |
| 127 | }); |
| 128 | // Clear detections with high IoU |
| 129 | for (unsigned int d = 0; d < num_detections; ++d) |
| 130 | { |
| 131 | // Check if class is already cleared/invalidated |
| 132 | if (detections[d].classes[c] == 0.f) |
| 133 | { |
| 134 | continue; |
| 135 | } |
| 136 | |
| 137 | // Filter out boxes on IoU threshold |
| 138 | const Box& box1 = detections[d].box; |
| 139 | for (unsigned int b = d + 1; b < num_detections; ++b) |
| 140 | { |
| 141 | const Box& box2 = detections[b].box; |
| 142 | if (iou(box1, box2) > config.iou_threshold) |
| 143 | { |
| 144 | detections[b].classes[c] = 0.f; |
| 145 | } |
| 146 | } |
| 147 | } |
| 148 | } |
| 149 | return detections; |
| 150 | } |
| 151 | } // namespace yolov3 |