blob: 6a53688e7f09a4bba59fb9b5ac0c84599492bfb1 [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef DETECTOR_POST_PROCESSING_HPP
18#define DETECTOR_POST_PROCESSING_HPP
19
Richard Burtoned35a6f2022-02-14 11:55:35 +000020#include "ImageUtils.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000021#include "DetectionResult.hpp"
22#include "YoloFastestModel.hpp"
Richard Burtonef904972022-04-27 17:24:36 +010023#include "BaseProcessing.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000024
25#include <forward_list>
26
27namespace arm {
28namespace app {
29namespace object_detection {
30
31 struct Branch {
32 int resolution;
33 int numBox;
34 const float* anchor;
35 int8_t* modelOutput;
36 float scale;
37 int zeroPoint;
38 size_t size;
39 };
40
41 struct Network {
42 int inputWidth;
43 int inputHeight;
44 int numClasses;
45 std::vector<Branch> branches;
46 int topN;
47 };
48
Richard Burtonef904972022-04-27 17:24:36 +010049} /* namespace object_detection */
50
Isabella Gottardi3107aa22022-01-27 16:39:37 +000051 /**
Richard Burtonef904972022-04-27 17:24:36 +010052 * @brief Post-processing class for Object Detection use case.
53 * Implements methods declared by BasePostProcess and anything else needed
54 * to populate result vector.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000055 */
Richard Burtonef904972022-04-27 17:24:36 +010056 class DetectorPostProcess : public BasePostProcess {
Isabella Gottardi3107aa22022-01-27 16:39:37 +000057 public:
58 /**
Richard Burtonef904972022-04-27 17:24:36 +010059 * @brief Constructor.
60 * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
61 * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
62 * @param[out] results Vector of detected results.
63 * @param[in] inputImgRows Number of rows in the input image.
64 * @param[in] inputImgCols Number of columns in the input image.
65 * @param[in] threshold Post-processing threshold.
66 * @param[in] nms Non-maximum Suppression threshold.
67 * @param[in] numClasses Number of classes.
68 * @param[in] topN Top N for each class.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000069 **/
Richard Burtonef904972022-04-27 17:24:36 +010070 explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
71 TfLiteTensor* outputTensor1,
72 std::vector<object_detection::DetectionResult>& results,
73 int inputImgRows,
74 int inputImgCols,
75 float threshold = 0.5f,
76 float nms = 0.45f,
77 int numClasses = 1,
78 int topN = 0);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000079
80 /**
Richard Burtonef904972022-04-27 17:24:36 +010081 * @brief Should perform YOLO post-processing of the result of inference then
82 * populate Detection result data for any later use.
83 * @return true if successful, false otherwise.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000084 **/
Richard Burtonef904972022-04-27 17:24:36 +010085 bool DoPostProcess() override;
Isabella Gottardi3107aa22022-01-27 16:39:37 +000086
87 private:
Richard Burtonef904972022-04-27 17:24:36 +010088 TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
89 TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
90 std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
91 int m_inputImgRows; /* Number of rows for model input. */
92 int m_inputImgCols; /* Number of cols for model input. */
93 float m_threshold; /* Post-processing threshold. */
94 float m_nms; /* NMS threshold. */
95 int m_numClasses; /* Number of classes. */
96 int m_topN; /* TopN. */
97 object_detection::Network m_net; /* YOLO network object. */
Isabella Gottardi3107aa22022-01-27 16:39:37 +000098
99 /**
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000100 * @brief Insert the given Detection in the list.
101 * @param[in] detections List of detections.
102 * @param[in] det Detection to be inserted.
103 **/
Richard Burtoned35a6f2022-02-14 11:55:35 +0000104 void InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000105
106 /**
107 * @brief Given a Network calculate the detection boxes.
108 * @param[in] net Network.
109 * @param[in] imageWidth Original image width.
110 * @param[in] imageHeight Original image height.
111 * @param[in] threshold Detections threshold.
112 * @param[out] detections Detection boxes.
113 **/
Richard Burtonef904972022-04-27 17:24:36 +0100114 void GetNetworkBoxes(object_detection::Network& net,
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000115 int imageWidth,
116 int imageHeight,
117 float threshold,
Richard Burtoned35a6f2022-02-14 11:55:35 +0000118 std::forward_list<image::Detection>& detections);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000119 };
120
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000121} /* namespace app */
122} /* namespace arm */
123
124#endif /* DETECTOR_POST_PROCESSING_HPP */