blob: 5393f895c40d7a661ed8e30114f3f2b45148ae77 [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef DETECTOR_POST_PROCESSING_HPP
18#define DETECTOR_POST_PROCESSING_HPP
19
20#include "UseCaseCommonUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000021#include "ImageUtils.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000022#include "DetectionResult.hpp"
23#include "YoloFastestModel.hpp"
24
25#include <forward_list>
26
27namespace arm {
28namespace app {
29namespace object_detection {
30
31 struct Branch {
32 int resolution;
33 int numBox;
34 const float* anchor;
35 int8_t* modelOutput;
36 float scale;
37 int zeroPoint;
38 size_t size;
39 };
40
41 struct Network {
42 int inputWidth;
43 int inputHeight;
44 int numClasses;
45 std::vector<Branch> branches;
46 int topN;
47 };
48
Isabella Gottardi3107aa22022-01-27 16:39:37 +000049 /**
50 * @brief Helper class to manage tensor post-processing for "object_detection"
51 * output.
52 */
53 class DetectorPostprocessing {
54 public:
55 /**
56 * @brief Constructor.
57 * @param[in] threshold Post-processing threshold.
58 * @param[in] nms Non-maximum Suppression threshold.
59 * @param[in] numClasses Number of classes.
60 * @param[in] topN Top N for each class.
61 **/
62 DetectorPostprocessing(float threshold = 0.5f,
63 float nms = 0.45f,
64 int numClasses = 1,
65 int topN = 0);
66
67 /**
68 * @brief Post processing part of Yolo object detection CNN.
69 * @param[in] imgIn Pointer to the input image,detection bounding boxes drown on it.
70 * @param[in] imgRows Number of rows in the input image.
71 * @param[in] imgCols Number of columns in the input image.
72 * @param[in] modelOutput Output tensors after CNN invoked.
73 * @param[out] resultsOut Vector of detected results.
74 **/
75 void RunPostProcessing(uint8_t* imgIn,
76 uint32_t imgRows,
77 uint32_t imgCols,
78 TfLiteTensor* modelOutput0,
79 TfLiteTensor* modelOutput1,
80 std::vector<DetectionResult>& resultsOut);
81
82 private:
83 float m_threshold; /* Post-processing threshold */
84 float m_nms; /* NMS threshold */
85 int m_numClasses; /* Number of classes */
86 int m_topN; /* TopN */
87
88 /**
Isabella Gottardi3107aa22022-01-27 16:39:37 +000089 * @brief Insert the given Detection in the list.
90 * @param[in] detections List of detections.
91 * @param[in] det Detection to be inserted.
92 **/
Richard Burtoned35a6f2022-02-14 11:55:35 +000093 void InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000094
95 /**
96 * @brief Given a Network calculate the detection boxes.
97 * @param[in] net Network.
98 * @param[in] imageWidth Original image width.
99 * @param[in] imageHeight Original image height.
100 * @param[in] threshold Detections threshold.
101 * @param[out] detections Detection boxes.
102 **/
103 void GetNetworkBoxes(Network& net,
104 int imageWidth,
105 int imageHeight,
106 float threshold,
Richard Burtoned35a6f2022-02-14 11:55:35 +0000107 std::forward_list<image::Detection>& detections);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000108
109 /**
110 * @brief Draw on the given image a bounding box starting at (boxX, boxY).
111 * @param[in/out] imgIn Image.
112 * @param[in] imWidth Image width.
113 * @param[in] imHeight Image height.
114 * @param[in] boxX Axis X starting point.
115 * @param[in] boxY Axis Y starting point.
116 * @param[in] boxWidth Box width.
117 * @param[in] boxHeight Box height.
118 **/
119 void DrawBoxOnImage(uint8_t* imgIn,
120 int imWidth,
121 int imHeight,
122 int boxX,
123 int boxY,
124 int boxWidth,
125 int boxHeight);
126 };
127
128} /* namespace object_detection */
129} /* namespace app */
130} /* namespace arm */
131
132#endif /* DETECTOR_POST_PROCESSING_HPP */