blob: e781b62bf23bcb82ea68140dbd5c27e37e0d8235 [file] [log] [blame]
Michael Levit06fcf752022-01-12 11:53:46 +02001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "DetectorPostProcessing.hpp"
18#include <algorithm>
19#include <cmath>
20#include <stdint.h>
21#include <forward_list>
22
23
24typedef struct boxabs {
25 float left, right, top, bot;
26} boxabs;
27
28
29typedef struct branch {
30 int resolution;
31 int num_box;
32 float *anchor;
33 int8_t *tf_output;
34 float scale;
35 int zero_point;
36 size_t size;
37 float scale_x_y;
38} branch;
39
40typedef struct network {
41 int input_w;
42 int input_h;
43 int num_classes;
44 int num_branch;
45 branch *branchs;
46 int topN;
47} network;
48
49
50typedef struct box {
51 float x, y, w, h;
52} box;
53
54typedef struct detection{
55 box bbox;
56 float *prob;
57 float objectness;
58} detection;
59
60
61
62static int sort_class;
63
64static void free_dets(std::forward_list<detection> &dets){
65 std::forward_list<detection>::iterator it;
66 for ( it = dets.begin(); it != dets.end(); ++it ){
67 free(it->prob);
68 }
69}
70
71float sigmoid(float x)
72{
73 return 1.f/(1.f + exp(-x));
74}
75
76static bool det_objectness_comparator(detection &pa, detection &pb)
77{
78 return pa.objectness < pb.objectness;
79}
80
81static void insert_topN_det(std::forward_list<detection> &dets, detection det)
82{
83 std::forward_list<detection>::iterator it;
84 std::forward_list<detection>::iterator last_it;
85 for ( it = dets.begin(); it != dets.end(); ++it ){
86 if(it->objectness > det.objectness)
87 break;
88 last_it = it;
89 }
90 if(it != dets.begin()){
91 dets.emplace_after(last_it, det);
92 free(dets.begin()->prob);
93 dets.pop_front();
94 }
95 else{
96 free(det.prob);
97 }
98}
99
100static std::forward_list<detection> get_network_boxes(network *net, int image_w, int image_h, float thresh, int *num)
101{
102 std::forward_list<detection> dets;
103 int i;
104 int num_classes = net->num_classes;
105 *num = 0;
106
107 for (i = 0; i < net->num_branch; ++i) {
108 int height = net->branchs[i].resolution;
109 int width = net->branchs[i].resolution;
110 int channel = net->branchs[i].num_box*(5+num_classes);
111
112 for (int h = 0; h < net->branchs[i].resolution; h++) {
113 for (int w = 0; w < net->branchs[i].resolution; w++) {
114 for (int anc = 0; anc < net->branchs[i].num_box; anc++) {
115
116 // objectness score
117 int bbox_obj_offset = h * width * channel + w * channel + anc * (num_classes + 5) + 4;
118 float objectness = sigmoid(((float)net->branchs[i].tf_output[bbox_obj_offset] - net->branchs[i].zero_point) * net->branchs[i].scale);
119
120 if(objectness > thresh){
121 detection det;
122 det.prob = (float*)calloc(num_classes, sizeof(float));
123 det.objectness = objectness;
124 //get bbox prediction data for each anchor, each feature point
125 int bbox_x_offset = bbox_obj_offset -4;
126 int bbox_y_offset = bbox_x_offset + 1;
127 int bbox_w_offset = bbox_x_offset + 2;
128 int bbox_h_offset = bbox_x_offset + 3;
129 int bbox_scores_offset = bbox_x_offset + 5;
130 //int bbox_scores_step = 1;
131 det.bbox.x = ((float)net->branchs[i].tf_output[bbox_x_offset] - net->branchs[i].zero_point) * net->branchs[i].scale;
132 det.bbox.y = ((float)net->branchs[i].tf_output[bbox_y_offset] - net->branchs[i].zero_point) * net->branchs[i].scale;
133 det.bbox.w = ((float)net->branchs[i].tf_output[bbox_w_offset] - net->branchs[i].zero_point) * net->branchs[i].scale;
134 det.bbox.h = ((float)net->branchs[i].tf_output[bbox_h_offset] - net->branchs[i].zero_point) * net->branchs[i].scale;
135
136
137 float bbox_x, bbox_y;
138
139 // Eliminate grid sensitivity trick involved in YOLOv4
140 bbox_x = sigmoid(det.bbox.x); //* net->branchs[i].scale_x_y - (net->branchs[i].scale_x_y - 1) / 2;
141 bbox_y = sigmoid(det.bbox.y); //* net->branchs[i].scale_x_y - (net->branchs[i].scale_x_y - 1) / 2;
142 det.bbox.x = (bbox_x + w) / width;
143 det.bbox.y = (bbox_y + h) / height;
144
145 det.bbox.w = exp(det.bbox.w) * net->branchs[i].anchor[anc*2] / net->input_w;
146 det.bbox.h = exp(det.bbox.h) * net->branchs[i].anchor[anc*2+1] / net->input_h;
147
148 for (int s = 0; s < num_classes; s++) {
149 det.prob[s] = sigmoid(((float)net->branchs[i].tf_output[bbox_scores_offset + s] - net->branchs[i].zero_point) * net->branchs[i].scale)*objectness;
150 det.prob[s] = (det.prob[s] > thresh) ? det.prob[s] : 0;
151 }
152
153 //correct_yolo_boxes
154 det.bbox.x *= image_w;
155 det.bbox.w *= image_w;
156 det.bbox.y *= image_h;
157 det.bbox.h *= image_h;
158
159 if (*num < net->topN || net->topN <=0){
160 dets.emplace_front(det);
161 *num += 1;
162 }
163 else if(*num == net->topN){
164 dets.sort(det_objectness_comparator);
165 insert_topN_det(dets,det);
166 *num += 1;
167 }else{
168 insert_topN_det(dets,det);
169 }
170 }
171 }
172 }
173 }
174 }
175 if(*num > net->topN)
176 *num -=1;
177 return dets;
178}
179
180// init part
181
182static branch create_brach(int resolution, int num_box, float *anchor, int8_t *tf_output, size_t size, float scale, int zero_point)
183{
184 branch b;
185 b.resolution = resolution;
186 b.num_box = num_box;
187 b.anchor = anchor;
188 b.tf_output = tf_output;
189 b.size = size;
190 b.scale = scale;
191 b.zero_point = zero_point;
192 return b;
193}
194
195static network creat_network(int input_w, int input_h, int num_classes, int num_branch, branch* branchs, int topN)
196{
197 network net;
198 net.input_w = input_w;
199 net.input_h = input_h;
200 net.num_classes = num_classes;
201 net.num_branch = num_branch;
202 net.branchs = branchs;
203 net.topN = topN;
204 return net;
205}
206
207// NMS part
208
209static float Calc1DOverlap(float x1_center, float width1, float x2_center, float width2)
210{
211 float left_1 = x1_center - width1/2;
212 float left_2 = x2_center - width2/2;
213 float leftest;
214 if (left_1 > left_2) {
215 leftest = left_1;
216 } else {
217 leftest = left_2;
218 }
219
220 float right_1 = x1_center + width1/2;
221 float right_2 = x2_center + width2/2;
222 float rightest;
223 if (right_1 < right_2) {
224 rightest = right_1;
225 } else {
226 rightest = right_2;
227 }
228
229 return rightest - leftest;
230}
231
232
233static float CalcBoxIntersect(box box1, box box2)
234{
235 float width = Calc1DOverlap(box1.x, box1.w, box2.x, box2.w);
236 if (width < 0) return 0;
237 float height = Calc1DOverlap(box1.y, box1.h, box2.y, box2.h);
238 if (height < 0) return 0;
239
240 float total_area = width*height;
241 return total_area;
242}
243
244
245static float CalcBoxUnion(box box1, box box2)
246{
247 float boxes_intersection = CalcBoxIntersect(box1, box2);
248 float boxes_union = box1.w*box1.h + box2.w*box2.h - boxes_intersection;
249 return boxes_union;
250}
251
252
253static float CalcBoxIOU(box box1, box box2)
254{
255 float boxes_intersection = CalcBoxIntersect(box1, box2);
256
257 if (boxes_intersection == 0) return 0;
258
259 float boxes_union = CalcBoxUnion(box1, box2);
260
261 if (boxes_union == 0) return 0;
262
263 return boxes_intersection / boxes_union;
264}
265
266
267static bool CompareProbs(detection &prob1, detection &prob2)
268{
269 return prob1.prob[sort_class] > prob2.prob[sort_class];
270}
271
272
273static void CalcNMS(std::forward_list<detection> &detections, int classes, float iou_threshold)
274{
275 int k;
276
277 for (k = 0; k < classes; ++k) {
278 sort_class = k;
279 detections.sort(CompareProbs);
280
281 for (std::forward_list<detection>::iterator it=detections.begin(); it != detections.end(); ++it){
282 if (it->prob[k] == 0) continue;
283 for (std::forward_list<detection>::iterator itc=std::next(it, 1); itc != detections.end(); ++itc){
284 if (itc->prob[k] == 0) continue;
285 if (CalcBoxIOU(it->bbox, itc->bbox) > iou_threshold) {
286 itc->prob[k] = 0;
287 }
288 }
289 }
290 }
291}
292
293
294static void inline check_and_fix_offset(int im_w,int im_h,int *offset)
295{
296
297 if (!offset) return;
298
299 if ( (*offset) >= im_w*im_h*FORMAT_MULTIPLY_FACTOR)
300 (*offset) = im_w*im_h*FORMAT_MULTIPLY_FACTOR -1;
301 else if ( (*offset) < 0)
302 *offset =0;
303
304}
305
306
307static void DrawBoxOnImage(uint8_t *img_in,int im_w,int im_h,int bx,int by,int bw,int bh)
308{
309
310 if (!img_in) {
311 return;
312 }
313
314 int offset=0;
315 for (int i=0; i < bw; i++) {
316 /*draw two lines */
317 for (int line=0; line < 2; line++) {
318 /*top*/
319 offset =(i + (by + line)*im_w + bx)*FORMAT_MULTIPLY_FACTOR;
320 check_and_fix_offset(im_w,im_h,&offset);
321 img_in[offset] = 0xFF; /* FORMAT_MULTIPLY_FACTOR for rgb or grayscale*/
322 /*bottom*/
323 offset = (i + (by + bh - line)*im_w + bx)*FORMAT_MULTIPLY_FACTOR;
324 check_and_fix_offset(im_w,im_h,&offset);
325 img_in[offset] = 0xFF;
326 }
327 }
328
329 for (int i=0; i < bh; i++) {
330 /*draw two lines */
331 for (int line=0; line < 2; line++) {
332 /*left*/
333 offset = ((i + by)*im_w + bx + line)*FORMAT_MULTIPLY_FACTOR;
334 check_and_fix_offset(im_w,im_h,&offset);
335 img_in[offset] = 0xFF;
336 /*right*/
337 offset = ((i + by)*im_w + bx + bw - line)*FORMAT_MULTIPLY_FACTOR;
338 check_and_fix_offset(im_w,im_h,&offset);
339 img_in[offset] = 0xFF;
340 }
341 }
342
343}
344
345
346void arm::app::RunPostProcessing(uint8_t *img_in,TfLiteTensor* model_output[2],std::vector<arm::app::DetectionResult> & results_out)
347{
348
349 TfLiteTensor* output[2] = {nullptr,nullptr};
350 int input_w = INPUT_IMAGE_WIDTH;
351 int input_h = INPUT_IMAGE_HEIGHT;
352
353 for(int anchor=0;anchor<2;anchor++)
354 {
355 output[anchor] = model_output[anchor];
356 }
357
358 /* init postprocessing */
359 int num_classes = 1;
360 int num_branch = 2;
361 int topN = 0;
362
363 branch* branchs = (branch*)calloc(num_branch, sizeof(branch));
364
365 /*NOTE: anchors are different for any given input model size, estimated during training phase */
366 float anchor1[] = {38, 77, 47, 97, 61, 126};
367 float anchor2[] = {14, 26, 19, 37, 28, 55 };
368
369
370 branchs[0] = create_brach(INPUT_IMAGE_WIDTH/32, 3, anchor1, output[0]->data.int8, output[0]->bytes, ((TfLiteAffineQuantization*)(output[0]->quantization.params))->scale->data[0], ((TfLiteAffineQuantization*)(output[0]->quantization.params))->zero_point->data[0]);
371
372 branchs[1] = create_brach(INPUT_IMAGE_WIDTH/16, 3, anchor2, output[1]->data.int8, output[1]->bytes, ((TfLiteAffineQuantization*)(output[1]->quantization.params))->scale->data[0],((TfLiteAffineQuantization*)(output[1]->quantization.params))->zero_point->data[0]);
373
374 network net = creat_network(input_w, input_h, num_classes, num_branch, branchs,topN);
375 /* end init */
376
377 /* start postprocessing */
378 int nboxes=0;
379 float thresh = .5;//50%
380 float nms = .45;
381 int orig_image_width = ORIGINAL_IMAGE_WIDTH;
382 int orig_image_height = ORIGINAL_IMAGE_HEIGHT;
383 std::forward_list<detection> dets = get_network_boxes(&net, orig_image_width, orig_image_height, thresh, &nboxes);
384 /* do nms */
385 CalcNMS(dets, net.num_classes, nms);
386 uint8_t temp_unsuppressed_counter = 0;
387 int j;
388 for (std::forward_list<detection>::iterator it=dets.begin(); it != dets.end(); ++it){
389 float xmin = it->bbox.x - it->bbox.w / 2.0f;
390 float xmax = it->bbox.x + it->bbox.w / 2.0f;
391 float ymin = it->bbox.y - it->bbox.h / 2.0f;
392 float ymax = it->bbox.y + it->bbox.h / 2.0f;
393
394 if (xmin < 0) xmin = 0;
395 if (ymin < 0) ymin = 0;
396 if (xmax > orig_image_width) xmax = orig_image_width;
397 if (ymax > orig_image_height) ymax = orig_image_height;
398
399 float bx = xmin;
400 float by = ymin;
401 float bw = xmax - xmin;
402 float bh = ymax - ymin;
403
404 for (j = 0; j < net.num_classes; ++j) {
405 if (it->prob[j] > 0) {
406
407 arm::app::DetectionResult tmp_result = {};
408
409 tmp_result.m_normalisedVal = it->prob[j];
410 tmp_result.m_x0=bx;
411 tmp_result.m_y0=by;
412 tmp_result.m_w=bw;
413 tmp_result.m_h=bh;
414
415 results_out.push_back(tmp_result);
416
417 DrawBoxOnImage(img_in,orig_image_width,orig_image_height,bx,by,bw,bh);
418
419 temp_unsuppressed_counter++;
420 }
421 }
422 }
423
424 free_dets(dets);
425 free(branchs);
426
427}
428
429void arm::app::RgbToGrayscale(const uint8_t *rgb,uint8_t *gray, int im_w,int im_h)
430{
431 float R=0.299;
432 float G=0.587;
433 float B=0.114;
434 for (int i=0; i< im_w*im_h; i++ ) {
435
436 uint32_t int_gray = rgb[i*3 + 0]*R + rgb[i*3 + 1]*G+ rgb[i*3 + 2]*B;
437 /*clip if need */
438 if (int_gray <= UINT8_MAX) {
439 gray[i] = int_gray;
440 } else {
441 gray[i] = UINT8_MAX;
442 }
443
444 }
445
446}
447