blob: edfb1374933670621d841b9e3ad3b91b5f463eea [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "DetectorPostProcessing.hpp"
18
19#include <algorithm>
20#include <cmath>
21
22namespace arm {
23namespace app {
24namespace object_detection {
25
26DetectorPostprocessing::DetectorPostprocessing(
27 const float threshold,
28 const float nms,
29 int numClasses,
30 int topN)
31 : m_threshold(threshold),
32 m_nms(nms),
33 m_numClasses(numClasses),
34 m_topN(topN)
35{}
36
37void DetectorPostprocessing::RunPostProcessing(
38 uint8_t* imgIn,
39 uint32_t imgRows,
40 uint32_t imgCols,
41 TfLiteTensor* modelOutput0,
42 TfLiteTensor* modelOutput1,
43 std::vector<DetectionResult>& resultsOut)
44{
45 /* init postprocessing */
46 Network net {
47 .inputWidth = static_cast<int>(imgCols),
48 .inputHeight = static_cast<int>(imgRows),
49 .numClasses = m_numClasses,
50 .branches = {
51 Branch {
52 .resolution = static_cast<int>(imgCols/32),
53 .numBox = 3,
54 .anchor = anchor1,
55 .modelOutput = modelOutput0->data.int8,
56 .scale = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->scale->data[0],
57 .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->zero_point->data[0],
58 .size = modelOutput0->bytes
59 },
60 Branch {
61 .resolution = static_cast<int>(imgCols/16),
62 .numBox = 3,
63 .anchor = anchor2,
64 .modelOutput = modelOutput1->data.int8,
65 .scale = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->scale->data[0],
66 .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->zero_point->data[0],
67 .size = modelOutput1->bytes
68 }
69 },
70 .topN = m_topN
71 };
72 /* End init */
73
74 /* Start postprocessing */
75 int originalImageWidth = originalImageSize;
76 int originalImageHeight = originalImageSize;
77
78 std::forward_list<Detection> detections;
79 GetNetworkBoxes(net, originalImageWidth, originalImageHeight, m_threshold, detections);
80
81 /* Do nms */
82 CalculateNMS(detections, net.numClasses, m_nms);
83
84 for (auto& it: detections) {
85 float xMin = it.bbox.x - it.bbox.w / 2.0f;
86 float xMax = it.bbox.x + it.bbox.w / 2.0f;
87 float yMin = it.bbox.y - it.bbox.h / 2.0f;
88 float yMax = it.bbox.y + it.bbox.h / 2.0f;
89
90 if (xMin < 0) {
91 xMin = 0;
92 }
93 if (yMin < 0) {
94 yMin = 0;
95 }
96 if (xMax > originalImageWidth) {
97 xMax = originalImageWidth;
98 }
99 if (yMax > originalImageHeight) {
100 yMax = originalImageHeight;
101 }
102
103 float boxX = xMin;
104 float boxY = yMin;
105 float boxWidth = xMax - xMin;
106 float boxHeight = yMax - yMin;
107
108 for (int j = 0; j < net.numClasses; ++j) {
109 if (it.prob[j] > 0) {
110
111 DetectionResult tmpResult = {};
112 tmpResult.m_normalisedVal = it.prob[j];
113 tmpResult.m_x0 = boxX;
114 tmpResult.m_y0 = boxY;
115 tmpResult.m_w = boxWidth;
116 tmpResult.m_h = boxHeight;
117
118 resultsOut.push_back(tmpResult);
119
120 /* TODO: Instead of draw on the image, return the boxes and draw on the LCD */
121 DrawBoxOnImage(imgIn, originalImageWidth, originalImageHeight, boxX, boxY, boxWidth, boxHeight);;
122 }
123 }
124 }
125}
126
127float DetectorPostprocessing::Sigmoid(float x)
128{
129 return 1.f/(1.f + exp(-x));
130}
131
132void DetectorPostprocessing::InsertTopNDetections(std::forward_list<Detection>& detections, Detection& det)
133{
134 std::forward_list<Detection>::iterator it;
135 std::forward_list<Detection>::iterator last_it;
136 for ( it = detections.begin(); it != detections.end(); ++it ) {
137 if(it->objectness > det.objectness)
138 break;
139 last_it = it;
140 }
141 if(it != detections.begin()) {
142 detections.emplace_after(last_it, det);
143 detections.pop_front();
144 }
145}
146
147void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int imageHeight, float threshold, std::forward_list<Detection>& detections)
148{
149 int numClasses = net.numClasses;
150 int num = 0;
151 auto det_objectness_comparator = [](Detection& pa, Detection& pb) {
152 return pa.objectness < pb.objectness;
153 };
154 for (size_t i = 0; i < net.branches.size(); ++i) {
155 int height = net.branches[i].resolution;
156 int width = net.branches[i].resolution;
157 int channel = net.branches[i].numBox*(5+numClasses);
158
159 for (int h = 0; h < net.branches[i].resolution; h++) {
160 for (int w = 0; w < net.branches[i].resolution; w++) {
161 for (int anc = 0; anc < net.branches[i].numBox; anc++) {
162
163 /* Objectness score */
164 int bbox_obj_offset = h * width * channel + w * channel + anc * (numClasses + 5) + 4;
165 float objectness = Sigmoid(((float)net.branches[i].modelOutput[bbox_obj_offset] - net.branches[i].zeroPoint) * net.branches[i].scale);
166
167 if(objectness > threshold) {
168 Detection det;
169 det.objectness = objectness;
170 /* Get bbox prediction data for each anchor, each feature point */
171 int bbox_x_offset = bbox_obj_offset -4;
172 int bbox_y_offset = bbox_x_offset + 1;
173 int bbox_w_offset = bbox_x_offset + 2;
174 int bbox_h_offset = bbox_x_offset + 3;
175 int bbox_scores_offset = bbox_x_offset + 5;
176
177 det.bbox.x = ((float)net.branches[i].modelOutput[bbox_x_offset] - net.branches[i].zeroPoint) * net.branches[i].scale;
178 det.bbox.y = ((float)net.branches[i].modelOutput[bbox_y_offset] - net.branches[i].zeroPoint) * net.branches[i].scale;
179 det.bbox.w = ((float)net.branches[i].modelOutput[bbox_w_offset] - net.branches[i].zeroPoint) * net.branches[i].scale;
180 det.bbox.h = ((float)net.branches[i].modelOutput[bbox_h_offset] - net.branches[i].zeroPoint) * net.branches[i].scale;
181
182
183 float bbox_x, bbox_y;
184
185 /* Eliminate grid sensitivity trick involved in YOLOv4 */
186 bbox_x = Sigmoid(det.bbox.x);
187 bbox_y = Sigmoid(det.bbox.y);
188 det.bbox.x = (bbox_x + w) / width;
189 det.bbox.y = (bbox_y + h) / height;
190
191 det.bbox.w = exp(det.bbox.w) * net.branches[i].anchor[anc*2] / net.inputWidth;
192 det.bbox.h = exp(det.bbox.h) * net.branches[i].anchor[anc*2+1] / net.inputHeight;
193
194 for (int s = 0; s < numClasses; s++) {
195 float sig = Sigmoid(((float)net.branches[i].modelOutput[bbox_scores_offset + s] - net.branches[i].zeroPoint) * net.branches[i].scale)*objectness;
196 det.prob.emplace_back((sig > threshold) ? sig : 0);
197 }
198
199 /* Correct_YOLO_boxes */
200 det.bbox.x *= imageWidth;
201 det.bbox.w *= imageWidth;
202 det.bbox.y *= imageHeight;
203 det.bbox.h *= imageHeight;
204
205 if (num < net.topN || net.topN <=0) {
206 detections.emplace_front(det);
207 num += 1;
208 } else if (num == net.topN) {
209 detections.sort(det_objectness_comparator);
210 InsertTopNDetections(detections,det);
211 num += 1;
212 } else {
213 InsertTopNDetections(detections,det);
214 }
215 }
216 }
217 }
218 }
219 }
220 if(num > net.topN)
221 num -=1;
222}
223
224float DetectorPostprocessing::Calculate1DOverlap(float x1Center, float width1, float x2Center, float width2)
225{
226 float left_1 = x1Center - width1/2;
227 float left_2 = x2Center - width2/2;
228 float leftest = left_1 > left_2 ? left_1 : left_2;
229
230 float right_1 = x1Center + width1/2;
231 float right_2 = x2Center + width2/2;
232 float rightest = right_1 < right_2 ? right_1 : right_2;
233
234 return rightest - leftest;
235}
236
237float DetectorPostprocessing::CalculateBoxIntersect(Box& box1, Box& box2)
238{
239 float width = Calculate1DOverlap(box1.x, box1.w, box2.x, box2.w);
240 if (width < 0) {
241 return 0;
242 }
243 float height = Calculate1DOverlap(box1.y, box1.h, box2.y, box2.h);
244 if (height < 0) {
245 return 0;
246 }
247
248 float total_area = width*height;
249 return total_area;
250}
251
252float DetectorPostprocessing::CalculateBoxUnion(Box& box1, Box& box2)
253{
254 float boxes_intersection = CalculateBoxIntersect(box1, box2);
255 float boxes_union = box1.w * box1.h + box2.w * box2.h - boxes_intersection;
256 return boxes_union;
257}
258
259
260float DetectorPostprocessing::CalculateBoxIOU(Box& box1, Box& box2)
261{
262 float boxes_intersection = CalculateBoxIntersect(box1, box2);
263 if (boxes_intersection == 0) {
264 return 0;
265 }
266
267 float boxes_union = CalculateBoxUnion(box1, box2);
268 if (boxes_union == 0) {
269 return 0;
270 }
271
272 return boxes_intersection / boxes_union;
273}
274
275void DetectorPostprocessing::CalculateNMS(std::forward_list<Detection>& detections, int classes, float iouThreshold)
276{
277 int idxClass{0};
278 auto CompareProbs = [idxClass](Detection& prob1, Detection& prob2) {
279 return prob1.prob[idxClass] > prob2.prob[idxClass];
280 };
281
282 for (idxClass = 0; idxClass < classes; ++idxClass) {
283 detections.sort(CompareProbs);
284
285 for (std::forward_list<Detection>::iterator it=detections.begin(); it != detections.end(); ++it) {
286 if (it->prob[idxClass] == 0) continue;
287 for (std::forward_list<Detection>::iterator itc=std::next(it, 1); itc != detections.end(); ++itc) {
288 if (itc->prob[idxClass] == 0) {
289 continue;
290 }
291 if (CalculateBoxIOU(it->bbox, itc->bbox) > iouThreshold) {
292 itc->prob[idxClass] = 0;
293 }
294 }
295 }
296 }
297}
298
299void DetectorPostprocessing::DrawBoxOnImage(uint8_t* imgIn, int imWidth, int imHeight, int boxX,int boxY, int boxWidth, int boxHeight)
300{
301 auto CheckAndFixOffset = [](int im_width,int im_height,int& offset) {
302 if ( (offset) >= im_width*im_height*channelsImageDisplayed) {
303 offset = im_width * im_height * channelsImageDisplayed -1;
304 }
305 else if ( (offset) < 0) {
306 offset = 0;
307 }
308 };
309
310 /* Consistency checks */
311 if (!imgIn) {
312 return;
313 }
314
315 int offset=0;
316 for (int i=0; i < boxWidth; i++) {
317 /* Draw two horizontal lines */
318 for (int line=0; line < 2; line++) {
319 /*top*/
320 offset =(i + (boxY + line)*imWidth + boxX) * channelsImageDisplayed; /* channelsImageDisplayed for rgb or grayscale*/
321 CheckAndFixOffset(imWidth,imHeight,offset);
322 imgIn[offset] = 0xFF;
323 /*bottom*/
324 offset = (i + (boxY + boxHeight - line)*imWidth + boxX) * channelsImageDisplayed;
325 CheckAndFixOffset(imWidth,imHeight,offset);
326 imgIn[offset] = 0xFF;
327 }
328 }
329
330 for (int i=0; i < boxHeight; i++) {
331 /* Draw two vertical lines */
332 for (int line=0; line < 2; line++) {
333 /*left*/
334 offset = ((i + boxY)*imWidth + boxX + line)*channelsImageDisplayed;
335 CheckAndFixOffset(imWidth,imHeight,offset);
336 imgIn[offset] = 0xFF;
337 /*right*/
338 offset = ((i + boxY)*imWidth + boxX + boxWidth - line)*channelsImageDisplayed;
339 CheckAndFixOffset(imWidth,imHeight, offset);
340 imgIn[offset] = 0xFF;
341 }
342 }
343
344}
345
346} /* namespace object_detection */
347} /* namespace app */
348} /* namespace arm */