blob: a890c9ee38eab4252299148880d856fcee8a1cab [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "DetectorPostProcessing.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000018#include "PlatformMath.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000019
Isabella Gottardi3107aa22022-01-27 16:39:37 +000020#include <cmath>
21
22namespace arm {
23namespace app {
24namespace object_detection {
25
26DetectorPostprocessing::DetectorPostprocessing(
27 const float threshold,
28 const float nms,
29 int numClasses,
30 int topN)
31 : m_threshold(threshold),
32 m_nms(nms),
33 m_numClasses(numClasses),
34 m_topN(topN)
35{}
36
37void DetectorPostprocessing::RunPostProcessing(
Isabella Gottardi3107aa22022-01-27 16:39:37 +000038 uint32_t imgRows,
39 uint32_t imgCols,
40 TfLiteTensor* modelOutput0,
41 TfLiteTensor* modelOutput1,
42 std::vector<DetectionResult>& resultsOut)
43{
44 /* init postprocessing */
45 Network net {
46 .inputWidth = static_cast<int>(imgCols),
47 .inputHeight = static_cast<int>(imgRows),
48 .numClasses = m_numClasses,
49 .branches = {
50 Branch {
51 .resolution = static_cast<int>(imgCols/32),
52 .numBox = 3,
53 .anchor = anchor1,
54 .modelOutput = modelOutput0->data.int8,
55 .scale = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->scale->data[0],
56 .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->zero_point->data[0],
57 .size = modelOutput0->bytes
58 },
59 Branch {
60 .resolution = static_cast<int>(imgCols/16),
61 .numBox = 3,
62 .anchor = anchor2,
63 .modelOutput = modelOutput1->data.int8,
64 .scale = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->scale->data[0],
65 .zeroPoint = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->zero_point->data[0],
66 .size = modelOutput1->bytes
67 }
68 },
69 .topN = m_topN
70 };
71 /* End init */
72
73 /* Start postprocessing */
74 int originalImageWidth = originalImageSize;
75 int originalImageHeight = originalImageSize;
76
Richard Burtoned35a6f2022-02-14 11:55:35 +000077 std::forward_list<image::Detection> detections;
Isabella Gottardi3107aa22022-01-27 16:39:37 +000078 GetNetworkBoxes(net, originalImageWidth, originalImageHeight, m_threshold, detections);
79
80 /* Do nms */
81 CalculateNMS(detections, net.numClasses, m_nms);
82
83 for (auto& it: detections) {
84 float xMin = it.bbox.x - it.bbox.w / 2.0f;
85 float xMax = it.bbox.x + it.bbox.w / 2.0f;
86 float yMin = it.bbox.y - it.bbox.h / 2.0f;
87 float yMax = it.bbox.y + it.bbox.h / 2.0f;
88
89 if (xMin < 0) {
90 xMin = 0;
91 }
92 if (yMin < 0) {
93 yMin = 0;
94 }
95 if (xMax > originalImageWidth) {
96 xMax = originalImageWidth;
97 }
98 if (yMax > originalImageHeight) {
99 yMax = originalImageHeight;
100 }
101
102 float boxX = xMin;
103 float boxY = yMin;
104 float boxWidth = xMax - xMin;
105 float boxHeight = yMax - yMin;
106
107 for (int j = 0; j < net.numClasses; ++j) {
108 if (it.prob[j] > 0) {
109
110 DetectionResult tmpResult = {};
111 tmpResult.m_normalisedVal = it.prob[j];
112 tmpResult.m_x0 = boxX;
113 tmpResult.m_y0 = boxY;
114 tmpResult.m_w = boxWidth;
115 tmpResult.m_h = boxHeight;
116
117 resultsOut.push_back(tmpResult);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000118 }
119 }
120 }
121}
122
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000123
Richard Burtoned35a6f2022-02-14 11:55:35 +0000124void DetectorPostprocessing::InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det)
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000125{
Richard Burtoned35a6f2022-02-14 11:55:35 +0000126 std::forward_list<image::Detection>::iterator it;
127 std::forward_list<image::Detection>::iterator last_it;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000128 for ( it = detections.begin(); it != detections.end(); ++it ) {
129 if(it->objectness > det.objectness)
130 break;
131 last_it = it;
132 }
133 if(it != detections.begin()) {
134 detections.emplace_after(last_it, det);
135 detections.pop_front();
136 }
137}
138
Richard Burtoned35a6f2022-02-14 11:55:35 +0000139void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int imageHeight, float threshold, std::forward_list<image::Detection>& detections)
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000140{
141 int numClasses = net.numClasses;
142 int num = 0;
Richard Burtoned35a6f2022-02-14 11:55:35 +0000143 auto det_objectness_comparator = [](image::Detection& pa, image::Detection& pb) {
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000144 return pa.objectness < pb.objectness;
145 };
146 for (size_t i = 0; i < net.branches.size(); ++i) {
147 int height = net.branches[i].resolution;
148 int width = net.branches[i].resolution;
149 int channel = net.branches[i].numBox*(5+numClasses);
150
151 for (int h = 0; h < net.branches[i].resolution; h++) {
152 for (int w = 0; w < net.branches[i].resolution; w++) {
153 for (int anc = 0; anc < net.branches[i].numBox; anc++) {
154
155 /* Objectness score */
156 int bbox_obj_offset = h * width * channel + w * channel + anc * (numClasses + 5) + 4;
Richard Burton9c549902022-02-15 16:39:18 +0000157 float objectness = math::MathUtils::SigmoidF32(
158 (static_cast<float>(net.branches[i].modelOutput[bbox_obj_offset])
159 - net.branches[i].zeroPoint
160 ) * net.branches[i].scale);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000161
162 if(objectness > threshold) {
Richard Burtoned35a6f2022-02-14 11:55:35 +0000163 image::Detection det;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000164 det.objectness = objectness;
165 /* Get bbox prediction data for each anchor, each feature point */
166 int bbox_x_offset = bbox_obj_offset -4;
167 int bbox_y_offset = bbox_x_offset + 1;
168 int bbox_w_offset = bbox_x_offset + 2;
169 int bbox_h_offset = bbox_x_offset + 3;
170 int bbox_scores_offset = bbox_x_offset + 5;
171
Richard Burton9c549902022-02-15 16:39:18 +0000172 det.bbox.x = (static_cast<float>(net.branches[i].modelOutput[bbox_x_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
173 det.bbox.y = (static_cast<float>(net.branches[i].modelOutput[bbox_y_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
174 det.bbox.w = (static_cast<float>(net.branches[i].modelOutput[bbox_w_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
175 det.bbox.h = (static_cast<float>(net.branches[i].modelOutput[bbox_h_offset]) - net.branches[i].zeroPoint) * net.branches[i].scale;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000176
177 float bbox_x, bbox_y;
178
179 /* Eliminate grid sensitivity trick involved in YOLOv4 */
Richard Burtoned35a6f2022-02-14 11:55:35 +0000180 bbox_x = math::MathUtils::SigmoidF32(det.bbox.x);
181 bbox_y = math::MathUtils::SigmoidF32(det.bbox.y);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000182 det.bbox.x = (bbox_x + w) / width;
183 det.bbox.y = (bbox_y + h) / height;
184
Richard Burton9c549902022-02-15 16:39:18 +0000185 det.bbox.w = std::exp(det.bbox.w) * net.branches[i].anchor[anc*2] / net.inputWidth;
186 det.bbox.h = std::exp(det.bbox.h) * net.branches[i].anchor[anc*2+1] / net.inputHeight;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000187
188 for (int s = 0; s < numClasses; s++) {
Richard Burton9c549902022-02-15 16:39:18 +0000189 float sig = math::MathUtils::SigmoidF32(
190 (static_cast<float>(net.branches[i].modelOutput[bbox_scores_offset + s]) -
191 net.branches[i].zeroPoint) * net.branches[i].scale
192 ) * objectness;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000193 det.prob.emplace_back((sig > threshold) ? sig : 0);
194 }
195
196 /* Correct_YOLO_boxes */
197 det.bbox.x *= imageWidth;
198 det.bbox.w *= imageWidth;
199 det.bbox.y *= imageHeight;
200 det.bbox.h *= imageHeight;
201
202 if (num < net.topN || net.topN <=0) {
203 detections.emplace_front(det);
204 num += 1;
205 } else if (num == net.topN) {
206 detections.sort(det_objectness_comparator);
207 InsertTopNDetections(detections,det);
208 num += 1;
209 } else {
210 InsertTopNDetections(detections,det);
211 }
212 }
213 }
214 }
215 }
216 }
217 if(num > net.topN)
218 num -=1;
219}
220
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000221} /* namespace object_detection */
222} /* namespace app */
223} /* namespace arm */