blob: 7610c4fa5cb752663d51a67610594b97b9d31c4a [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "DetectorPostProcessing.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000018#include "PlatformMath.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000019
Isabella Gottardi3107aa22022-01-27 16:39:37 +000020#include <cmath>
21
22namespace arm {
23namespace app {
Isabella Gottardi3107aa22022-01-27 16:39:37 +000024
Richard Burtonef904972022-04-27 17:24:36 +010025 DetectorPostProcess::DetectorPostProcess(
26 TfLiteTensor* modelOutput0,
27 TfLiteTensor* modelOutput1,
28 std::vector<object_detection::DetectionResult>& results,
29 int inputImgRows,
30 int inputImgCols,
31 const float threshold,
32 const float nms,
33 int numClasses,
34 int topN)
35 : m_outputTensor0{modelOutput0},
36 m_outputTensor1{modelOutput1},
37 m_results{results},
38 m_inputImgRows{inputImgRows},
39 m_inputImgCols{inputImgCols},
40 m_threshold(threshold),
41 m_nms(nms),
42 m_numClasses(numClasses),
43 m_topN(topN)
Isabella Gottardi3107aa22022-01-27 16:39:37 +000044{
Richard Burtonef904972022-04-27 17:24:36 +010045 /* Init PostProcessing */
Liam Barry213a5432022-05-09 17:06:19 +010046 this->m_net = object_detection::Network{
47 .inputWidth = inputImgCols,
Richard Burtonef904972022-04-27 17:24:36 +010048 .inputHeight = inputImgRows,
Liam Barry213a5432022-05-09 17:06:19 +010049 .numClasses = numClasses,
50 .branches =
51 {object_detection::Branch{.resolution = inputImgCols / 32,
52 .numBox = 3,
53 .anchor = arm::app::object_detection::anchor1,
54 .modelOutput = this->m_outputTensor0->data.int8,
55 .scale = (static_cast<TfLiteAffineQuantization*>(
56 this->m_outputTensor0->quantization.params))
57 ->scale->data[0],
58 .zeroPoint = (static_cast<TfLiteAffineQuantization*>(
59 this->m_outputTensor0->quantization.params))
60 ->zero_point->data[0],
61 .size = this->m_outputTensor0->bytes},
62 object_detection::Branch{.resolution = inputImgCols / 16,
63 .numBox = 3,
64 .anchor = arm::app::object_detection::anchor2,
65 .modelOutput = this->m_outputTensor1->data.int8,
66 .scale = (static_cast<TfLiteAffineQuantization*>(
67 this->m_outputTensor1->quantization.params))
68 ->scale->data[0],
69 .zeroPoint = (static_cast<TfLiteAffineQuantization*>(
70 this->m_outputTensor1->quantization.params))
71 ->zero_point->data[0],
72 .size = this->m_outputTensor1->bytes}},
73 .topN = m_topN};
Isabella Gottardi3107aa22022-01-27 16:39:37 +000074 /* End init */
Richard Burtonef904972022-04-27 17:24:36 +010075}
Isabella Gottardi3107aa22022-01-27 16:39:37 +000076
Richard Burtonef904972022-04-27 17:24:36 +010077bool DetectorPostProcess::DoPostProcess()
78{
Isabella Gottardi3107aa22022-01-27 16:39:37 +000079 /* Start postprocessing */
Liam Barry213a5432022-05-09 17:06:19 +010080 int originalImageWidth = arm::app::object_detection::originalImageSize;
81 int originalImageHeight = arm::app::object_detection::originalImageSize;
Isabella Gottardi3107aa22022-01-27 16:39:37 +000082
Richard Burtoned35a6f2022-02-14 11:55:35 +000083 std::forward_list<image::Detection> detections;
Richard Burtonef904972022-04-27 17:24:36 +010084 GetNetworkBoxes(this->m_net, originalImageWidth, originalImageHeight, m_threshold, detections);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000085
86 /* Do nms */
Richard Burtonef904972022-04-27 17:24:36 +010087 CalculateNMS(detections, this->m_net.numClasses, m_nms);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000088
89 for (auto& it: detections) {
90 float xMin = it.bbox.x - it.bbox.w / 2.0f;
91 float xMax = it.bbox.x + it.bbox.w / 2.0f;
92 float yMin = it.bbox.y - it.bbox.h / 2.0f;
93 float yMax = it.bbox.y + it.bbox.h / 2.0f;
94
95 if (xMin < 0) {
96 xMin = 0;
97 }
98 if (yMin < 0) {
99 yMin = 0;
100 }
101 if (xMax > originalImageWidth) {
102 xMax = originalImageWidth;
103 }
104 if (yMax > originalImageHeight) {
105 yMax = originalImageHeight;
106 }
107
108 float boxX = xMin;
109 float boxY = yMin;
110 float boxWidth = xMax - xMin;
111 float boxHeight = yMax - yMin;
112
Richard Burtonef904972022-04-27 17:24:36 +0100113 for (int j = 0; j < this->m_net.numClasses; ++j) {
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000114 if (it.prob[j] > 0) {
115
Richard Burtonef904972022-04-27 17:24:36 +0100116 object_detection::DetectionResult tmpResult = {};
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000117 tmpResult.m_normalisedVal = it.prob[j];
118 tmpResult.m_x0 = boxX;
119 tmpResult.m_y0 = boxY;
120 tmpResult.m_w = boxWidth;
121 tmpResult.m_h = boxHeight;
122
Richard Burtonef904972022-04-27 17:24:36 +0100123 this->m_results.push_back(tmpResult);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000124 }
125 }
126 }
Richard Burtonef904972022-04-27 17:24:36 +0100127 return true;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000128}
129
Richard Burtonef904972022-04-27 17:24:36 +0100130void DetectorPostProcess::InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det)
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000131{
Richard Burtoned35a6f2022-02-14 11:55:35 +0000132 std::forward_list<image::Detection>::iterator it;
133 std::forward_list<image::Detection>::iterator last_it;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000134 for ( it = detections.begin(); it != detections.end(); ++it ) {
135 if(it->objectness > det.objectness)
136 break;
137 last_it = it;
138 }
139 if(it != detections.begin()) {
140 detections.emplace_after(last_it, det);
141 detections.pop_front();
142 }
143}
144
Richard Burtonef904972022-04-27 17:24:36 +0100145void DetectorPostProcess::GetNetworkBoxes(
146 object_detection::Network& net,
147 int imageWidth,
148 int imageHeight,
149 float threshold,
150 std::forward_list<image::Detection>& detections)
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000151{
152 int numClasses = net.numClasses;
153 int num = 0;
Richard Burtoned35a6f2022-02-14 11:55:35 +0000154 auto det_objectness_comparator = [](image::Detection& pa, image::Detection& pb) {
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000155 return pa.objectness < pb.objectness;
156 };
157 for (size_t i = 0; i < net.branches.size(); ++i) {
158 int height = net.branches[i].resolution;
159 int width = net.branches[i].resolution;
160 int channel = net.branches[i].numBox*(5+numClasses);
161
162 for (int h = 0; h < net.branches[i].resolution; h++) {
163 for (int w = 0; w < net.branches[i].resolution; w++) {
164 for (int anc = 0; anc < net.branches[i].numBox; anc++) {
165
166 /* Objectness score */
167 int bbox_obj_offset = h * width * channel + w * channel + anc * (numClasses + 5) + 4;
Richard Burton9c549902022-02-15 16:39:18 +0000168 float objectness = math::MathUtils::SigmoidF32(
169 (static_cast<float>(net.branches[i].modelOutput[bbox_obj_offset])
170 - net.branches[i].zeroPoint
171 ) * net.branches[i].scale);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000172
173 if(objectness > threshold) {
Richard Burtoned35a6f2022-02-14 11:55:35 +0000174 image::Detection det;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000175 det.objectness = objectness;
176 /* Get bbox prediction data for each anchor, each feature point */
177 int bbox_x_offset = bbox_obj_offset -4;
178 int bbox_y_offset = bbox_x_offset + 1;
179 int bbox_w_offset = bbox_x_offset + 2;
180 int bbox_h_offset = bbox_x_offset + 3;
181 int bbox_scores_offset = bbox_x_offset + 5;
182
Richard Burtonef904972022-04-27 17:24:36 +0100183 det.bbox.x = (static_cast<float>(net.branches[i].modelOutput[bbox_x_offset])
184 - net.branches[i].zeroPoint) * net.branches[i].scale;
185 det.bbox.y = (static_cast<float>(net.branches[i].modelOutput[bbox_y_offset])
186 - net.branches[i].zeroPoint) * net.branches[i].scale;
187 det.bbox.w = (static_cast<float>(net.branches[i].modelOutput[bbox_w_offset])
188 - net.branches[i].zeroPoint) * net.branches[i].scale;
189 det.bbox.h = (static_cast<float>(net.branches[i].modelOutput[bbox_h_offset])
190 - net.branches[i].zeroPoint) * net.branches[i].scale;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000191
192 float bbox_x, bbox_y;
193
194 /* Eliminate grid sensitivity trick involved in YOLOv4 */
Richard Burtoned35a6f2022-02-14 11:55:35 +0000195 bbox_x = math::MathUtils::SigmoidF32(det.bbox.x);
196 bbox_y = math::MathUtils::SigmoidF32(det.bbox.y);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000197 det.bbox.x = (bbox_x + w) / width;
198 det.bbox.y = (bbox_y + h) / height;
199
Richard Burton9c549902022-02-15 16:39:18 +0000200 det.bbox.w = std::exp(det.bbox.w) * net.branches[i].anchor[anc*2] / net.inputWidth;
201 det.bbox.h = std::exp(det.bbox.h) * net.branches[i].anchor[anc*2+1] / net.inputHeight;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000202
203 for (int s = 0; s < numClasses; s++) {
Richard Burton9c549902022-02-15 16:39:18 +0000204 float sig = math::MathUtils::SigmoidF32(
205 (static_cast<float>(net.branches[i].modelOutput[bbox_scores_offset + s]) -
206 net.branches[i].zeroPoint) * net.branches[i].scale
207 ) * objectness;
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000208 det.prob.emplace_back((sig > threshold) ? sig : 0);
209 }
210
211 /* Correct_YOLO_boxes */
212 det.bbox.x *= imageWidth;
213 det.bbox.w *= imageWidth;
214 det.bbox.y *= imageHeight;
215 det.bbox.h *= imageHeight;
216
217 if (num < net.topN || net.topN <=0) {
218 detections.emplace_front(det);
219 num += 1;
220 } else if (num == net.topN) {
221 detections.sort(det_objectness_comparator);
222 InsertTopNDetections(detections,det);
223 num += 1;
224 } else {
225 InsertTopNDetections(detections,det);
226 }
227 }
228 }
229 }
230 }
231 }
232 if(num > net.topN)
233 num -=1;
234}
235
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000236} /* namespace app */
237} /* namespace arm */