blob: 2f9f29c2232ad7879cf24e9c1503d619b26b2033 [file] [log] [blame]
keidav011b3e2ea2019-02-21 10:07:37 +00001//
Finn Williamsb49ed182021-06-29 15:50:08 +01002// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
keidav011b3e2ea2019-02-21 10:07:37 +00003// SPDX-License-Identifier: MIT
4//
5
keidav011b3e2ea2019-02-21 10:07:37 +00006#include "ParserFlatbuffersFixture.hpp"
7#include "ParserPrototxtFixture.hpp"
keidav01222c7532019-03-14 17:12:10 +00008#include "ParserHelper.hpp"
Sadik Armagana097d2a2021-11-24 15:47:28 +00009#include <GraphUtils.hpp>
keidav011b3e2ea2019-02-21 10:07:37 +000010
Jan Eilersbb446e52020-04-02 13:56:54 +010011#include <armnn/utility/PolymorphicDowncast.hpp>
Colm Donelanc42a9872022-02-02 16:35:09 +000012#include <armnnUtils/QuantizeHelper.hpp>
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010013
Sadik Armagan1625efc2021-06-10 18:24:34 +010014TEST_SUITE("TensorflowLiteParser_DetectionPostProcess")
15{
keidav011b3e2ea2019-02-21 10:07:37 +000016struct DetectionPostProcessFixture : ParserFlatbuffersFixture
17{
keidav01222c7532019-03-14 17:12:10 +000018 explicit DetectionPostProcessFixture(const std::string& custom_options)
keidav011b3e2ea2019-02-21 10:07:37 +000019 {
20 /*
21 The following values were used for the custom_options:
keidav0107d58c72019-02-26 11:57:39 +000022 use_regular_nms = true
keidav011b3e2ea2019-02-21 10:07:37 +000023 max_classes_per_detection = 1
keidav01222c7532019-03-14 17:12:10 +000024 detections_per_class = 1
keidav011b3e2ea2019-02-21 10:07:37 +000025 nms_score_threshold = 0.0
26 nms_iou_threshold = 0.5
27 max_detections = 3
28 max_detections = 3
29 num_classes = 2
30 h_scale = 5
31 w_scale = 5
32 x_scale = 10
33 y_scale = 10
34 */
35 m_JsonString = R"(
36 {
37 "version": 3,
38 "operator_codes": [{
39 "builtin_code": "CUSTOM",
40 "custom_code": "TFLite_Detection_PostProcess"
41 }],
42 "subgraphs": [{
43 "tensors": [{
44 "shape": [1, 6, 4],
45 "type": "UINT8",
46 "buffer": 0,
47 "name": "box_encodings",
48 "quantization": {
49 "min": [0.0],
50 "max": [255.0],
51 "scale": [1.0],
52 "zero_point": [ 1 ]
53 }
54 },
55 {
56 "shape": [1, 6, 3],
57 "type": "UINT8",
58 "buffer": 1,
59 "name": "scores",
60 "quantization": {
61 "min": [0.0],
62 "max": [255.0],
63 "scale": [0.01],
64 "zero_point": [0]
65 }
66 },
67 {
68 "shape": [6, 4],
69 "type": "UINT8",
70 "buffer": 2,
71 "name": "anchors",
72 "quantization": {
73 "min": [0.0],
74 "max": [255.0],
75 "scale": [0.5],
76 "zero_point": [0]
77 }
78 },
79 {
keidav011b3e2ea2019-02-21 10:07:37 +000080 "type": "FLOAT32",
81 "buffer": 3,
82 "name": "detection_boxes",
83 "quantization": {}
84 },
85 {
keidav011b3e2ea2019-02-21 10:07:37 +000086 "type": "FLOAT32",
87 "buffer": 4,
88 "name": "detection_classes",
89 "quantization": {}
90 },
91 {
keidav011b3e2ea2019-02-21 10:07:37 +000092 "type": "FLOAT32",
93 "buffer": 5,
94 "name": "detection_scores",
95 "quantization": {}
96 },
97 {
keidav011b3e2ea2019-02-21 10:07:37 +000098 "type": "FLOAT32",
99 "buffer": 6,
100 "name": "num_detections",
101 "quantization": {}
102 }
103 ],
104 "inputs": [0, 1, 2],
105 "outputs": [3, 4, 5, 6],
106 "operators": [{
107 "opcode_index": 0,
108 "inputs": [0, 1, 2],
109 "outputs": [3, 4, 5, 6],
110 "builtin_options_type": 0,
keidav01222c7532019-03-14 17:12:10 +0000111 "custom_options": [)" + custom_options + R"(],
keidav011b3e2ea2019-02-21 10:07:37 +0000112 "custom_options_format": "FLEXBUFFERS"
113 }]
114 }],
115 "buffers": [{},
116 {},
117 { "data": [ 1, 1, 2, 2,
118 1, 1, 2, 2,
119 1, 1, 2, 2,
120 1, 21, 2, 2,
121 1, 21, 2, 2,
122 1, 201, 2, 2]},
123 {},
124 {},
125 {},
126 {},
127 ]
128 }
129 )";
130 }
131};
132
keidav01222c7532019-03-14 17:12:10 +0000133struct ParseDetectionPostProcessCustomOptions : DetectionPostProcessFixture
134{
135private:
136 static armnn::DetectionPostProcessDescriptor GenerateDescriptor()
137 {
138 static armnn::DetectionPostProcessDescriptor descriptor;
139 descriptor.m_UseRegularNms = true;
140 descriptor.m_MaxDetections = 3u;
141 descriptor.m_MaxClassesPerDetection = 1u;
142 descriptor.m_DetectionsPerClass = 1u;
143 descriptor.m_NumClasses = 2u;
144 descriptor.m_NmsScoreThreshold = 0.0f;
145 descriptor.m_NmsIouThreshold = 0.5f;
146 descriptor.m_ScaleH = 5.0f;
147 descriptor.m_ScaleW = 5.0f;
148 descriptor.m_ScaleX = 10.0f;
149 descriptor.m_ScaleY = 10.0f;
150
151 return descriptor;
152 }
153
154public:
155 ParseDetectionPostProcessCustomOptions()
156 : DetectionPostProcessFixture(
157 GenerateDetectionPostProcessJsonString(GenerateDescriptor()))
158 {}
159};
160
Sadik Armagan1625efc2021-06-10 18:24:34 +0100161TEST_CASE_FIXTURE(ParseDetectionPostProcessCustomOptions, "ParseDetectionPostProcess")
keidav011b3e2ea2019-02-21 10:07:37 +0000162{
163 Setup();
164
165 // Inputs
166 using UnquantizedContainer = std::vector<float>;
167 UnquantizedContainer boxEncodings =
168 {
169 0.0f, 0.0f, 0.0f, 0.0f,
170 0.0f, 1.0f, 0.0f, 0.0f,
171 0.0f, -1.0f, 0.0f, 0.0f,
172 0.0f, 0.0f, 0.0f, 0.0f,
173 0.0f, 1.0f, 0.0f, 0.0f,
174 0.0f, 0.0f, 0.0f, 0.0f
175 };
176
177 UnquantizedContainer scores =
178 {
179 0.0f, 0.9f, 0.8f,
180 0.0f, 0.75f, 0.72f,
181 0.0f, 0.6f, 0.5f,
182 0.0f, 0.93f, 0.95f,
183 0.0f, 0.5f, 0.4f,
184 0.0f, 0.3f, 0.2f
185 };
186
187 // Outputs
188 UnquantizedContainer detectionBoxes =
189 {
190 0.0f, 10.0f, 1.0f, 11.0f,
191 0.0f, 10.0f, 1.0f, 11.0f,
192 0.0f, 0.0f, 0.0f, 0.0f
193 };
194
195 UnquantizedContainer detectionClasses = { 1.0f, 0.0f, 0.0f };
196 UnquantizedContainer detectionScores = { 0.95f, 0.93f, 0.0f };
197
198 UnquantizedContainer numDetections = { 2.0f };
199
200 // Quantize inputs and outputs
201 using QuantizedContainer = std::vector<uint8_t>;
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100202
203 QuantizedContainer quantBoxEncodings = armnnUtils::QuantizedVector<uint8_t>(boxEncodings, 1.00f, 1);
204 QuantizedContainer quantScores = armnnUtils::QuantizedVector<uint8_t>(scores, 0.01f, 0);
keidav011b3e2ea2019-02-21 10:07:37 +0000205
206 std::map<std::string, QuantizedContainer> input =
207 {
208 { "box_encodings", quantBoxEncodings },
209 { "scores", quantScores }
210 };
211
212 std::map<std::string, UnquantizedContainer> output =
213 {
214 { "detection_boxes", detectionBoxes},
215 { "detection_classes", detectionClasses},
216 { "detection_scores", detectionScores},
217 { "num_detections", numDetections}
218 };
219
Derek Lambertif90c56d2020-01-10 17:14:08 +0000220 RunTest<armnn::DataType::QAsymmU8, armnn::DataType::Float32>(0, input, output);
keidav011b3e2ea2019-02-21 10:07:37 +0000221}
222
Sadik Armagan1625efc2021-06-10 18:24:34 +0100223TEST_CASE_FIXTURE(ParseDetectionPostProcessCustomOptions, "DetectionPostProcessGraphStructureTest")
keidav011b3e2ea2019-02-21 10:07:37 +0000224{
225 /*
226 Inputs: box_encodings scores
227 \ /
228 DetectionPostProcess
229 / / \ \
230 / / \ \
231 Outputs: detection detection detection num_detections
232 boxes classes scores
233 */
234
235 ReadStringToBinary();
236
237 armnn::INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary);
238
239 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
240
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000241 armnn::Graph& graph = GetGraphForTesting(optimized.get());
keidav011b3e2ea2019-02-21 10:07:37 +0000242
243 // Check the number of layers in the graph
Sadik Armagan1625efc2021-06-10 18:24:34 +0100244 CHECK((graph.GetNumInputs() == 2));
245 CHECK((graph.GetNumOutputs() == 4));
246 CHECK((graph.GetNumLayers() == 7));
keidav011b3e2ea2019-02-21 10:07:37 +0000247
248 // Input layers
249 armnn::Layer* boxEncodingLayer = GetFirstLayerWithName(graph, "box_encodings");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100250 CHECK((boxEncodingLayer->GetType() == armnn::LayerType::Input));
251 CHECK(CheckNumberOfInputSlot(boxEncodingLayer, 0));
252 CHECK(CheckNumberOfOutputSlot(boxEncodingLayer, 1));
keidav011b3e2ea2019-02-21 10:07:37 +0000253
254 armnn::Layer* scoresLayer = GetFirstLayerWithName(graph, "scores");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100255 CHECK((scoresLayer->GetType() == armnn::LayerType::Input));
256 CHECK(CheckNumberOfInputSlot(scoresLayer, 0));
257 CHECK(CheckNumberOfOutputSlot(scoresLayer, 1));
keidav011b3e2ea2019-02-21 10:07:37 +0000258
259 // DetectionPostProcess layer
260 armnn::Layer* detectionPostProcessLayer = GetFirstLayerWithName(graph, "DetectionPostProcess:0:0");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100261 CHECK((detectionPostProcessLayer->GetType() == armnn::LayerType::DetectionPostProcess));
262 CHECK(CheckNumberOfInputSlot(detectionPostProcessLayer, 2));
263 CHECK(CheckNumberOfOutputSlot(detectionPostProcessLayer, 4));
keidav011b3e2ea2019-02-21 10:07:37 +0000264
265 // Output layers
266 armnn::Layer* detectionBoxesLayer = GetFirstLayerWithName(graph, "detection_boxes");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100267 CHECK((detectionBoxesLayer->GetType() == armnn::LayerType::Output));
268 CHECK(CheckNumberOfInputSlot(detectionBoxesLayer, 1));
269 CHECK(CheckNumberOfOutputSlot(detectionBoxesLayer, 0));
keidav011b3e2ea2019-02-21 10:07:37 +0000270
271 armnn::Layer* detectionClassesLayer = GetFirstLayerWithName(graph, "detection_classes");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100272 CHECK((detectionClassesLayer->GetType() == armnn::LayerType::Output));
273 CHECK(CheckNumberOfInputSlot(detectionClassesLayer, 1));
274 CHECK(CheckNumberOfOutputSlot(detectionClassesLayer, 0));
keidav011b3e2ea2019-02-21 10:07:37 +0000275
276 armnn::Layer* detectionScoresLayer = GetFirstLayerWithName(graph, "detection_scores");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100277 CHECK((detectionScoresLayer->GetType() == armnn::LayerType::Output));
278 CHECK(CheckNumberOfInputSlot(detectionScoresLayer, 1));
279 CHECK(CheckNumberOfOutputSlot(detectionScoresLayer, 0));
keidav011b3e2ea2019-02-21 10:07:37 +0000280
281 armnn::Layer* numDetectionsLayer = GetFirstLayerWithName(graph, "num_detections");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100282 CHECK((numDetectionsLayer->GetType() == armnn::LayerType::Output));
283 CHECK(CheckNumberOfInputSlot(numDetectionsLayer, 1));
284 CHECK(CheckNumberOfOutputSlot(numDetectionsLayer, 0));
keidav011b3e2ea2019-02-21 10:07:37 +0000285
286 // Check the connections
Derek Lambertif90c56d2020-01-10 17:14:08 +0000287 armnn::TensorInfo boxEncodingTensor(armnn::TensorShape({ 1, 6, 4 }), armnn::DataType::QAsymmU8, 1, 1);
288 armnn::TensorInfo scoresTensor(armnn::TensorShape({ 1, 6, 3 }), armnn::DataType::QAsymmU8,
keidav011b3e2ea2019-02-21 10:07:37 +0000289 0.00999999978f, 0);
290
291 armnn::TensorInfo detectionBoxesTensor(armnn::TensorShape({ 1, 3, 4 }), armnn::DataType::Float32, 0, 0);
292 armnn::TensorInfo detectionClassesTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32, 0, 0);
293 armnn::TensorInfo detectionScoresTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32, 0, 0);
294 armnn::TensorInfo numDetectionsTensor(armnn::TensorShape({ 1} ), armnn::DataType::Float32, 0, 0);
295
Sadik Armagan1625efc2021-06-10 18:24:34 +0100296 CHECK(IsConnected(boxEncodingLayer, detectionPostProcessLayer, 0, 0, boxEncodingTensor));
297 CHECK(IsConnected(scoresLayer, detectionPostProcessLayer, 0, 1, scoresTensor));
298 CHECK(IsConnected(detectionPostProcessLayer, detectionBoxesLayer, 0, 0, detectionBoxesTensor));
299 CHECK(IsConnected(detectionPostProcessLayer, detectionClassesLayer, 1, 0, detectionClassesTensor));
300 CHECK(IsConnected(detectionPostProcessLayer, detectionScoresLayer, 2, 0, detectionScoresTensor));
301 CHECK(IsConnected(detectionPostProcessLayer, numDetectionsLayer, 3, 0, numDetectionsTensor));
keidav011b3e2ea2019-02-21 10:07:37 +0000302}
303
Sadik Armagan1625efc2021-06-10 18:24:34 +0100304}