blob: b3ddb2c37f15cfb6da617c12b09a664589daab9b [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef DETECTOR_POST_PROCESSING_HPP
18#define DETECTOR_POST_PROCESSING_HPP
19
20#include "UseCaseCommonUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000021#include "ImageUtils.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000022#include "DetectionResult.hpp"
23#include "YoloFastestModel.hpp"
Richard Burtonef904972022-04-27 17:24:36 +010024#include "BaseProcessing.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000025
26#include <forward_list>
27
28namespace arm {
29namespace app {
Richard Burtonef904972022-04-27 17:24:36 +010030
Isabella Gottardi3107aa22022-01-27 16:39:37 +000031namespace object_detection {
32
33 struct Branch {
34 int resolution;
35 int numBox;
36 const float* anchor;
37 int8_t* modelOutput;
38 float scale;
39 int zeroPoint;
40 size_t size;
41 };
42
43 struct Network {
44 int inputWidth;
45 int inputHeight;
46 int numClasses;
47 std::vector<Branch> branches;
48 int topN;
49 };
50
Richard Burtonef904972022-04-27 17:24:36 +010051} /* namespace object_detection */
52
Isabella Gottardi3107aa22022-01-27 16:39:37 +000053 /**
Richard Burtonef904972022-04-27 17:24:36 +010054 * @brief Post-processing class for Object Detection use case.
55 * Implements methods declared by BasePostProcess and anything else needed
56 * to populate result vector.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000057 */
Richard Burtonef904972022-04-27 17:24:36 +010058 class DetectorPostProcess : public BasePostProcess {
Isabella Gottardi3107aa22022-01-27 16:39:37 +000059 public:
60 /**
Richard Burtonef904972022-04-27 17:24:36 +010061 * @brief Constructor.
62 * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
63 * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
64 * @param[out] results Vector of detected results.
65 * @param[in] inputImgRows Number of rows in the input image.
66 * @param[in] inputImgCols Number of columns in the input image.
67 * @param[in] threshold Post-processing threshold.
68 * @param[in] nms Non-maximum Suppression threshold.
69 * @param[in] numClasses Number of classes.
70 * @param[in] topN Top N for each class.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000071 **/
Richard Burtonef904972022-04-27 17:24:36 +010072 explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
73 TfLiteTensor* outputTensor1,
74 std::vector<object_detection::DetectionResult>& results,
75 int inputImgRows,
76 int inputImgCols,
77 float threshold = 0.5f,
78 float nms = 0.45f,
79 int numClasses = 1,
80 int topN = 0);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000081
82 /**
Richard Burtonef904972022-04-27 17:24:36 +010083 * @brief Should perform YOLO post-processing of the result of inference then
84 * populate Detection result data for any later use.
85 * @return true if successful, false otherwise.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000086 **/
Richard Burtonef904972022-04-27 17:24:36 +010087 bool DoPostProcess() override;
Isabella Gottardi3107aa22022-01-27 16:39:37 +000088
89 private:
Richard Burtonef904972022-04-27 17:24:36 +010090 TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
91 TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
92 std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
93 int m_inputImgRows; /* Number of rows for model input. */
94 int m_inputImgCols; /* Number of cols for model input. */
95 float m_threshold; /* Post-processing threshold. */
96 float m_nms; /* NMS threshold. */
97 int m_numClasses; /* Number of classes. */
98 int m_topN; /* TopN. */
99 object_detection::Network m_net; /* YOLO network object. */
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000100
101 /**
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000102 * @brief Insert the given Detection in the list.
103 * @param[in] detections List of detections.
104 * @param[in] det Detection to be inserted.
105 **/
Richard Burtoned35a6f2022-02-14 11:55:35 +0000106 void InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000107
108 /**
109 * @brief Given a Network calculate the detection boxes.
110 * @param[in] net Network.
111 * @param[in] imageWidth Original image width.
112 * @param[in] imageHeight Original image height.
113 * @param[in] threshold Detections threshold.
114 * @param[out] detections Detection boxes.
115 **/
Richard Burtonef904972022-04-27 17:24:36 +0100116 void GetNetworkBoxes(object_detection::Network& net,
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000117 int imageWidth,
118 int imageHeight,
119 float threshold,
Richard Burtoned35a6f2022-02-14 11:55:35 +0000120 std::forward_list<image::Detection>& detections);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000121 };
122
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000123} /* namespace app */
124} /* namespace arm */
125
126#endif /* DETECTOR_POST_PROCESSING_HPP */