blob: fcebc5c4183a8db2593b354b0c53dda461f8213c [file] [log] [blame]
Georgios Pinitas28705162018-03-21 20:10:53 +00001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2018-2020 Arm Limited.
Georgios Pinitas28705162018-03-21 20:10:53 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouf4643372019-11-29 16:17:13 +000024#ifndef ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H
25#define ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H
Georgios Pinitas28705162018-03-21 20:10:53 +000026
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/graph/Logger.h"
28#include "arm_compute/graph/Tensor.h"
29#include "arm_compute/graph/Types.h"
30#include "arm_compute/graph/nodes/Nodes.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000031
32#include "arm_compute/core/Error.h"
Georgios Pinitascac13b12018-04-27 19:07:19 +010033#include "arm_compute/core/Helpers.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000034#include "arm_compute/core/ITensorInfo.h"
35
36namespace arm_compute
37{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010038namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +000039{
40namespace backends
41{
42namespace detail
43{
44/** Returns backing tensor info of a given tensor
45 *
46 * @param[in] tensor Tensor to extract the backing tensor from
47 *
48 * @return Backing tensor tensor info if present else nullptr
49 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010050inline arm_compute::ITensorInfo *get_backing_tensor_info(arm_compute::graph::Tensor *tensor)
Georgios Pinitas28705162018-03-21 20:10:53 +000051{
52 return ((tensor == nullptr) || (tensor->handle() == nullptr)) ? nullptr : tensor->handle()->tensor().info();
53}
54
Manuel Bottinid2048ce2018-10-23 17:00:42 +010055/** Validates a Bounding Box Transform layer node
56 *
57 * @tparam BoundingBoxTransformLayer Bounding Box Transform layer function type
58 *
59 * @param[in] node Node to validate
60 *
61 * @return Status
62 */
63template <typename BoundingBoxTransformLayer>
64Status validate_bounding_box_transform_layer(BoundingBoxTransformLayerNode &node)
65{
66 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating BoundingBoxTransformLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
67 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2);
68 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
69
70 // Extract IO and info
71 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
72 arm_compute::ITensorInfo *deltas = get_backing_tensor_info(node.input(1));
73 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
74 const BoundingBoxTransformInfo bbox_info = node.info();
75
76 return BoundingBoxTransformLayer::validate(input, output, deltas, bbox_info);
77}
78
Georgios Pinitas087eaf62018-05-16 15:52:35 +010079/** Validates a Channel Shuffle layer node
80 *
81 * @tparam ChannelShuffleLayer Channel Shuffle layer function type
82 *
83 * @param[in] node Node to validate
84 *
85 * @return Status
86 */
87template <typename ChannelShuffleLayer>
88Status validate_channel_shuffle_layer(ChannelShuffleLayerNode &node)
89{
90 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ChannelShuffle node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
91 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
92 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
93
94 // Extract IO and info
95 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
96 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
97 const unsigned int num_groups = node.num_groups();
98
99 return ChannelShuffleLayer::validate(input, output, num_groups);
100}
101
Georgios Pinitas28705162018-03-21 20:10:53 +0000102/** Validates a Convolution layer node
103 *
104 * @tparam ConvolutionLayer Default Convolution layer function type
105 * @tparam DirectConvolutionLayer Direct Convolution layer function type
106 * @tparam GEMMConvolutionLayer GEMM Convolution layer function type
107 * @tparam WinogradConvolutionLayer Winograd Convolution layer function type
108 *
109 * @param[in] node Node to validate
110 *
111 * @return Status
112 */
113template <typename ConvolutionLayer, typename DirectConvolutionLayer, typename GEMMConvolutionLayer, typename WinogradConvolutionLayer>
114Status validate_convolution_layer(ConvolutionLayerNode &node)
115{
116 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
117 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
118 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
119
120 // Extract IO and info
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100121 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
122 arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1));
123 arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
124 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
125
126 if(is_data_type_quantized_asymmetric(input->data_type()))
127 {
128 biases->set_data_type(DataType::S32);
129 }
130
131 const PadStrideInfo conv_info = node.convolution_info();
132 const ConvolutionMethod conv_algorithm = node.convolution_method();
Georgios Pinitase2220552018-07-20 13:23:44 +0100133 const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled;
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100134 const unsigned int num_groups = node.num_groups();
Georgios Pinitas28705162018-03-21 20:10:53 +0000135
136 // Validate function
137 Status status{};
138 switch(conv_algorithm)
139 {
Georgios Pinitase2220552018-07-20 13:23:44 +0100140 case ConvolutionMethod::Direct:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100141 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "DirectConvolutionLayer does not support grouping!");
Georgios Pinitas28705162018-03-21 20:10:53 +0000142 status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info);
143 break;
144 case ConvolutionMethod::GEMM:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100145 status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info,
146 WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups);
Georgios Pinitas28705162018-03-21 20:10:53 +0000147 break;
Georgios Pinitase2220552018-07-20 13:23:44 +0100148 case ConvolutionMethod::Winograd:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100149 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "WinogradConvolutionLayer does not support grouping!");
Georgios Pinitase2220552018-07-20 13:23:44 +0100150 status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math);
Georgios Pinitas28705162018-03-21 20:10:53 +0000151 break;
Georgios Pinitase2220552018-07-20 13:23:44 +0100152 case ConvolutionMethod::Default:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100153 status = ConvolutionLayer::validate(input, weights, biases, output, conv_info,
154 WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), fast_math, num_groups);
Georgios Pinitas54d6fae2018-05-10 15:50:14 +0100155 break;
Georgios Pinitas28705162018-03-21 20:10:53 +0000156 default:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100157 ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported convolution method");
Georgios Pinitas28705162018-03-21 20:10:53 +0000158 }
159
160 return status;
161}
162
163/** Validates a Depthwise Convolution layer node
164 *
165 * @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type
Georgios Pinitas28705162018-03-21 20:10:53 +0000166 *
167 * @param[in] node Node to validate
168 *
169 * @return Status
170 */
Manuel Bottini05069f02019-09-26 17:18:26 +0100171template <typename DepthwiseConvolutionLayer>
Georgios Pinitas28705162018-03-21 20:10:53 +0000172Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
173{
174 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
175 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
176 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
177
178 // Extract IO and info
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100179 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
180 arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1));
181 arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
182 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
Georgios Pinitas28705162018-03-21 20:10:53 +0000183
Georgios Pinitas05045c12018-12-07 18:31:47 +0000184 const PadStrideInfo conv_info = node.convolution_info();
185 const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
186 const int depth_multiplier = node.depth_multiplier();
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100187
Georgios Pinitas28705162018-03-21 20:10:53 +0000188 // Validate function
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100189 Status status{};
190 switch(dwc_algorithm)
Georgios Pinitas28705162018-03-21 20:10:53 +0000191 {
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100192 case DepthwiseConvolutionMethod::Default:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100193 case DepthwiseConvolutionMethod::Optimized3x3:
Manuel Bottini05069f02019-09-26 17:18:26 +0100194 status = DepthwiseConvolutionLayer::validate(input, weights, biases, output, conv_info, depth_multiplier);
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100195 break;
196 default:
197 ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported depthwise convolution method");
Georgios Pinitas28705162018-03-21 20:10:53 +0000198 }
199
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100200 return status;
Georgios Pinitas28705162018-03-21 20:10:53 +0000201}
thecha010a05e6a2020-08-28 18:40:38 +0100202/** Validates a depth to space layer node
203 *
204 * @tparam DequantizationLayer Dequantize layer type
205 *
206 * @param[in] node Node to validate
207 *
208 * @return Status
209 */
210template <typename DepthToSpaceLayer>
211Status validate_depth_to_space_layer(DepthToSpaceLayerNode &node)
212{
213 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
214 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
215 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
216
217 // Extract IO and info
218 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
219 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
220
221 return DepthToSpaceLayer::validate(input, output, node.block_shape());
222}
Isabella Gottardicd4e9ab2019-11-05 17:50:27 +0000223/** Validates a dequantize layer node
224 *
225 * @tparam DequantizationLayer Dequantize layer type
226 *
227 * @param[in] node Node to validate
228 *
229 * @return Status
230 */
231template <typename DequantizationLayer>
232Status validate_dequantization_layer(DequantizationLayerNode &node)
233{
234 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
235 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
236 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
Isabella Gottardi0ae5de92019-03-14 10:32:11 +0000237
Isabella Gottardicd4e9ab2019-11-05 17:50:27 +0000238 // Extract IO and info
239 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
240 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
241
242 return DequantizationLayer::validate(input, output);
243}
Isabella Gottardi7234ed82018-11-27 08:51:10 +0000244/** Validates a detection output layer node
245 *
246 * @tparam DetectionOutputLayer DetectionOutput layer type
247 *
248 * @param[in] node Node to validate
249 *
250 * @return Status
251 */
252template <typename DetectionOutputLayer>
253Status validate_detection_output_layer(DetectionOutputLayerNode &node)
254{
255 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
256 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
257 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
258
259 // Extract IO and info
260 arm_compute::ITensorInfo *input0 = get_backing_tensor_info(node.input(0));
261 arm_compute::ITensorInfo *input1 = get_backing_tensor_info(node.input(1));
262 arm_compute::ITensorInfo *input2 = get_backing_tensor_info(node.input(2));
263 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
264 const DetectionOutputLayerInfo detect_info = node.detection_output_info();
265
266 return DetectionOutputLayer::validate(input0, input1, input2, output, detect_info);
267}
Isabella Gottardia7acb3c2019-01-08 13:48:44 +0000268/** Validates a detection post process layer node
269 *
270 * @tparam DetectionPostProcessLayer DetectionOutput layer type
271 *
272 * @param[in] node Node to validate
273 *
274 * @return Status
275 */
276template <typename DetectionPostProcessLayer>
277Status validate_detection_post_process_layer(DetectionPostProcessLayerNode &node)
278{
279 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionPostProcessLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
280 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
281 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 4);
282
283 // Extract IO and info
284 arm_compute::ITensorInfo *input0 = get_backing_tensor_info(node.input(0));
285 arm_compute::ITensorInfo *input1 = get_backing_tensor_info(node.input(1));
286 arm_compute::ITensorInfo *input2 = get_backing_tensor_info(node.input(2));
287 arm_compute::ITensorInfo *output0 = get_backing_tensor_info(node.output(0));
288 arm_compute::ITensorInfo *output1 = get_backing_tensor_info(node.output(1));
289 arm_compute::ITensorInfo *output2 = get_backing_tensor_info(node.output(2));
290 arm_compute::ITensorInfo *output3 = get_backing_tensor_info(node.output(3));
291 const DetectionPostProcessLayerInfo detect_info = node.detection_post_process_info();
292
293 return DetectionPostProcessLayer::validate(input0, input1, input2, output0, output1, output2, output3, detect_info);
294}
Georgios Pinitas57c48242018-08-02 13:41:49 +0100295
Manuel Bottini5209be52019-02-13 16:34:56 +0000296/** Validates a Generate Proposals layer node
297 *
298 * @tparam GenerateProposalsLayer Generate Proposals layer type
299 *
300 * @param[in] node Node to validate
301 *
302 * @return Status
303 */
304template <typename GenerateProposalsLayer>
305Status validate_generate_proposals_layer(GenerateProposalsLayerNode &node)
306{
307 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating GenerateProposalsLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
308 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
309 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 3);
310
311 // Extract IO and info
312 arm_compute::ITensorInfo *scores = detail::get_backing_tensor_info(node.input(0));
313 arm_compute::ITensorInfo *deltas = detail::get_backing_tensor_info(node.input(1));
314 arm_compute::ITensorInfo *anchors = detail::get_backing_tensor_info(node.input(2));
315 arm_compute::ITensorInfo *proposals = get_backing_tensor_info(node.output(0));
316 arm_compute::ITensorInfo *scores_out = get_backing_tensor_info(node.output(1));
317 arm_compute::ITensorInfo *num_valid_proposals = get_backing_tensor_info(node.output(2));
318 const GenerateProposalsInfo info = node.info();
319
320 return GenerateProposalsLayer::validate(scores, deltas, anchors, proposals, scores_out, num_valid_proposals, info);
321}
322
Michele Di Giorgio555d1102018-09-12 13:51:59 +0100323/** Validates a NormalizePlanarYUV layer node
324 *
325 * @tparam NormalizePlanarYUVLayer layer type
326 *
327 * @param[in] node Node to validate
328 *
329 * @return Status
330 */
331template <typename NormalizePlanarYUVLayer>
332Status validate_normalize_planar_yuv_layer(NormalizePlanarYUVLayerNode &node)
333{
334 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating NormalizePlanarYUVLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
335 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
336 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
337
338 // Extract IO and info
339 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
340 arm_compute::ITensorInfo *mean = detail::get_backing_tensor_info(node.input(1));
341 arm_compute::ITensorInfo *std = detail::get_backing_tensor_info(node.input(2));
342 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
343
344 // Validate function
345 return NormalizePlanarYUVLayer::validate(input, output, mean, std);
346}
Michele Di Giorgio4bb17332018-09-26 13:56:51 +0100347
348/** Validates a pad layer node
349 *
350 * @tparam PadLayer Pad layer type
351 *
352 * @param[in] node Node to validate
353 *
354 * @return Status
355 */
356template <typename PadLayer>
357Status validate_pad_layer(PadLayerNode &node)
358{
359 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PadLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
360 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
361 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
362
363 // Extract IO and info
364 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
365 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
366 const PaddingList &padding = node.padding();
367
368 return PadLayer::validate(input, output, padding);
369}
370
Georgios Pinitas57c48242018-08-02 13:41:49 +0100371/** Validates a permute layer node
372 *
373 * @tparam PermuteLayer Permute layer type
374 *
375 * @param[in] node Node to validate
376 *
377 * @return Status
378 */
379template <typename PermuteLayer>
380Status validate_permute_layer(PermuteLayerNode &node)
381{
382 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PermuteLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
383 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
384 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
385
386 // Extract IO and info
387 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
388 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
389 const PermutationVector &perm = node.permutation_vector();
390
391 return PermuteLayer::validate(input, output, perm);
392}
Georgios Pinitasf8c47492020-02-04 17:39:59 +0000393
394/** Validates a PRelu layer node
395 *
396 * @tparam PReluLayer PRelu layer type
397 *
398 * @param[in] node Node to validate
399 *
400 * @return Status
401 */
402template <typename PReluLayer>
403Status validate_prelu_layer(PReluLayerNode &node)
404{
405 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PRelu node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
406 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2);
407 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
408
409 // Extract IO and info
410 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
411 arm_compute::ITensorInfo *alpha = get_backing_tensor_info(node.input(1));
412 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
413
414 return PReluLayer::validate(input, alpha, output);
415}
416
Pablo Tello32521432018-11-15 14:43:10 +0000417/** Validates a priorbox layer node
418 *
419 * @tparam PriorBoxLayer PriorBox layer type
420 *
421 * @param[in] node Node to validate
422 *
423 * @return Status
424 */
425template <typename PriorBoxLayer>
426Status validate_priorbox_layer(PriorBoxLayerNode &node)
427{
428 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PriorBoxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
429 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2);
430 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
431
432 // Extract IO and info
433 arm_compute::ITensorInfo *input0 = get_backing_tensor_info(node.input(0));
434 arm_compute::ITensorInfo *input1 = get_backing_tensor_info(node.input(1));
435 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
436 const PriorBoxLayerInfo prior_info = node.priorbox_info();
437
438 return PriorBoxLayer::validate(input0, input1, output, prior_info);
439}
Gian Marco Iodice23e24792018-09-07 15:32:14 +0100440
Isabella Gottardi3db1ba92019-05-17 12:35:20 +0100441/** Validates a Quantization layer node
442 *
443 * @tparam QuantizationLayer Quantization layer type
444 *
445 * @param[in] node Node to validate
446 *
447 * @return Status
448 */
449template <typename QuantizationLayer>
450Status validate_quantization_layer(QuantizationLayerNode &node)
451{
452 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating QuantizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
453 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
454 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
455
456 // Extract input and output
457 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
458 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
459
460 // Validate function
461 return QuantizationLayer::validate(input, output);
462}
463
Gian Marco Iodice23e24792018-09-07 15:32:14 +0100464/** Validates a Reorg layer node
465 *
466 * @tparam ReorgLayer Reorg layer type
467 *
468 * @param[in] node Node to validate
469 *
470 * @return Status
471 */
472template <typename ReorgLayer>
473Status validate_reorg_layer(ReorgLayerNode &node)
474{
475 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReorgLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
476 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
477 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
478
479 // Extract input and output
480 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
481 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
482
483 // Validate function
484 return ReorgLayer::validate(input, output, node.stride());
485}
Michalis Spyrou96f67692018-09-13 11:39:28 +0100486
Isabella Gottardi0ae5de92019-03-14 10:32:11 +0000487/** Validates a Reshape layer node
488 *
489 * @tparam ReshapeLayer Reshape layer type
490 *
491 * @param[in] node Node to validate
492 *
493 * @return Status
494 */
495template <typename ReshapeLayer>
496Status validate_reshape_layer(ReshapeLayerNode &node)
497{
498 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReshapeLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
499 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
500 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
501
502 // Extract input and output
503 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
504 arm_compute::ITensorInfo *output = detail::get_backing_tensor_info(node.output(0));
505
506 // Validate function
507 return ReshapeLayer::validate(input, output);
508}
509
Manuel Bottini3f9d4d72018-10-19 14:04:42 +0100510/** Validates a ROI Align layer node
511 *
512 * @tparam ROIAlignLayer ROIAlign layer type
513 *
514 * @param[in] node Node to validate
515 *
516 * @return Status
517 */
518template <typename ROIAlignLayer>
519Status validate_roi_align_layer(ROIAlignLayerNode &node)
520{
521 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ROIAlignLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
522 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2);
523 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
524
525 // Extract input and output
526 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
527 arm_compute::ITensorInfo *rois = detail::get_backing_tensor_info(node.input(1));
528 arm_compute::ITensorInfo *output = detail::get_backing_tensor_info(node.output(0));
529 const ROIPoolingLayerInfo &pool_info = node.pooling_info();
530
531 // Validate function
532 return ROIAlignLayer::validate(input, rois, output, pool_info);
533}
534
Michele Di Giorgioc30b6682018-09-12 17:44:08 +0100535/** Validates a Slice layer node
536 *
537 * @tparam SliceLayer Slice layer function type
538 *
539 * @param[in] node Node to validate
540 *
541 * @return Status
542 */
543template <typename SliceLayer>
544Status validate_slice_layer(SliceLayerNode &node)
545{
546 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating Slice node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
547 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
548 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
549
550 // Extract IO and info
551 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
552 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
553 const Coordinates starts = node.starts();
554 const Coordinates ends = node.ends();
555
556 return SliceLayer::validate(input, output, starts, ends);
557}
558
thecha012bfadd92020-08-12 17:25:51 +0100559/** Validates a Strided Slice layer node
560 *
561 * @tparam StridedSliceLayer Strided Slice layer function type
562 *
563 * @param[in] node Node to validate
564 *
565 * @return Status
566 */
567template <typename StridedSliceLayer>
568Status validate_strided_slice_layer(StridedSliceLayerNode &node)
569{
570 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating StridedSlice node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
571 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
572 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
573
574 // Extract IO and info
575 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
576 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
577 const Coordinates starts = node.starts();
578 const Coordinates ends = node.ends();
579 const BiStrides strides = node.strides();
580 const StridedSliceLayerInfo info = node.strided_slice_info();
581
582 return StridedSliceLayer::validate(input, output, starts, ends, strides, info.begin_mask(), info.end_mask(), info.shrink_axis_mask());
583}
584
Michalis Spyrou4e1c3f32018-09-20 17:14:03 +0100585/** Validates a Upsample layer node
586 *
587 * @tparam UpsampleLayer Upsample layer type
588 *
589 * @param[in] node Node to validate
590 *
591 * @return Status
592 */
593template <typename UpsampleLayer>
594Status validate_upsample_layer(UpsampleLayerNode &node)
595{
596 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating UpsampleLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
597 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
598 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
599
600 // Extract input and output
601 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
602 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
603
604 // Validate function
605 return UpsampleLayer::validate(input, output, node.info(), node.upsampling_policy());
606}
Michalis Spyrou96f67692018-09-13 11:39:28 +0100607/** Validates a YOLO layer node
608 *
609 * @tparam YOLOLayer YOLO layer type
610 *
611 * @param[in] node Node to validate
612 *
613 * @return Status
614 */
615template <typename YOLOLayer>
616Status validate_yolo_layer(YOLOLayerNode &node)
617{
618 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating YOLOLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
619 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
620 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
621
622 // Extract input and output
623 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
624 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
625
626 // Validate function
627 return YOLOLayer::validate(input, output, node.activation_info(), node.num_classes());
628}
Sheri Zhang16dddd22020-05-27 15:03:48 +0100629/** Validates a element-wise layer node
630 *
631 * @param[in] node Node to validate
632 *
633 * @return Status
634 */
635template <typename EltwiseLayerFunctions>
636Status validate_eltwise_Layer(EltwiseLayerNode &node)
637{
638 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
639 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2);
640 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
641
642 // Extract input and output
643 const arm_compute::ITensorInfo *input1 = detail::get_backing_tensor_info(node.input(0));
644 const arm_compute::ITensorInfo *input2 = detail::get_backing_tensor_info(node.input(1));
645 const arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
646 const EltwiseOperation eltwise_op = node.eltwise_operation();
647 const ConvertPolicy convert_policy = node.convert_policy();
648 const RoundingPolicy round_policy = node.rounding_policy();
649 const ActivationLayerInfo act_info = node.fused_activation();
650 const QuantizationInfo quant_info = node.output_quant_info();
Sheri Zhang16dddd22020-05-27 15:03:48 +0100651
652 // Validate function
653 if(eltwise_op == EltwiseOperation::Add)
654 {
655 return EltwiseLayerFunctions::ArithmeticAddition::validate(input1, input2, output, convert_policy, act_info);
656 }
657 else if(eltwise_op == EltwiseOperation::Sub)
658 {
659 return EltwiseLayerFunctions::ArithmeticSubtraction::validate(input1, input2, output, convert_policy, act_info);
660 }
661 else if(eltwise_op == EltwiseOperation::Mul)
662 {
Manuel Bottini7e725cf2020-08-12 16:05:16 +0100663 return EltwiseLayerFunctions::PixelWiseMultiplication::validate(input1, input2, output, 1.0f, convert_policy, round_policy, act_info);
Sheri Zhang16dddd22020-05-27 15:03:48 +0100664 }
thecha01f8e35842020-07-28 17:28:17 +0100665 else if(eltwise_op == EltwiseOperation::Max)
666 {
667 return EltwiseLayerFunctions::ElementwiseMax::validate(input1, input2, output, act_info);
668 }
Sheri Zhang16dddd22020-05-27 15:03:48 +0100669 else
670 {
671 ARM_COMPUTE_ERROR("Unsupported element-wise operation!");
672 }
673 return Status{};
674}
675/** Validates a unary element-wise layer node
676 *
677 * @param[in] node Node to validate
678 *
679 * @return Status
680 */
681template <typename UnaryEltwiseLayerFunctions>
682Status validate_unary_eltwise_layer(UnaryEltwiseLayerNode &node)
683{
684 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating EltwiseLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
685 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
686 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
687
688 // Extract input and output
689 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
690 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
691 const UnaryEltwiseOperation eltwise_op = node.eltwise_descriptor().op;
692
693 // Validate function
694 if(eltwise_op == UnaryEltwiseOperation::Exp)
695 {
696 return UnaryEltwiseLayerFunctions::ExpLayer::validate(input, output);
697 }
698 else
699 {
700 ARM_COMPUTE_ERROR("Unsupported unary element-wise operation!");
701 }
702
703 return Status{};
704}
Georgios Pinitas28705162018-03-21 20:10:53 +0000705} // namespace detail
706} // namespace backends
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100707} // namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +0000708} // namespace arm_compute
709
Michalis Spyrouf4643372019-11-29 16:17:13 +0000710#endif /* ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H */