blob: 2dc349174dfa6eacbf9610f059bced1dd91a4287 [file] [log] [blame]
Georgios Pinitas28705162018-03-21 20:10:53 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010024#ifndef __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__
25#define __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__
Georgios Pinitas28705162018-03-21 20:10:53 +000026
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/graph/Logger.h"
28#include "arm_compute/graph/Tensor.h"
29#include "arm_compute/graph/Types.h"
30#include "arm_compute/graph/nodes/Nodes.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000031
32#include "arm_compute/core/Error.h"
Georgios Pinitascac13b12018-04-27 19:07:19 +010033#include "arm_compute/core/Helpers.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000034#include "arm_compute/core/ITensorInfo.h"
35
36namespace arm_compute
37{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010038namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +000039{
40namespace backends
41{
42namespace detail
43{
44/** Returns backing tensor info of a given tensor
45 *
46 * @param[in] tensor Tensor to extract the backing tensor from
47 *
48 * @return Backing tensor tensor info if present else nullptr
49 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010050inline arm_compute::ITensorInfo *get_backing_tensor_info(arm_compute::graph::Tensor *tensor)
Georgios Pinitas28705162018-03-21 20:10:53 +000051{
52 return ((tensor == nullptr) || (tensor->handle() == nullptr)) ? nullptr : tensor->handle()->tensor().info();
53}
54
Georgios Pinitas087eaf62018-05-16 15:52:35 +010055/** Validates a Channel Shuffle layer node
56 *
57 * @tparam ChannelShuffleLayer Channel Shuffle layer function type
58 *
59 * @param[in] node Node to validate
60 *
61 * @return Status
62 */
63template <typename ChannelShuffleLayer>
64Status validate_channel_shuffle_layer(ChannelShuffleLayerNode &node)
65{
66 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ChannelShuffle node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
67 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
68 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
69
70 // Extract IO and info
71 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
72 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
73 const unsigned int num_groups = node.num_groups();
74
75 return ChannelShuffleLayer::validate(input, output, num_groups);
76}
77
Georgios Pinitas28705162018-03-21 20:10:53 +000078/** Validates a Convolution layer node
79 *
80 * @tparam ConvolutionLayer Default Convolution layer function type
81 * @tparam DirectConvolutionLayer Direct Convolution layer function type
82 * @tparam GEMMConvolutionLayer GEMM Convolution layer function type
83 * @tparam WinogradConvolutionLayer Winograd Convolution layer function type
84 *
85 * @param[in] node Node to validate
86 *
87 * @return Status
88 */
89template <typename ConvolutionLayer, typename DirectConvolutionLayer, typename GEMMConvolutionLayer, typename WinogradConvolutionLayer>
90Status validate_convolution_layer(ConvolutionLayerNode &node)
91{
92 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
93 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
94 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
95
96 // Extract IO and info
Giorgio Arenabb54e4e2018-04-05 17:20:34 +010097 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
98 arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1));
99 arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
100 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
101
102 if(is_data_type_quantized_asymmetric(input->data_type()))
103 {
104 biases->set_data_type(DataType::S32);
105 }
106
107 const PadStrideInfo conv_info = node.convolution_info();
108 const ConvolutionMethod conv_algorithm = node.convolution_method();
Georgios Pinitase2220552018-07-20 13:23:44 +0100109 const bool fast_math = node.fast_math_hint() == FastMathHint::Enabled;
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100110 const unsigned int num_groups = node.num_groups();
Georgios Pinitas28705162018-03-21 20:10:53 +0000111
112 // Validate function
113 Status status{};
114 switch(conv_algorithm)
115 {
Georgios Pinitase2220552018-07-20 13:23:44 +0100116 case ConvolutionMethod::Direct:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100117 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "DirectConvolutionLayer does not support grouping!");
Georgios Pinitas28705162018-03-21 20:10:53 +0000118 status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info);
119 break;
120 case ConvolutionMethod::GEMM:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100121 status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info,
122 WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), num_groups);
Georgios Pinitas28705162018-03-21 20:10:53 +0000123 break;
Georgios Pinitase2220552018-07-20 13:23:44 +0100124 case ConvolutionMethod::Winograd:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100125 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups != 1, "WinogradConvolutionLayer does not support grouping!");
Georgios Pinitase2220552018-07-20 13:23:44 +0100126 status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info, ActivationLayerInfo(), fast_math);
Georgios Pinitas28705162018-03-21 20:10:53 +0000127 break;
Georgios Pinitase2220552018-07-20 13:23:44 +0100128 case ConvolutionMethod::Default:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100129 status = ConvolutionLayer::validate(input, weights, biases, output, conv_info,
130 WeightsInfo(), Size2D(1, 1), ActivationLayerInfo(), fast_math, num_groups);
Georgios Pinitas54d6fae2018-05-10 15:50:14 +0100131 break;
Georgios Pinitas28705162018-03-21 20:10:53 +0000132 default:
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100133 ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported convolution method");
Georgios Pinitas28705162018-03-21 20:10:53 +0000134 }
135
136 return status;
137}
138
139/** Validates a Depthwise Convolution layer node
140 *
141 * @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type
142 * @tparam DepthwiseConvolutionLayer3x3 Optimized 3x3 Depthwise Convolution layer type
143 *
144 * @param[in] node Node to validate
145 *
146 * @return Status
147 */
148template <typename DepthwiseConvolutionLayer, typename DepthwiseConvolutionLayer3x3>
149Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
150{
151 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
152 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
153 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
154
155 // Extract IO and info
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100156 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
157 arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1));
158 arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
159 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
Georgios Pinitas28705162018-03-21 20:10:53 +0000160
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100161 const PadStrideInfo conv_info = node.convolution_info();
162 const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
163
Georgios Pinitas28705162018-03-21 20:10:53 +0000164 // Validate function
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100165 Status status{};
166 switch(dwc_algorithm)
Georgios Pinitas28705162018-03-21 20:10:53 +0000167 {
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100168 case DepthwiseConvolutionMethod::Default:
169 case DepthwiseConvolutionMethod::GEMV:
170 status = DepthwiseConvolutionLayer::validate(input, weights, biases, output, conv_info);
171 break;
172 case DepthwiseConvolutionMethod::Optimized3x3:
173 status = DepthwiseConvolutionLayer3x3::validate(input, weights, biases, output, conv_info);
174 break;
175 default:
176 ARM_COMPUTE_RETURN_ERROR_MSG("Unsupported depthwise convolution method");
Georgios Pinitas28705162018-03-21 20:10:53 +0000177 }
178
Georgios Pinitas2a2db592018-08-15 12:14:46 +0100179 return status;
Georgios Pinitas28705162018-03-21 20:10:53 +0000180}
Georgios Pinitas57c48242018-08-02 13:41:49 +0100181
Michele Di Giorgio555d1102018-09-12 13:51:59 +0100182/** Validates a NormalizePlanarYUV layer node
183 *
184 * @tparam NormalizePlanarYUVLayer layer type
185 *
186 * @param[in] node Node to validate
187 *
188 * @return Status
189 */
190template <typename NormalizePlanarYUVLayer>
191Status validate_normalize_planar_yuv_layer(NormalizePlanarYUVLayerNode &node)
192{
193 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating NormalizePlanarYUVLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
194 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
195 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
196
197 // Extract IO and info
198 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
199 arm_compute::ITensorInfo *mean = detail::get_backing_tensor_info(node.input(1));
200 arm_compute::ITensorInfo *std = detail::get_backing_tensor_info(node.input(2));
201 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
202
203 // Validate function
204 return NormalizePlanarYUVLayer::validate(input, output, mean, std);
205}
Georgios Pinitas57c48242018-08-02 13:41:49 +0100206/** Validates a permute layer node
207 *
208 * @tparam PermuteLayer Permute layer type
209 *
210 * @param[in] node Node to validate
211 *
212 * @return Status
213 */
214template <typename PermuteLayer>
215Status validate_permute_layer(PermuteLayerNode &node)
216{
217 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating PermuteLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
218 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
219 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
220
221 // Extract IO and info
222 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
223 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
224 const PermutationVector &perm = node.permutation_vector();
225
226 return PermuteLayer::validate(input, output, perm);
227}
Gian Marco Iodice23e24792018-09-07 15:32:14 +0100228
229/** Validates a Reorg layer node
230 *
231 * @tparam ReorgLayer Reorg layer type
232 *
233 * @param[in] node Node to validate
234 *
235 * @return Status
236 */
237template <typename ReorgLayer>
238Status validate_reorg_layer(ReorgLayerNode &node)
239{
240 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ReorgLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
241 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
242 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
243
244 // Extract input and output
245 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
246 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
247
248 // Validate function
249 return ReorgLayer::validate(input, output, node.stride());
250}
Michalis Spyrou96f67692018-09-13 11:39:28 +0100251
Michele Di Giorgioc30b6682018-09-12 17:44:08 +0100252/** Validates a Slice layer node
253 *
254 * @tparam SliceLayer Slice layer function type
255 *
256 * @param[in] node Node to validate
257 *
258 * @return Status
259 */
260template <typename SliceLayer>
261Status validate_slice_layer(SliceLayerNode &node)
262{
263 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating Slice node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
264 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
265 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
266
267 // Extract IO and info
268 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
269 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
270 const Coordinates starts = node.starts();
271 const Coordinates ends = node.ends();
272
273 return SliceLayer::validate(input, output, starts, ends);
274}
275
Michalis Spyrou96f67692018-09-13 11:39:28 +0100276/** Validates a YOLO layer node
277 *
278 * @tparam YOLOLayer YOLO layer type
279 *
280 * @param[in] node Node to validate
281 *
282 * @return Status
283 */
284template <typename YOLOLayer>
285Status validate_yolo_layer(YOLOLayerNode &node)
286{
287 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating YOLOLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
288 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 1);
289 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
290
291 // Extract input and output
292 arm_compute::ITensorInfo *input = detail::get_backing_tensor_info(node.input(0));
293 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
294
295 // Validate function
296 return YOLOLayer::validate(input, output, node.activation_info(), node.num_classes());
297}
Georgios Pinitas28705162018-03-21 20:10:53 +0000298} // namespace detail
299} // namespace backends
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100300} // namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +0000301} // namespace arm_compute
302
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100303#endif /* __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__ */