blob: 30bc12399fc83e1056917ea8e711d36f2f72f19b [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef DETECTOR_POST_PROCESSING_HPP
18#define DETECTOR_POST_PROCESSING_HPP
19
Richard Burtoned35a6f2022-02-14 11:55:35 +000020#include "ImageUtils.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000021#include "DetectionResult.hpp"
22#include "YoloFastestModel.hpp"
Richard Burtonef904972022-04-27 17:24:36 +010023#include "BaseProcessing.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000024
25#include <forward_list>
26
27namespace arm {
28namespace app {
Richard Burtonef904972022-04-27 17:24:36 +010029
Isabella Gottardi3107aa22022-01-27 16:39:37 +000030namespace object_detection {
31
32 struct Branch {
33 int resolution;
34 int numBox;
35 const float* anchor;
36 int8_t* modelOutput;
37 float scale;
38 int zeroPoint;
39 size_t size;
40 };
41
42 struct Network {
43 int inputWidth;
44 int inputHeight;
45 int numClasses;
46 std::vector<Branch> branches;
47 int topN;
48 };
49
Richard Burtonef904972022-04-27 17:24:36 +010050} /* namespace object_detection */
51
Isabella Gottardi3107aa22022-01-27 16:39:37 +000052 /**
Richard Burtonef904972022-04-27 17:24:36 +010053 * @brief Post-processing class for Object Detection use case.
54 * Implements methods declared by BasePostProcess and anything else needed
55 * to populate result vector.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000056 */
Richard Burtonef904972022-04-27 17:24:36 +010057 class DetectorPostProcess : public BasePostProcess {
Isabella Gottardi3107aa22022-01-27 16:39:37 +000058 public:
59 /**
Richard Burtonef904972022-04-27 17:24:36 +010060 * @brief Constructor.
61 * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
62 * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
63 * @param[out] results Vector of detected results.
64 * @param[in] inputImgRows Number of rows in the input image.
65 * @param[in] inputImgCols Number of columns in the input image.
66 * @param[in] threshold Post-processing threshold.
67 * @param[in] nms Non-maximum Suppression threshold.
68 * @param[in] numClasses Number of classes.
69 * @param[in] topN Top N for each class.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000070 **/
Richard Burtonef904972022-04-27 17:24:36 +010071 explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
72 TfLiteTensor* outputTensor1,
73 std::vector<object_detection::DetectionResult>& results,
74 int inputImgRows,
75 int inputImgCols,
76 float threshold = 0.5f,
77 float nms = 0.45f,
78 int numClasses = 1,
79 int topN = 0);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000080
81 /**
Richard Burtonef904972022-04-27 17:24:36 +010082 * @brief Should perform YOLO post-processing of the result of inference then
83 * populate Detection result data for any later use.
84 * @return true if successful, false otherwise.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000085 **/
Richard Burtonef904972022-04-27 17:24:36 +010086 bool DoPostProcess() override;
Isabella Gottardi3107aa22022-01-27 16:39:37 +000087
88 private:
Richard Burtonef904972022-04-27 17:24:36 +010089 TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
90 TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
91 std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
92 int m_inputImgRows; /* Number of rows for model input. */
93 int m_inputImgCols; /* Number of cols for model input. */
94 float m_threshold; /* Post-processing threshold. */
95 float m_nms; /* NMS threshold. */
96 int m_numClasses; /* Number of classes. */
97 int m_topN; /* TopN. */
98 object_detection::Network m_net; /* YOLO network object. */
Isabella Gottardi3107aa22022-01-27 16:39:37 +000099
100 /**
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000101 * @brief Insert the given Detection in the list.
102 * @param[in] detections List of detections.
103 * @param[in] det Detection to be inserted.
104 **/
Richard Burtoned35a6f2022-02-14 11:55:35 +0000105 void InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000106
107 /**
108 * @brief Given a Network calculate the detection boxes.
109 * @param[in] net Network.
110 * @param[in] imageWidth Original image width.
111 * @param[in] imageHeight Original image height.
112 * @param[in] threshold Detections threshold.
113 * @param[out] detections Detection boxes.
114 **/
Richard Burtonef904972022-04-27 17:24:36 +0100115 void GetNetworkBoxes(object_detection::Network& net,
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000116 int imageWidth,
117 int imageHeight,
118 float threshold,
Richard Burtoned35a6f2022-02-14 11:55:35 +0000119 std::forward_list<image::Detection>& detections);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000120 };
121
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000122} /* namespace app */
123} /* namespace arm */
124
125#endif /* DETECTOR_POST_PROCESSING_HPP */