blob: 237d4ae2a4c70e19ca9bdaeadaa91b25325e722f [file] [log] [blame]
Georgios Pinitas28705162018-03-21 20:10:53 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010024#ifndef __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__
25#define __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__
Georgios Pinitas28705162018-03-21 20:10:53 +000026
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/graph/Logger.h"
28#include "arm_compute/graph/Tensor.h"
29#include "arm_compute/graph/Types.h"
30#include "arm_compute/graph/nodes/Nodes.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000031
32#include "arm_compute/core/Error.h"
Georgios Pinitascac13b12018-04-27 19:07:19 +010033#include "arm_compute/core/Helpers.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000034#include "arm_compute/core/ITensorInfo.h"
35
36namespace arm_compute
37{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010038namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +000039{
40namespace backends
41{
42namespace detail
43{
44/** Returns backing tensor info of a given tensor
45 *
46 * @param[in] tensor Tensor to extract the backing tensor from
47 *
48 * @return Backing tensor tensor info if present else nullptr
49 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010050inline arm_compute::ITensorInfo *get_backing_tensor_info(arm_compute::graph::Tensor *tensor)
Georgios Pinitas28705162018-03-21 20:10:53 +000051{
52 return ((tensor == nullptr) || (tensor->handle() == nullptr)) ? nullptr : tensor->handle()->tensor().info();
53}
54
55/** Validates a Convolution layer node
56 *
57 * @tparam ConvolutionLayer Default Convolution layer function type
58 * @tparam DirectConvolutionLayer Direct Convolution layer function type
59 * @tparam GEMMConvolutionLayer GEMM Convolution layer function type
60 * @tparam WinogradConvolutionLayer Winograd Convolution layer function type
61 *
62 * @param[in] node Node to validate
63 *
64 * @return Status
65 */
66template <typename ConvolutionLayer, typename DirectConvolutionLayer, typename GEMMConvolutionLayer, typename WinogradConvolutionLayer>
67Status validate_convolution_layer(ConvolutionLayerNode &node)
68{
69 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
70 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
71 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
72
73 // Extract IO and info
Giorgio Arenabb54e4e2018-04-05 17:20:34 +010074 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
75 arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1));
76 arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
77 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
78
79 if(is_data_type_quantized_asymmetric(input->data_type()))
80 {
81 biases->set_data_type(DataType::S32);
82 }
83
84 const PadStrideInfo conv_info = node.convolution_info();
85 const ConvolutionMethod conv_algorithm = node.convolution_method();
Georgios Pinitas28705162018-03-21 20:10:53 +000086
87 // Validate function
88 Status status{};
89 switch(conv_algorithm)
90 {
91 case ConvolutionMethod::DIRECT:
92 status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info);
93 break;
94 case ConvolutionMethod::GEMM:
95 status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info);
96 break;
97 case ConvolutionMethod::WINOGRAD:
98 status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info);
99 break;
100 default:
101 break;
102 }
103
104 // If validation fails try the Default approach
105 if(!bool(status) || (conv_algorithm == ConvolutionMethod::DEFAULT))
106 {
107 status = ConvolutionLayer::validate(input, weights, biases, output, conv_info);
108 if(bool(status))
109 {
110 ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : "
111 << node.id() << " and Name: " << node.name() << std::endl);
112 node.set_convolution_method(ConvolutionMethod::DEFAULT);
113 }
114 }
115
116 return status;
117}
118
119/** Validates a Depthwise Convolution layer node
120 *
121 * @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type
122 * @tparam DepthwiseConvolutionLayer3x3 Optimized 3x3 Depthwise Convolution layer type
123 *
124 * @param[in] node Node to validate
125 *
126 * @return Status
127 */
128template <typename DepthwiseConvolutionLayer, typename DepthwiseConvolutionLayer3x3>
129Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
130{
131 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
132 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
133 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
134
135 // Extract IO and info
136 arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1));
137 const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
138 ARM_COMPUTE_ERROR_ON(weights == nullptr);
139
140 // TODO (geopin01) : Switch when validation is implemented
141 // Validate function
Georgios Pinitascac13b12018-04-27 19:07:19 +0100142 if((dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) && (weights->tensor_shape()[get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::WIDTH)] != 3))
Georgios Pinitas28705162018-03-21 20:10:53 +0000143 {
144 ARM_COMPUTE_LOG_GRAPH_INFO("Switched DepthwiseConvolutionLayer method of node with ID : "
145 << node.id() << " and Name: " << node.name() << std::endl);
146 node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::DEFAULT);
147 }
148
149 return Status{};
150}
151} // namespace detail
152} // namespace backends
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100153} // namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +0000154} // namespace arm_compute
155
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100156#endif /* __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__ */