blob: e7ef7402f2a9cc5fd66da63ce60fd0ff2f42054f [file] [log] [blame]
keidav011b3e2ea2019-02-21 10:07:37 +00001//
2// Copyright © 2017 Arm Ltd. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "../TfLiteParser.hpp"
keidav011b3e2ea2019-02-21 10:07:37 +00007#include "ParserFlatbuffersFixture.hpp"
8#include "ParserPrototxtFixture.hpp"
keidav01222c7532019-03-14 17:12:10 +00009#include "ParserHelper.hpp"
Jan Eilersbb446e52020-04-02 13:56:54 +010010#include "test/GraphUtils.hpp"
keidav011b3e2ea2019-02-21 10:07:37 +000011
Jan Eilersbb446e52020-04-02 13:56:54 +010012#include <armnn/utility/PolymorphicDowncast.hpp>
Aron Virginas-Tar48623a02019-10-22 10:00:28 +010013#include <QuantizeHelper.hpp>
14
Sadik Armagan1625efc2021-06-10 18:24:34 +010015TEST_SUITE("TensorflowLiteParser_DetectionPostProcess")
16{
keidav011b3e2ea2019-02-21 10:07:37 +000017struct DetectionPostProcessFixture : ParserFlatbuffersFixture
18{
keidav01222c7532019-03-14 17:12:10 +000019 explicit DetectionPostProcessFixture(const std::string& custom_options)
keidav011b3e2ea2019-02-21 10:07:37 +000020 {
21 /*
22 The following values were used for the custom_options:
keidav0107d58c72019-02-26 11:57:39 +000023 use_regular_nms = true
keidav011b3e2ea2019-02-21 10:07:37 +000024 max_classes_per_detection = 1
keidav01222c7532019-03-14 17:12:10 +000025 detections_per_class = 1
keidav011b3e2ea2019-02-21 10:07:37 +000026 nms_score_threshold = 0.0
27 nms_iou_threshold = 0.5
28 max_detections = 3
29 max_detections = 3
30 num_classes = 2
31 h_scale = 5
32 w_scale = 5
33 x_scale = 10
34 y_scale = 10
35 */
36 m_JsonString = R"(
37 {
38 "version": 3,
39 "operator_codes": [{
40 "builtin_code": "CUSTOM",
41 "custom_code": "TFLite_Detection_PostProcess"
42 }],
43 "subgraphs": [{
44 "tensors": [{
45 "shape": [1, 6, 4],
46 "type": "UINT8",
47 "buffer": 0,
48 "name": "box_encodings",
49 "quantization": {
50 "min": [0.0],
51 "max": [255.0],
52 "scale": [1.0],
53 "zero_point": [ 1 ]
54 }
55 },
56 {
57 "shape": [1, 6, 3],
58 "type": "UINT8",
59 "buffer": 1,
60 "name": "scores",
61 "quantization": {
62 "min": [0.0],
63 "max": [255.0],
64 "scale": [0.01],
65 "zero_point": [0]
66 }
67 },
68 {
69 "shape": [6, 4],
70 "type": "UINT8",
71 "buffer": 2,
72 "name": "anchors",
73 "quantization": {
74 "min": [0.0],
75 "max": [255.0],
76 "scale": [0.5],
77 "zero_point": [0]
78 }
79 },
80 {
keidav011b3e2ea2019-02-21 10:07:37 +000081 "type": "FLOAT32",
82 "buffer": 3,
83 "name": "detection_boxes",
84 "quantization": {}
85 },
86 {
keidav011b3e2ea2019-02-21 10:07:37 +000087 "type": "FLOAT32",
88 "buffer": 4,
89 "name": "detection_classes",
90 "quantization": {}
91 },
92 {
keidav011b3e2ea2019-02-21 10:07:37 +000093 "type": "FLOAT32",
94 "buffer": 5,
95 "name": "detection_scores",
96 "quantization": {}
97 },
98 {
keidav011b3e2ea2019-02-21 10:07:37 +000099 "type": "FLOAT32",
100 "buffer": 6,
101 "name": "num_detections",
102 "quantization": {}
103 }
104 ],
105 "inputs": [0, 1, 2],
106 "outputs": [3, 4, 5, 6],
107 "operators": [{
108 "opcode_index": 0,
109 "inputs": [0, 1, 2],
110 "outputs": [3, 4, 5, 6],
111 "builtin_options_type": 0,
keidav01222c7532019-03-14 17:12:10 +0000112 "custom_options": [)" + custom_options + R"(],
keidav011b3e2ea2019-02-21 10:07:37 +0000113 "custom_options_format": "FLEXBUFFERS"
114 }]
115 }],
116 "buffers": [{},
117 {},
118 { "data": [ 1, 1, 2, 2,
119 1, 1, 2, 2,
120 1, 1, 2, 2,
121 1, 21, 2, 2,
122 1, 21, 2, 2,
123 1, 201, 2, 2]},
124 {},
125 {},
126 {},
127 {},
128 ]
129 }
130 )";
131 }
132};
133
keidav01222c7532019-03-14 17:12:10 +0000134struct ParseDetectionPostProcessCustomOptions : DetectionPostProcessFixture
135{
136private:
137 static armnn::DetectionPostProcessDescriptor GenerateDescriptor()
138 {
139 static armnn::DetectionPostProcessDescriptor descriptor;
140 descriptor.m_UseRegularNms = true;
141 descriptor.m_MaxDetections = 3u;
142 descriptor.m_MaxClassesPerDetection = 1u;
143 descriptor.m_DetectionsPerClass = 1u;
144 descriptor.m_NumClasses = 2u;
145 descriptor.m_NmsScoreThreshold = 0.0f;
146 descriptor.m_NmsIouThreshold = 0.5f;
147 descriptor.m_ScaleH = 5.0f;
148 descriptor.m_ScaleW = 5.0f;
149 descriptor.m_ScaleX = 10.0f;
150 descriptor.m_ScaleY = 10.0f;
151
152 return descriptor;
153 }
154
155public:
156 ParseDetectionPostProcessCustomOptions()
157 : DetectionPostProcessFixture(
158 GenerateDetectionPostProcessJsonString(GenerateDescriptor()))
159 {}
160};
161
Sadik Armagan1625efc2021-06-10 18:24:34 +0100162TEST_CASE_FIXTURE(ParseDetectionPostProcessCustomOptions, "ParseDetectionPostProcess")
keidav011b3e2ea2019-02-21 10:07:37 +0000163{
164 Setup();
165
166 // Inputs
167 using UnquantizedContainer = std::vector<float>;
168 UnquantizedContainer boxEncodings =
169 {
170 0.0f, 0.0f, 0.0f, 0.0f,
171 0.0f, 1.0f, 0.0f, 0.0f,
172 0.0f, -1.0f, 0.0f, 0.0f,
173 0.0f, 0.0f, 0.0f, 0.0f,
174 0.0f, 1.0f, 0.0f, 0.0f,
175 0.0f, 0.0f, 0.0f, 0.0f
176 };
177
178 UnquantizedContainer scores =
179 {
180 0.0f, 0.9f, 0.8f,
181 0.0f, 0.75f, 0.72f,
182 0.0f, 0.6f, 0.5f,
183 0.0f, 0.93f, 0.95f,
184 0.0f, 0.5f, 0.4f,
185 0.0f, 0.3f, 0.2f
186 };
187
188 // Outputs
189 UnquantizedContainer detectionBoxes =
190 {
191 0.0f, 10.0f, 1.0f, 11.0f,
192 0.0f, 10.0f, 1.0f, 11.0f,
193 0.0f, 0.0f, 0.0f, 0.0f
194 };
195
196 UnquantizedContainer detectionClasses = { 1.0f, 0.0f, 0.0f };
197 UnquantizedContainer detectionScores = { 0.95f, 0.93f, 0.0f };
198
199 UnquantizedContainer numDetections = { 2.0f };
200
201 // Quantize inputs and outputs
202 using QuantizedContainer = std::vector<uint8_t>;
Aron Virginas-Tar48623a02019-10-22 10:00:28 +0100203
204 QuantizedContainer quantBoxEncodings = armnnUtils::QuantizedVector<uint8_t>(boxEncodings, 1.00f, 1);
205 QuantizedContainer quantScores = armnnUtils::QuantizedVector<uint8_t>(scores, 0.01f, 0);
keidav011b3e2ea2019-02-21 10:07:37 +0000206
207 std::map<std::string, QuantizedContainer> input =
208 {
209 { "box_encodings", quantBoxEncodings },
210 { "scores", quantScores }
211 };
212
213 std::map<std::string, UnquantizedContainer> output =
214 {
215 { "detection_boxes", detectionBoxes},
216 { "detection_classes", detectionClasses},
217 { "detection_scores", detectionScores},
218 { "num_detections", numDetections}
219 };
220
Derek Lambertif90c56d2020-01-10 17:14:08 +0000221 RunTest<armnn::DataType::QAsymmU8, armnn::DataType::Float32>(0, input, output);
keidav011b3e2ea2019-02-21 10:07:37 +0000222}
223
Sadik Armagan1625efc2021-06-10 18:24:34 +0100224TEST_CASE_FIXTURE(ParseDetectionPostProcessCustomOptions, "DetectionPostProcessGraphStructureTest")
keidav011b3e2ea2019-02-21 10:07:37 +0000225{
226 /*
227 Inputs: box_encodings scores
228 \ /
229 DetectionPostProcess
230 / / \ \
231 / / \ \
232 Outputs: detection detection detection num_detections
233 boxes classes scores
234 */
235
236 ReadStringToBinary();
237
238 armnn::INetworkPtr network = m_Parser->CreateNetworkFromBinary(m_GraphBinary);
239
240 auto optimized = Optimize(*network, { armnn::Compute::CpuRef }, m_Runtime->GetDeviceSpec());
241
Francis Murtagh3d2b4b22021-02-15 18:23:17 +0000242 armnn::Graph& graph = GetGraphForTesting(optimized.get());
keidav011b3e2ea2019-02-21 10:07:37 +0000243
244 // Check the number of layers in the graph
Sadik Armagan1625efc2021-06-10 18:24:34 +0100245 CHECK((graph.GetNumInputs() == 2));
246 CHECK((graph.GetNumOutputs() == 4));
247 CHECK((graph.GetNumLayers() == 7));
keidav011b3e2ea2019-02-21 10:07:37 +0000248
249 // Input layers
250 armnn::Layer* boxEncodingLayer = GetFirstLayerWithName(graph, "box_encodings");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100251 CHECK((boxEncodingLayer->GetType() == armnn::LayerType::Input));
252 CHECK(CheckNumberOfInputSlot(boxEncodingLayer, 0));
253 CHECK(CheckNumberOfOutputSlot(boxEncodingLayer, 1));
keidav011b3e2ea2019-02-21 10:07:37 +0000254
255 armnn::Layer* scoresLayer = GetFirstLayerWithName(graph, "scores");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100256 CHECK((scoresLayer->GetType() == armnn::LayerType::Input));
257 CHECK(CheckNumberOfInputSlot(scoresLayer, 0));
258 CHECK(CheckNumberOfOutputSlot(scoresLayer, 1));
keidav011b3e2ea2019-02-21 10:07:37 +0000259
260 // DetectionPostProcess layer
261 armnn::Layer* detectionPostProcessLayer = GetFirstLayerWithName(graph, "DetectionPostProcess:0:0");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100262 CHECK((detectionPostProcessLayer->GetType() == armnn::LayerType::DetectionPostProcess));
263 CHECK(CheckNumberOfInputSlot(detectionPostProcessLayer, 2));
264 CHECK(CheckNumberOfOutputSlot(detectionPostProcessLayer, 4));
keidav011b3e2ea2019-02-21 10:07:37 +0000265
266 // Output layers
267 armnn::Layer* detectionBoxesLayer = GetFirstLayerWithName(graph, "detection_boxes");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100268 CHECK((detectionBoxesLayer->GetType() == armnn::LayerType::Output));
269 CHECK(CheckNumberOfInputSlot(detectionBoxesLayer, 1));
270 CHECK(CheckNumberOfOutputSlot(detectionBoxesLayer, 0));
keidav011b3e2ea2019-02-21 10:07:37 +0000271
272 armnn::Layer* detectionClassesLayer = GetFirstLayerWithName(graph, "detection_classes");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100273 CHECK((detectionClassesLayer->GetType() == armnn::LayerType::Output));
274 CHECK(CheckNumberOfInputSlot(detectionClassesLayer, 1));
275 CHECK(CheckNumberOfOutputSlot(detectionClassesLayer, 0));
keidav011b3e2ea2019-02-21 10:07:37 +0000276
277 armnn::Layer* detectionScoresLayer = GetFirstLayerWithName(graph, "detection_scores");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100278 CHECK((detectionScoresLayer->GetType() == armnn::LayerType::Output));
279 CHECK(CheckNumberOfInputSlot(detectionScoresLayer, 1));
280 CHECK(CheckNumberOfOutputSlot(detectionScoresLayer, 0));
keidav011b3e2ea2019-02-21 10:07:37 +0000281
282 armnn::Layer* numDetectionsLayer = GetFirstLayerWithName(graph, "num_detections");
Sadik Armagan1625efc2021-06-10 18:24:34 +0100283 CHECK((numDetectionsLayer->GetType() == armnn::LayerType::Output));
284 CHECK(CheckNumberOfInputSlot(numDetectionsLayer, 1));
285 CHECK(CheckNumberOfOutputSlot(numDetectionsLayer, 0));
keidav011b3e2ea2019-02-21 10:07:37 +0000286
287 // Check the connections
Derek Lambertif90c56d2020-01-10 17:14:08 +0000288 armnn::TensorInfo boxEncodingTensor(armnn::TensorShape({ 1, 6, 4 }), armnn::DataType::QAsymmU8, 1, 1);
289 armnn::TensorInfo scoresTensor(armnn::TensorShape({ 1, 6, 3 }), armnn::DataType::QAsymmU8,
keidav011b3e2ea2019-02-21 10:07:37 +0000290 0.00999999978f, 0);
291
292 armnn::TensorInfo detectionBoxesTensor(armnn::TensorShape({ 1, 3, 4 }), armnn::DataType::Float32, 0, 0);
293 armnn::TensorInfo detectionClassesTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32, 0, 0);
294 armnn::TensorInfo detectionScoresTensor(armnn::TensorShape({ 1, 3 }), armnn::DataType::Float32, 0, 0);
295 armnn::TensorInfo numDetectionsTensor(armnn::TensorShape({ 1} ), armnn::DataType::Float32, 0, 0);
296
Sadik Armagan1625efc2021-06-10 18:24:34 +0100297 CHECK(IsConnected(boxEncodingLayer, detectionPostProcessLayer, 0, 0, boxEncodingTensor));
298 CHECK(IsConnected(scoresLayer, detectionPostProcessLayer, 0, 1, scoresTensor));
299 CHECK(IsConnected(detectionPostProcessLayer, detectionBoxesLayer, 0, 0, detectionBoxesTensor));
300 CHECK(IsConnected(detectionPostProcessLayer, detectionClassesLayer, 1, 0, detectionClassesTensor));
301 CHECK(IsConnected(detectionPostProcessLayer, detectionScoresLayer, 2, 0, detectionScoresTensor));
302 CHECK(IsConnected(detectionPostProcessLayer, numDetectionsLayer, 3, 0, numDetectionsTensor));
keidav011b3e2ea2019-02-21 10:07:37 +0000303}
304
Sadik Armagan1625efc2021-06-10 18:24:34 +0100305}