blob: f5e3cf38af2455c86107e4f44009ec5a2db622f1 [file] [log] [blame]
Derek Lambertid6cb30e2020-04-28 13:31:29 +01001//
2// Copyright © 2020 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include <ostream>
9#include <vector>
10
11namespace yolov3 {
12/** Non Maxima Suprresion configuration meta-data */
13struct NMSConfig {
14 unsigned int num_classes{0}; /**< Number of classes in the detected boxes */
15 unsigned int num_boxes{0}; /**< Number of detected boxes */
16 float confidence_threshold{0.8f}; /**< Inclusion confidence threshold for a box */
17 float iou_threshold{0.8f}; /**< Inclusion threshold for Intersection-Over-Union */
18};
19
20/** Box representation structure */
21struct Box {
22 float xmin; /**< X-pos position of the low left coordinate */
23 float xmax; /**< X-pos position of the top right coordinate */
24 float ymin; /**< Y-pos position of the low left coordinate */
25 float ymax; /**< Y-pos position of the top right coordinate */
26};
27
28/** Detection structure */
29struct Detection {
30 Box box; /**< Detection box */
31 float confidence; /**< Confidence of detection */
32 std::vector<float> classes; /**< Probability of classes */
33};
34
35/** Print identified yolo detections
36 *
37 * @param[in, out] os Output stream to print to
38 * @param[in] detections Detections to print
39 */
40void print_detection(std::ostream& os,
41 const std::vector<Detection>& detections);
42
Ryan OShea74af0932020-08-07 16:27:34 +010043/** Compare a detection object with a vector of float values
44 *
45 * @param detection [in] Detection object
46 * @param expected [in] Vector of expected float values
47 * @return Boolean to represent if they match or not
48 */
49bool compare_detection(const yolov3::Detection& detection,
50 const std::vector<float>& expected);
51
Derek Lambertid6cb30e2020-04-28 13:31:29 +010052/** Perform Non-Maxima Supression on a list of given detections
53 *
54 * @param[in] config Configuration metadata for NMS
55 * @param[in] detected_boxes Detected boxes
56 *
57 * @return A vector with the final detections
58 */
59std::vector<Detection> nms(const NMSConfig& config,
60 const std::vector<float>& detected_boxes);
61} // namespace yolov3