blob: 3bc8e53bd6e72127e16a4f82ecbedfac1783ce20 [file] [log] [blame]
Isabella Gottardi3107aa22022-01-27 16:39:37 +00001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Isabella Gottardi3107aa22022-01-27 16:39:37 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef DETECTOR_POST_PROCESSING_HPP
18#define DETECTOR_POST_PROCESSING_HPP
19
Richard Burtoned35a6f2022-02-14 11:55:35 +000020#include "ImageUtils.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000021#include "DetectionResult.hpp"
22#include "YoloFastestModel.hpp"
Richard Burtonef904972022-04-27 17:24:36 +010023#include "BaseProcessing.hpp"
Isabella Gottardi3107aa22022-01-27 16:39:37 +000024
25#include <forward_list>
26
27namespace arm {
28namespace app {
29namespace object_detection {
30
Richard Burton6f6df092022-05-17 12:52:50 +010031 struct PostProcessParams {
32 int inputImgRows{};
33 int inputImgCols{};
34 int originalImageSize{};
35 const float* anchor1;
36 const float* anchor2;
37 float threshold = 0.5f;
38 float nms = 0.45f;
39 int numClasses = 1;
40 int topN = 0;
41 };
42
Isabella Gottardi3107aa22022-01-27 16:39:37 +000043 struct Branch {
44 int resolution;
45 int numBox;
46 const float* anchor;
47 int8_t* modelOutput;
48 float scale;
49 int zeroPoint;
50 size_t size;
51 };
52
53 struct Network {
54 int inputWidth;
55 int inputHeight;
56 int numClasses;
57 std::vector<Branch> branches;
58 int topN;
59 };
60
Richard Burtonef904972022-04-27 17:24:36 +010061} /* namespace object_detection */
62
Isabella Gottardi3107aa22022-01-27 16:39:37 +000063 /**
Richard Burtonef904972022-04-27 17:24:36 +010064 * @brief Post-processing class for Object Detection use case.
65 * Implements methods declared by BasePostProcess and anything else needed
66 * to populate result vector.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000067 */
Richard Burtonef904972022-04-27 17:24:36 +010068 class DetectorPostProcess : public BasePostProcess {
Isabella Gottardi3107aa22022-01-27 16:39:37 +000069 public:
70 /**
Richard Burtonef904972022-04-27 17:24:36 +010071 * @brief Constructor.
Richard Burton6f6df092022-05-17 12:52:50 +010072 * @param[in] outputTensor0 Pointer to the TFLite Micro output Tensor at index 0.
73 * @param[in] outputTensor1 Pointer to the TFLite Micro output Tensor at index 1.
74 * @param[out] results Vector of detected results.
75 * @param[in] postProcessParams Struct of various parameters used in post-processing.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000076 **/
Richard Burtonef904972022-04-27 17:24:36 +010077 explicit DetectorPostProcess(TfLiteTensor* outputTensor0,
78 TfLiteTensor* outputTensor1,
79 std::vector<object_detection::DetectionResult>& results,
Richard Burton6f6df092022-05-17 12:52:50 +010080 const object_detection::PostProcessParams& postProcessParams);
Isabella Gottardi3107aa22022-01-27 16:39:37 +000081
82 /**
Richard Burtonef904972022-04-27 17:24:36 +010083 * @brief Should perform YOLO post-processing of the result of inference then
84 * populate Detection result data for any later use.
85 * @return true if successful, false otherwise.
Isabella Gottardi3107aa22022-01-27 16:39:37 +000086 **/
Richard Burtonef904972022-04-27 17:24:36 +010087 bool DoPostProcess() override;
Isabella Gottardi3107aa22022-01-27 16:39:37 +000088
89 private:
Richard Burton6f6df092022-05-17 12:52:50 +010090 TfLiteTensor* m_outputTensor0; /* Output tensor index 0 */
91 TfLiteTensor* m_outputTensor1; /* Output tensor index 1 */
92 std::vector<object_detection::DetectionResult>& m_results; /* Single inference results. */
93 const object_detection::PostProcessParams& m_postProcessParams; /* Post processing param struct. */
94 object_detection::Network m_net; /* YOLO network object. */
Isabella Gottardi3107aa22022-01-27 16:39:37 +000095
96 /**
Isabella Gottardi3107aa22022-01-27 16:39:37 +000097 * @brief Insert the given Detection in the list.
98 * @param[in] detections List of detections.
99 * @param[in] det Detection to be inserted.
100 **/
Richard Burtoned35a6f2022-02-14 11:55:35 +0000101 void InsertTopNDetections(std::forward_list<image::Detection>& detections, image::Detection& det);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000102
103 /**
104 * @brief Given a Network calculate the detection boxes.
105 * @param[in] net Network.
106 * @param[in] imageWidth Original image width.
107 * @param[in] imageHeight Original image height.
108 * @param[in] threshold Detections threshold.
109 * @param[out] detections Detection boxes.
110 **/
Richard Burtonef904972022-04-27 17:24:36 +0100111 void GetNetworkBoxes(object_detection::Network& net,
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000112 int imageWidth,
113 int imageHeight,
114 float threshold,
Richard Burtoned35a6f2022-02-14 11:55:35 +0000115 std::forward_list<image::Detection>& detections);
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000116 };
117
Isabella Gottardi3107aa22022-01-27 16:39:37 +0000118} /* namespace app */
119} /* namespace arm */
120
121#endif /* DETECTOR_POST_PROCESSING_HPP */