blob: 66de7ad90407435ce86a696397f75b42aa250773 [file] [log] [blame]
Georgios Pinitas28705162018-03-21 20:10:53 +00001/*
Giuseppe Rossinibb365de2019-02-15 10:24:47 +00002 * Copyright (c) 2018-2019 ARM Limited.
Georgios Pinitas28705162018-03-21 20:10:53 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouf4643372019-11-29 16:17:13 +000024#ifndef ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H
25#define ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H
Georgios Pinitas28705162018-03-21 20:10:53 +000026
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/graph/Logger.h"
28#include "arm_compute/graph/Tensor.h"
29#include "arm_compute/graph/Types.h"
30#include "arm_compute/graph/nodes/Nodes.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000031
32#include "arm_compute/core/Error.h"
Georgios Pinitascac13b12018-04-27 19:07:19 +010033#include "arm_compute/core/Helpers.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000034#include "arm_compute/core/ITensorInfo.h"
35
36namespace arm_compute
37{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010038namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +000039{
40namespace backends
41{
42namespace detail
43{
44/** Returns backing tensor info of a given tensor
45 *
46 * @param[in] tensor Tensor to extract the backing tensor from
47 *
48 * @return Backing tensor tensor info if present else nullptr
49 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010050inline arm_compute::ITensorInfo *get_backing_tensor_info(arm_compute::graph::Tensor *tensor)
Georgios Pinitas28705162018-03-21 20:10:53 +000051{
52 return ((tensor == nullptr) || (tensor->handle() == nullptr)) ? nullptr : tensor->handle()->tensor().info();
53}
54
Manuel Bottinid2048ce2018-10-23 17:00:42 +010055/** Validates a Bounding Box Transform layer node
56 *
57 * @tparam BoundingBoxTransformLayer Bounding Box Transform layer function type
58 *
59 * @param[in] node Node to validate
60 *
61 * @return Status
62 */
63template <typename BoundingBoxTransformLayer>
64Status validate_bounding_box_transform_layer(BoundingBoxTransformLayerNode &node)
65{
66 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating BoundingBoxTransformLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
67 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2);
68 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
69
70 // Extract IO and info
71 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
72 arm_compute::ITensorInfo *deltas = get_backing_tensor_info(node.input(1));
73 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
74 const BoundingBoxTransformInfo bbox_info = node.info();
75
76 return BoundingBoxTransformLayer::validate(input, output, deltas, bbox_info);
77}
78
Georgios Pinitas087eaf62018-05-16 15:52:35 +010079/** Validates a Channel Shuffle layer node
80 *
81 * @tparam ChannelShuffleLayer Channel Shuffle layer function type
82 *
83 * @param[in] node Node to validate
84 *
85 * @return Status
86 */
87template <typename ChannelShuffleLayer>
88Status validate_channel_shuffle_layer(ChannelShuffleLayerNode &node)
89{
90 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ChannelShuffle node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
91 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
92 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
93
94 // Extract IO and info
95 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
96 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
97 const unsigned int num_groups = node.num_groups();
98
99 return ChannelShuffleLayer::validate(input, output, num_groups);
100}
101
Georgios Pinitas28705162018-03-21 20:10:53 +0000102/** Validates a Convolution layer node
103 *
104 * @tparam ConvolutionLayer Default Convolution layer function type
105 * @tparam DirectConvolutionLayer Direct Convolution layer function type
106 * @tparam GEMMConvolutionLayer GEMM Convolution layer function type
107 * @tparam WinogradConvolutionLayer Winograd Convolution layer function type
108 *
109 * @param[in] node Node to validate
110 *
111 * @return Status
112 */
113template <typename ConvolutionLayer, typename DirectConvolutionLayer, typename GEMMConvolutionLayer, typename WinogradConvolutionLayer>
114Status validate_convolution_layer(ConvolutionLayerNode &node)
115{
116 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
117 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
118 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
119
120 // Extract IO and info
Giorgio Arenabb54e4e2018-04-05 17:20:34 +0100121 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
122 arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1));
123 arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
124 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
125
126 if(is_data_type_quantized_asymmetric(input->data_type()))
127 {
128 biases->set_data_type(DataType::S32);
129 }
130
131 const PadStrideInfo conv_info = node.convolution_info();
132 const ConvolutionMethod conv_algorithm = node.convolution_method();
Georgios Pinitase2220552018-07-20 13:23:44 +0100133 const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled;
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100134 const unsigned int num_groups = node.num_groups();
Georgios Pinitas28705162018-03-21 20:10:53 +0000135
136 // Validate function
137 Status status{};
138 switch(conv_algorithm)
139 {
Georgios Pinitase2220552018-07-20 13:23:44 +0100140 case ConvolutionMethod::Direct:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100141 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "DirectConvolutionLayer does not support grouping!");
Georgios Pinitas28705162018-03-21 20:10:53 +0000142 status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info);
143 break;
144 case ConvolutionMethod::GEMM:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100145 status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info,
146 WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups);
Georgios Pinitas28705162018-03-21 20:10:53 +0000147 break;
Georgios Pinitase2220552018-07-20 13:23:44 +0100148 case ConvolutionMethod::Winograd:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100149 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "WinogradConvolutionLayer does not support grouping!");
Georgios Pinitase2220552018-07-20 13:23:44 +0100150 status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math);
Georgios Pinitas28705162018-03-21 20:10:53 +0000151 break;
Georgios Pinitase2220552018-07-20 13:23:44 +0100152 case ConvolutionMethod::Default:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100153 status = ConvolutionLayer::validate(input, weights, biases, output, conv_info,
154 WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), fast_math, num_groups);
Georgios Pinitas54d6fae2018-05-10 15:50:14 +0100155 break;
Georgios Pinitas28705162018-03-21 20:10:53 +0000156 default:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100157 ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported convolution method");
Georgios Pinitas28705162018-03-21 20:10:53 +0000158 }
159
160 return status;
161}
162
163/** Validates a Depthwise Convolution layer node
164 *
165 * @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type
Georgios Pinitas28705162018-03-21 20:10:53 +0000166 *
167 * @param[in] node Node to validate
168 *
169 * @return Status
170 */
Manuel Bottini05069f02019-09-26 17:18:26 +0100171template <typename DepthwiseConvolutionLayer>
Georgios Pinitas28705162018-03-21 20:10:53 +0000172Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
173{
174 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
175 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
176 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
177
178 // Extract IO and info
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100179 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
180 arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1));
181 arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
182 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
Georgios Pinitas28705162018-03-21 20:10:53 +0000183
Georgios Pinitas05045c12018-12-07 18:31:47 +0000184 const PadStrideInfo conv_info = node.convolution_info();
185 const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
186 const int depth_multiplier = node.depth_multiplier();
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100187
Georgios Pinitas28705162018-03-21 20:10:53 +0000188 // Validate function
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100189 Status status{};
190 switch(dwc_algorithm)
Georgios Pinitas28705162018-03-21 20:10:53 +0000191 {
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100192 case DepthwiseConvolutionMethod::Default:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100193 case DepthwiseConvolutionMethod::Optimized3x3:
Manuel Bottini05069f02019-09-26 17:18:26 +0100194 status = DepthwiseConvolutionLayer::validate(input, weights, biases, output, conv_info, depth_multiplier);
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100195 break;
196 default:
197 ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported depthwise convolution method");
Georgios Pinitas28705162018-03-21 20:10:53 +0000198 }
199
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100200 return status;
Georgios Pinitas28705162018-03-21 20:10:53 +0000201}
Isabella Gottardicd4e9ab2019-11-05 17:50:27 +0000202/** Validates a dequantize layer node
203 *
204 * @tparam DequantizationLayer Dequantize layer type
205 *
206 * @param[in] node Node to validate
207 *
208 * @return Status
209 */
210template <typename DequantizationLayer>
211Status validate_dequantization_layer(DequantizationLayerNode &node)
212{
213 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
214 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
215 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
Isabella Gottardi0ae5de92019-03-14 10:32:11 +0000216
Isabella Gottardicd4e9ab2019-11-05 17:50:27 +0000217 // Extract IO and info
218 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
219 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
220
221 return DequantizationLayer::validate(input, output);
222}
Isabella Gottardi7234ed82018-11-27 08:51:10 +0000223/** Validates a detection output layer node
224 *
225 * @tparam DetectionOutputLayer DetectionOutput layer type
226 *
227 * @param[in] node Node to validate
228 *
229 * @return Status
230 */
231template <typename DetectionOutputLayer>
232Status validate_detection_output_layer(DetectionOutputLayerNode &node)
233{
234 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionOutputLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
235 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
236 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
237
238 // Extract IO and info
239 arm_compute::ITensorInfo *input0 = get_backing_tensor_info(node.input(0));
240 arm_compute::ITensorInfo *input1 = get_backing_tensor_info(node.input(1));
241 arm_compute::ITensorInfo *input2 = get_backing_tensor_info(node.input(2));
242 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
243 const DetectionOutputLayerInfo detect_info = node.detection_output_info();
244
245 return DetectionOutputLayer::validate(input0, input1, input2, output, detect_info);
246}
Isabella Gottardia7acb3c2019-01-08 13:48:44 +0000247/** Validates a detection post process layer node
248 *
249 * @tparam DetectionPostProcessLayer DetectionOutput layer type
250 *
251 * @param[in] node Node to validate
252 *
253 * @return Status
254 */
255template <typename DetectionPostProcessLayer>
256Status validate_detection_post_process_layer(DetectionPostProcessLayerNode &node)
257{
258 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DetectionPostProcessLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
259 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
260 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 4);
261
262 // Extract IO and info
263 arm_compute::ITensorInfo *input0 = get_backing_tensor_info(node.input(0));
264 arm_compute::ITensorInfo *input1 = get_backing_tensor_info(node.input(1));
265 arm_compute::ITensorInfo *input2 = get_backing_tensor_info(node.input(2));
266 arm_compute::ITensorInfo *output0 = get_backing_tensor_info(node.output(0));
267 arm_compute::ITensorInfo *output1 = get_backing_tensor_info(node.output(1));
268 arm_compute::ITensorInfo *output2 = get_backing_tensor_info(node.output(2));
269 arm_compute::ITensorInfo *output3 = get_backing_tensor_info(node.output(3));
270 const DetectionPostProcessLayerInfo detect_info = node.detection_post_process_info();
271
272 return DetectionPostProcessLayer::validate(input0, input1, input2, output0, output1, output2, output3, detect_info);
273}
Georgios Pinitas57c48242018-08-02 13:41:49 +0100274
Manuel Bottini5209be52019-02-13 16:34:56 +0000275/** Validates a Generate Proposals layer node
276 *
277 * @tparam GenerateProposalsLayer Generate Proposals layer type
278 *
279 * @param[in] node Node to validate
280 *
281 * @return Status
282 */
283template <typename GenerateProposalsLayer>
284Status validate_generate_proposals_layer(GenerateProposalsLayerNode &node)
285{
286 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating GenerateProposalsLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
287 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
288 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 3);
289
290 // Extract IO and info
291 arm_compute::ITensorInfo *scores = detail::get_backing_tensor_info(node.input(0));
292 arm_compute::ITensorInfo *deltas = detail::get_backing_tensor_info(node.input(1));
293 arm_compute::ITensorInfo *anchors = detail::get_backing_tensor_info(node.input(2));
294 arm_compute::ITensorInfo *proposals = get_backing_tensor_info(node.output(0));
295 arm_compute::ITensorInfo *scores_out = get_backing_tensor_info(node.output(1));
296 arm_compute::ITensorInfo *num_valid_proposals = get_backing_tensor_info(node.output(2));
297 const GenerateProposalsInfo info = node.info();
298
299 return GenerateProposalsLayer::validate(scores, deltas, anchors, proposals, scores_out, num_valid_proposals, info);
300}
301
Michele Di Giorgio555d1102018-09-12 13:51:59 +0100302/** Validates a NormalizePlanarYUV layer node
303 *
304 * @tparam NormalizePlanarYUVLayer layer type
305 *
306 * @param[in] node Node to validate
307 *
308 * @return Status
309 */
310template <typename NormalizePlanarYUVLayer>
311Status validate_normalize_planar_yuv_layer(NormalizePlanarYUVLayerNode &node)
312{
313 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating NormalizePlanarYUVLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
314 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
315 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
316
317 // Extract IO and info
318 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
319 arm_compute::ITensorInfo *mean = detail::get_backing_tensor_info(node.input(1));
320 arm_compute::ITensorInfo *std = detail::get_backing_tensor_info(node.input(2));
321 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
322
323 // Validate function
324 return NormalizePlanarYUVLayer::validate(input, output, mean, std);
325}
Michele Di Giorgio4bb17332018-09-26 13:56:51 +0100326
327/** Validates a pad layer node
328 *
329 * @tparam PadLayer Pad layer type
330 *
331 * @param[in] node Node to validate
332 *
333 * @return Status
334 */
335template <typename PadLayer>
336Status validate_pad_layer(PadLayerNode &node)
337{
338 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PadLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
339 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
340 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
341
342 // Extract IO and info
343 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
344 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
345 const PaddingList &padding = node.padding();
346
347 return PadLayer::validate(input, output, padding);
348}
349
Georgios Pinitas57c48242018-08-02 13:41:49 +0100350/** Validates a permute layer node
351 *
352 * @tparam PermuteLayer Permute layer type
353 *
354 * @param[in] node Node to validate
355 *
356 * @return Status
357 */
358template <typename PermuteLayer>
359Status validate_permute_layer(PermuteLayerNode &node)
360{
361 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PermuteLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
362 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
363 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
364
365 // Extract IO and info
366 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
367 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
368 const PermutationVector &perm = node.permutation_vector();
369
370 return PermuteLayer::validate(input, output, perm);
371}
Pablo Tello32521432018-11-15 14:43:10 +0000372/** Validates a priorbox layer node
373 *
374 * @tparam PriorBoxLayer PriorBox layer type
375 *
376 * @param[in] node Node to validate
377 *
378 * @return Status
379 */
380template <typename PriorBoxLayer>
381Status validate_priorbox_layer(PriorBoxLayerNode &node)
382{
383 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PriorBoxLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
384 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2);
385 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
386
387 // Extract IO and info
388 arm_compute::ITensorInfo *input0 = get_backing_tensor_info(node.input(0));
389 arm_compute::ITensorInfo *input1 = get_backing_tensor_info(node.input(1));
390 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
391 const PriorBoxLayerInfo prior_info = node.priorbox_info();
392
393 return PriorBoxLayer::validate(input0, input1, output, prior_info);
394}
Gian Marco Iodice23e24792018-09-07 15:32:14 +0100395
Isabella Gottardi3db1ba92019-05-17 12:35:20 +0100396/** Validates a Quantization layer node
397 *
398 * @tparam QuantizationLayer Quantization layer type
399 *
400 * @param[in] node Node to validate
401 *
402 * @return Status
403 */
404template <typename QuantizationLayer>
405Status validate_quantization_layer(QuantizationLayerNode &node)
406{
407 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating QuantizationLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
408 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
409 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
410
411 // Extract input and output
412 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
413 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
414
415 // Validate function
416 return QuantizationLayer::validate(input, output);
417}
418
Gian Marco Iodice23e24792018-09-07 15:32:14 +0100419/** Validates a Reorg layer node
420 *
421 * @tparam ReorgLayer Reorg layer type
422 *
423 * @param[in] node Node to validate
424 *
425 * @return Status
426 */
427template <typename ReorgLayer>
428Status validate_reorg_layer(ReorgLayerNode &node)
429{
430 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReorgLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
431 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
432 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
433
434 // Extract input and output
435 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
436 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
437
438 // Validate function
439 return ReorgLayer::validate(input, output, node.stride());
440}
Michalis Spyrou96f67692018-09-13 11:39:28 +0100441
Isabella Gottardi0ae5de92019-03-14 10:32:11 +0000442/** Validates a Reshape layer node
443 *
444 * @tparam ReshapeLayer Reshape layer type
445 *
446 * @param[in] node Node to validate
447 *
448 * @return Status
449 */
450template <typename ReshapeLayer>
451Status validate_reshape_layer(ReshapeLayerNode &node)
452{
453 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReshapeLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
454 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
455 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
456
457 // Extract input and output
458 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
459 arm_compute::ITensorInfo *output = detail::get_backing_tensor_info(node.output(0));
460
461 // Validate function
462 return ReshapeLayer::validate(input, output);
463}
464
Manuel Bottini3f9d4d72018-10-19 14:04:42 +0100465/** Validates a ROI Align layer node
466 *
467 * @tparam ROIAlignLayer ROIAlign layer type
468 *
469 * @param[in] node Node to validate
470 *
471 * @return Status
472 */
473template <typename ROIAlignLayer>
474Status validate_roi_align_layer(ROIAlignLayerNode &node)
475{
476 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ROIAlignLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
477 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 2);
478 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
479
480 // Extract input and output
481 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
482 arm_compute::ITensorInfo *rois = detail::get_backing_tensor_info(node.input(1));
483 arm_compute::ITensorInfo *output = detail::get_backing_tensor_info(node.output(0));
484 const ROIPoolingLayerInfo &pool_info = node.pooling_info();
485
486 // Validate function
487 return ROIAlignLayer::validate(input, rois, output, pool_info);
488}
489
Michele Di Giorgioc30b6682018-09-12 17:44:08 +0100490/** Validates a Slice layer node
491 *
492 * @tparam SliceLayer Slice layer function type
493 *
494 * @param[in] node Node to validate
495 *
496 * @return Status
497 */
498template <typename SliceLayer>
499Status validate_slice_layer(SliceLayerNode &node)
500{
501 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating Slice node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
502 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
503 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
504
505 // Extract IO and info
506 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
507 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
508 const Coordinates starts = node.starts();
509 const Coordinates ends = node.ends();
510
511 return SliceLayer::validate(input, output, starts, ends);
512}
513
Michalis Spyrou4e1c3f32018-09-20 17:14:03 +0100514/** Validates a Upsample layer node
515 *
516 * @tparam UpsampleLayer Upsample layer type
517 *
518 * @param[in] node Node to validate
519 *
520 * @return Status
521 */
522template <typename UpsampleLayer>
523Status validate_upsample_layer(UpsampleLayerNode &node)
524{
525 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating UpsampleLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
526 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
527 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
528
529 // Extract input and output
530 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
531 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
532
533 // Validate function
534 return UpsampleLayer::validate(input, output, node.info(), node.upsampling_policy());
535}
Michalis Spyrou96f67692018-09-13 11:39:28 +0100536/** Validates a YOLO layer node
537 *
538 * @tparam YOLOLayer YOLO layer type
539 *
540 * @param[in] node Node to validate
541 *
542 * @return Status
543 */
544template <typename YOLOLayer>
545Status validate_yolo_layer(YOLOLayerNode &node)
546{
547 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating YOLOLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
548 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
549 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
550
551 // Extract input and output
552 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
553 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
554
555 // Validate function
556 return YOLOLayer::validate(input, output, node.activation_info(), node.num_classes());
557}
Georgios Pinitas28705162018-03-21 20:10:53 +0000558} // namespace detail
559} // namespace backends
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100560} // namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +0000561} // namespace arm_compute
562
Michalis Spyrouf4643372019-11-29 16:17:13 +0000563#endif /* ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H */