blob: cdb14f59d082afed97d491ff15dbec4b1b0baf21 [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef DETECTOR_POST_PROCESSING_HPP
18#define DETECTOR_POST_PROCESSING_HPP
19
20#include "UseCaseCommonUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000021#include "ImageUtils.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000022#include "DetectionResult.hpp"
23#include "YoloFastestModel.hpp"
24
25#include <forward_list>
26
27namespace arm {
28namespace app {
29namespace object_detection {
30
31 struct Branch {
32 int resolution;
33 int numBox;
34 const float* anchor;
35 int8_t* modelOutput;
36 float scale;
37 int zeroPoint;
38 size_t size;
39 };
40
41 struct Network {
42 int inputWidth;
43 int inputHeight;
44 int numClasses;
45 std::vector<Branch> branches;
46 int topN;
47 };
48
Isabella Gottardi3107aa22022-01-27 16:39:37 +000049 /**
50 * @brief Helper class to manage tensor post-processing for "object_detection"
51 * output.
52 */
53 class DetectorPostprocessing {
54 public:
55 /**
56 * @brief Constructor.
57 * @param[in] threshold Post-processing threshold.
58 * @param[in] nms Non-maximum Suppression threshold.
59 * @param[in] numClasses Number of classes.
60 * @param[in] topN Top N for each class.
61 **/
Richard Burton9c549902022-02-15 16:39:18 +000062 explicit DetectorPostprocessing(float threshold = 0.5f,
63 float nms = 0.45f,
64 int numClasses = 1,
65 int topN = 0);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000066
67 /**
Richard Burton9c549902022-02-15 16:39:18 +000068 * @brief Post processing part of YOLO object detection CNN.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000069 * @param[in] imgRows Number of rows in the input image.
70 * @param[in] imgCols Number of columns in the input image.
71 * @param[in] modelOutput Output tensors after CNN invoked.
72 * @param[out] resultsOut Vector of detected results.
73 **/
Richard Burton9c549902022-02-15 16:39:18 +000074 void RunPostProcessing(uint32_t imgRows,
Isabella Gottardi3107aa22022-01-27 16:39:37 +000075 uint32_t imgCols,
76 TfLiteTensor* modelOutput0,
77 TfLiteTensor* modelOutput1,
78 std::vector<DetectionResult>& resultsOut);
79
80 private:
81 float m_threshold; /* Post-processing threshold */
82 float m_nms; /* NMS threshold */
83 int m_numClasses; /* Number of classes */
84 int m_topN; /* TopN */
85
86 /**
Isabella Gottardi3107aa22022-01-27 16:39:37 +000087 * @brief Insert the given Detection in the list.
88 * @param[in] detections List of detections.
89 * @param[in] det Detection to be inserted.
90 **/
Richard Burtoned35a6f2022-02-14 11:55:35 +000091 void InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000092
93 /**
94 * @brief Given a Network calculate the detection boxes.
95 * @param[in] net Network.
96 * @param[in] imageWidth Original image width.
97 * @param[in] imageHeight Original image height.
98 * @param[in] threshold Detections threshold.
99 * @param[out] detections Detection boxes.
100 **/
101 void GetNetworkBoxes(Network& net,
102 int imageWidth,
103 int imageHeight,
104 float threshold,
Richard Burtoned35a6f2022-02-14 11:55:35 +0000105 std::forward_list<image::Detection>& detections);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000106
107 /**
108 * @brief Draw on the given image a bounding box starting at (boxX, boxY).
109 * @param[in/out] imgIn Image.
110 * @param[in] imWidth Image width.
111 * @param[in] imHeight Image height.
112 * @param[in] boxX Axis X starting point.
113 * @param[in] boxY Axis Y starting point.
114 * @param[in] boxWidth Box width.
115 * @param[in] boxHeight Box height.
116 **/
117 void DrawBoxOnImage(uint8_t* imgIn,
118 int imWidth,
119 int imHeight,
120 int boxX,
121 int boxY,
122 int boxWidth,
123 int boxHeight);
124 };
125
126} /* namespace object_detection */
127} /* namespace app */
128} /* namespace arm */
129
130#endif /* DETECTOR_POST_PROCESSING_HPP */