blob: edfb1374933670621d841b9e3ad3b91b5f463eea [file] [log] [blame]
/*
* Copyright (c) 2022 Arm Limited. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "DetectorPostProcessing.hpp"
#include <algorithm>
#include <cmath>
namespace arm {
namespace app {
namespace object_detection {
DetectorPostprocessing::DetectorPostprocessing(
const float threshold,
const float nms,
int numClasses,
int topN)
: m_threshold(threshold),
m_nms(nms),
m_numClasses(numClasses),
m_topN(topN)
{}
void DetectorPostprocessing::RunPostProcessing(
uint8_t* imgIn,
uint32_t imgRows,
uint32_t imgCols,
TfLiteTensor* modelOutput0,
TfLiteTensor* modelOutput1,
std::vector<DetectionResult>& resultsOut)
{
/* init postprocessing */
Network net {
.inputWidth = static_cast<int>(imgCols),
.inputHeight = static_cast<int>(imgRows),
.numClasses = m_numClasses,
.branches = {
Branch {
.resolution = static_cast<int>(imgCols/32),
.numBox = 3,
.anchor = anchor1,
.modelOutput = modelOutput0->data.int8,
.scale = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->scale->data[0],
.zeroPoint = ((TfLiteAffineQuantization*)(modelOutput0->quantization.params))->zero_point->data[0],
.size = modelOutput0->bytes
},
Branch {
.resolution = static_cast<int>(imgCols/16),
.numBox = 3,
.anchor = anchor2,
.modelOutput = modelOutput1->data.int8,
.scale = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->scale->data[0],
.zeroPoint = ((TfLiteAffineQuantization*)(modelOutput1->quantization.params))->zero_point->data[0],
.size = modelOutput1->bytes
}
},
.topN = m_topN
};
/* End init */
/* Start postprocessing */
int originalImageWidth = originalImageSize;
int originalImageHeight = originalImageSize;
std::forward_list<Detection> detections;
GetNetworkBoxes(net, originalImageWidth, originalImageHeight, m_threshold, detections);
/* Do nms */
CalculateNMS(detections, net.numClasses, m_nms);
for (auto& it: detections) {
float xMin = it.bbox.x - it.bbox.w / 2.0f;
float xMax = it.bbox.x + it.bbox.w / 2.0f;
float yMin = it.bbox.y - it.bbox.h / 2.0f;
float yMax = it.bbox.y + it.bbox.h / 2.0f;
if (xMin < 0) {
xMin = 0;
}
if (yMin < 0) {
yMin = 0;
}
if (xMax > originalImageWidth) {
xMax = originalImageWidth;
}
if (yMax > originalImageHeight) {
yMax = originalImageHeight;
}
float boxX = xMin;
float boxY = yMin;
float boxWidth = xMax - xMin;
float boxHeight = yMax - yMin;
for (int j = 0; j < net.numClasses; ++j) {
if (it.prob[j] > 0) {
DetectionResult tmpResult = {};
tmpResult.m_normalisedVal = it.prob[j];
tmpResult.m_x0 = boxX;
tmpResult.m_y0 = boxY;
tmpResult.m_w = boxWidth;
tmpResult.m_h = boxHeight;
resultsOut.push_back(tmpResult);
/* TODO: Instead of draw on the image, return the boxes and draw on the LCD */
DrawBoxOnImage(imgIn, originalImageWidth, originalImageHeight, boxX, boxY, boxWidth, boxHeight);;
}
}
}
}
float DetectorPostprocessing::Sigmoid(float x)
{
return 1.f/(1.f + exp(-x));
}
void DetectorPostprocessing::InsertTopNDetections(std::forward_list<Detection>& detections, Detection& det)
{
std::forward_list<Detection>::iterator it;
std::forward_list<Detection>::iterator last_it;
for ( it = detections.begin(); it != detections.end(); ++it ) {
if(it->objectness > det.objectness)
break;
last_it = it;
}
if(it != detections.begin()) {
detections.emplace_after(last_it, det);
detections.pop_front();
}
}
void DetectorPostprocessing::GetNetworkBoxes(Network& net, int imageWidth, int imageHeight, float threshold, std::forward_list<Detection>& detections)
{
int numClasses = net.numClasses;
int num = 0;
auto det_objectness_comparator = [](Detection& pa, Detection& pb) {
return pa.objectness < pb.objectness;
};
for (size_t i = 0; i < net.branches.size(); ++i) {
int height = net.branches[i].resolution;
int width = net.branches[i].resolution;
int channel = net.branches[i].numBox*(5+numClasses);
for (int h = 0; h < net.branches[i].resolution; h++) {
for (int w = 0; w < net.branches[i].resolution; w++) {
for (int anc = 0; anc < net.branches[i].numBox; anc++) {
/* Objectness score */
int bbox_obj_offset = h * width * channel + w * channel + anc * (numClasses + 5) + 4;
float objectness = Sigmoid(((float)net.branches[i].modelOutput[bbox_obj_offset] - net.branches[i].zeroPoint) * net.branches[i].scale);
if(objectness > threshold) {
Detection det;
det.objectness = objectness;
/* Get bbox prediction data for each anchor, each feature point */
int bbox_x_offset = bbox_obj_offset -4;
int bbox_y_offset = bbox_x_offset + 1;
int bbox_w_offset = bbox_x_offset + 2;
int bbox_h_offset = bbox_x_offset + 3;
int bbox_scores_offset = bbox_x_offset + 5;
det.bbox.x = ((float)net.branches[i].modelOutput[bbox_x_offset] - net.branches[i].zeroPoint) * net.branches[i].scale;
det.bbox.y = ((float)net.branches[i].modelOutput[bbox_y_offset] - net.branches[i].zeroPoint) * net.branches[i].scale;
det.bbox.w = ((float)net.branches[i].modelOutput[bbox_w_offset] - net.branches[i].zeroPoint) * net.branches[i].scale;
det.bbox.h = ((float)net.branches[i].modelOutput[bbox_h_offset] - net.branches[i].zeroPoint) * net.branches[i].scale;
float bbox_x, bbox_y;
/* Eliminate grid sensitivity trick involved in YOLOv4 */
bbox_x = Sigmoid(det.bbox.x);
bbox_y = Sigmoid(det.bbox.y);
det.bbox.x = (bbox_x + w) / width;
det.bbox.y = (bbox_y + h) / height;
det.bbox.w = exp(det.bbox.w) * net.branches[i].anchor[anc*2] / net.inputWidth;
det.bbox.h = exp(det.bbox.h) * net.branches[i].anchor[anc*2+1] / net.inputHeight;
for (int s = 0; s < numClasses; s++) {
float sig = Sigmoid(((float)net.branches[i].modelOutput[bbox_scores_offset + s] - net.branches[i].zeroPoint) * net.branches[i].scale)*objectness;
det.prob.emplace_back((sig > threshold) ? sig : 0);
}
/* Correct_YOLO_boxes */
det.bbox.x *= imageWidth;
det.bbox.w *= imageWidth;
det.bbox.y *= imageHeight;
det.bbox.h *= imageHeight;
if (num < net.topN || net.topN <=0) {
detections.emplace_front(det);
num += 1;
} else if (num == net.topN) {
detections.sort(det_objectness_comparator);
InsertTopNDetections(detections,det);
num += 1;
} else {
InsertTopNDetections(detections,det);
}
}
}
}
}
}
if(num > net.topN)
num -=1;
}
float DetectorPostprocessing::Calculate1DOverlap(float x1Center, float width1, float x2Center, float width2)
{
float left_1 = x1Center - width1/2;
float left_2 = x2Center - width2/2;
float leftest = left_1 > left_2 ? left_1 : left_2;
float right_1 = x1Center + width1/2;
float right_2 = x2Center + width2/2;
float rightest = right_1 < right_2 ? right_1 : right_2;
return rightest - leftest;
}
float DetectorPostprocessing::CalculateBoxIntersect(Box& box1, Box& box2)
{
float width = Calculate1DOverlap(box1.x, box1.w, box2.x, box2.w);
if (width < 0) {
return 0;
}
float height = Calculate1DOverlap(box1.y, box1.h, box2.y, box2.h);
if (height < 0) {
return 0;
}
float total_area = width*height;
return total_area;
}
float DetectorPostprocessing::CalculateBoxUnion(Box& box1, Box& box2)
{
float boxes_intersection = CalculateBoxIntersect(box1, box2);
float boxes_union = box1.w * box1.h + box2.w * box2.h - boxes_intersection;
return boxes_union;
}
float DetectorPostprocessing::CalculateBoxIOU(Box& box1, Box& box2)
{
float boxes_intersection = CalculateBoxIntersect(box1, box2);
if (boxes_intersection == 0) {
return 0;
}
float boxes_union = CalculateBoxUnion(box1, box2);
if (boxes_union == 0) {
return 0;
}
return boxes_intersection / boxes_union;
}
void DetectorPostprocessing::CalculateNMS(std::forward_list<Detection>& detections, int classes, float iouThreshold)
{
int idxClass{0};
auto CompareProbs = [idxClass](Detection& prob1, Detection& prob2) {
return prob1.prob[idxClass] > prob2.prob[idxClass];
};
for (idxClass = 0; idxClass < classes; ++idxClass) {
detections.sort(CompareProbs);
for (std::forward_list<Detection>::iterator it=detections.begin(); it != detections.end(); ++it) {
if (it->prob[idxClass] == 0) continue;
for (std::forward_list<Detection>::iterator itc=std::next(it, 1); itc != detections.end(); ++itc) {
if (itc->prob[idxClass] == 0) {
continue;
}
if (CalculateBoxIOU(it->bbox, itc->bbox) > iouThreshold) {
itc->prob[idxClass] = 0;
}
}
}
}
}
void DetectorPostprocessing::DrawBoxOnImage(uint8_t* imgIn, int imWidth, int imHeight, int boxX,int boxY, int boxWidth, int boxHeight)
{
auto CheckAndFixOffset = [](int im_width,int im_height,int& offset) {
if ( (offset) >= im_width*im_height*channelsImageDisplayed) {
offset = im_width * im_height * channelsImageDisplayed -1;
}
else if ( (offset) < 0) {
offset = 0;
}
};
/* Consistency checks */
if (!imgIn) {
return;
}
int offset=0;
for (int i=0; i < boxWidth; i++) {
/* Draw two horizontal lines */
for (int line=0; line < 2; line++) {
/*top*/
offset =(i + (boxY + line)*imWidth + boxX) * channelsImageDisplayed; /* channelsImageDisplayed for rgb or grayscale*/
CheckAndFixOffset(imWidth,imHeight,offset);
imgIn[offset] = 0xFF;
/*bottom*/
offset = (i + (boxY + boxHeight - line)*imWidth + boxX) * channelsImageDisplayed;
CheckAndFixOffset(imWidth,imHeight,offset);
imgIn[offset] = 0xFF;
}
}
for (int i=0; i < boxHeight; i++) {
/* Draw two vertical lines */
for (int line=0; line < 2; line++) {
/*left*/
offset = ((i + boxY)*imWidth + boxX + line)*channelsImageDisplayed;
CheckAndFixOffset(imWidth,imHeight,offset);
imgIn[offset] = 0xFF;
/*right*/
offset = ((i + boxY)*imWidth + boxX + boxWidth - line)*channelsImageDisplayed;
CheckAndFixOffset(imWidth,imHeight, offset);
imgIn[offset] = 0xFF;
}
}
}
} /* namespace object_detection */
} /* namespace app */
} /* namespace arm */