blob: 3e9c8198a3512ed1e0cd933af3709abac4b6f911 [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef DETECTOR_POST_PROCESSING_HPP
18#define DETECTOR_POST_PROCESSING_HPP
19
20#include "UseCaseCommonUtils.hpp"
21#include "DetectionResult.hpp"
22#include "YoloFastestModel.hpp"
23
24#include <forward_list>
25
26namespace arm {
27namespace app {
28namespace object_detection {
29
30 struct Branch {
31 int resolution;
32 int numBox;
33 const float* anchor;
34 int8_t* modelOutput;
35 float scale;
36 int zeroPoint;
37 size_t size;
38 };
39
40 struct Network {
41 int inputWidth;
42 int inputHeight;
43 int numClasses;
44 std::vector<Branch> branches;
45 int topN;
46 };
47
48
49 struct Box {
50 float x;
51 float y;
52 float w;
53 float h;
54 };
55
56 struct Detection {
57 Box bbox;
58 std::vector<float> prob;
59 float objectness;
60 };
61
62 /**
63 * @brief Helper class to manage tensor post-processing for "object_detection"
64 * output.
65 */
66 class DetectorPostprocessing {
67 public:
68 /**
69 * @brief Constructor.
70 * @param[in] threshold Post-processing threshold.
71 * @param[in] nms Non-maximum Suppression threshold.
72 * @param[in] numClasses Number of classes.
73 * @param[in] topN Top N for each class.
74 **/
75 DetectorPostprocessing(float threshold = 0.5f,
76 float nms = 0.45f,
77 int numClasses = 1,
78 int topN = 0);
79
80 /**
81 * @brief Post processing part of Yolo object detection CNN.
82 * @param[in] imgIn Pointer to the input image,detection bounding boxes drown on it.
83 * @param[in] imgRows Number of rows in the input image.
84 * @param[in] imgCols Number of columns in the input image.
85 * @param[in] modelOutput Output tensors after CNN invoked.
86 * @param[out] resultsOut Vector of detected results.
87 **/
88 void RunPostProcessing(uint8_t* imgIn,
89 uint32_t imgRows,
90 uint32_t imgCols,
91 TfLiteTensor* modelOutput0,
92 TfLiteTensor* modelOutput1,
93 std::vector<DetectionResult>& resultsOut);
94
95 private:
96 float m_threshold; /* Post-processing threshold */
97 float m_nms; /* NMS threshold */
98 int m_numClasses; /* Number of classes */
99 int m_topN; /* TopN */
100
101 /**
102 * @brief Calculate the Sigmoid function of the give value.
103 * @param[in] x Value.
104 * @return Sigmoid value of the input.
105 **/
106 float Sigmoid(float x);
107
108 /**
109 * @brief Insert the given Detection in the list.
110 * @param[in] detections List of detections.
111 * @param[in] det Detection to be inserted.
112 **/
113 void InsertTopNDetections(std::forward_list<Detection>& detections, Detection& det);
114
115 /**
116 * @brief Given a Network calculate the detection boxes.
117 * @param[in] net Network.
118 * @param[in] imageWidth Original image width.
119 * @param[in] imageHeight Original image height.
120 * @param[in] threshold Detections threshold.
121 * @param[out] detections Detection boxes.
122 **/
123 void GetNetworkBoxes(Network& net,
124 int imageWidth,
125 int imageHeight,
126 float threshold,
127 std::forward_list<Detection>& detections);
128
129 /**
130 * @brief Calculate the 1D overlap.
131 * @param[in] x1Center First center point.
132 * @param[in] width1 First width.
133 * @param[in] x2Center Second center point.
134 * @param[in] width2 Second width.
135 * @return The overlap between the two lines.
136 **/
137 float Calculate1DOverlap(float x1Center, float width1, float x2Center, float width2);
138
139 /**
140 * @brief Calculate the intersection between the two given boxes.
141 * @param[in] box1 First box.
142 * @param[in] box2 Second box.
143 * @return The intersection value.
144 **/
145 float CalculateBoxIntersect(Box& box1, Box& box2);
146
147 /**
148 * @brief Calculate the union between the two given boxes.
149 * @param[in] box1 First box.
150 * @param[in] box2 Second box.
151 * @return The two given boxes union value.
152 **/
153 float CalculateBoxUnion(Box& box1, Box& box2);
154 /**
155 * @brief Calculate the intersection over union between the two given boxes.
156 * @param[in] box1 First box.
157 * @param[in] box2 Second box.
158 * @return The intersection over union value.
159 **/
160 float CalculateBoxIOU(Box& box1, Box& box2);
161
162 /**
163 * @brief Calculate the Non-Maxima suppression on the given detection boxes.
164 * @param[in] detections Detection boxes.
165 * @param[in] classes Number of classes.
166 * @param[in] iouThreshold Intersection over union threshold.
167 * @return true or false based on execution success.
168 **/
169 void CalculateNMS(std::forward_list<Detection>& detections, int classes, float iouThreshold);
170
171 /**
172 * @brief Draw on the given image a bounding box starting at (boxX, boxY).
173 * @param[in/out] imgIn Image.
174 * @param[in] imWidth Image width.
175 * @param[in] imHeight Image height.
176 * @param[in] boxX Axis X starting point.
177 * @param[in] boxY Axis Y starting point.
178 * @param[in] boxWidth Box width.
179 * @param[in] boxHeight Box height.
180 **/
181 void DrawBoxOnImage(uint8_t* imgIn,
182 int imWidth,
183 int imHeight,
184 int boxX,
185 int boxY,
186 int boxWidth,
187 int boxHeight);
188 };
189
190} /* namespace object_detection */
191} /* namespace app */
192} /* namespace arm */
193
194#endif /* DETECTOR_POST_PROCESSING_HPP */