blob: c1b87ee0c084a687563f43d4547e3bd38ce65e60 [file] [log] [blame]
Georgios Pinitas28705162018-03-21 20:10:53 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010024#ifndef __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__
25#define __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__
Georgios Pinitas28705162018-03-21 20:10:53 +000026
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010027#include "arm_compute/graph/Logger.h"
28#include "arm_compute/graph/Tensor.h"
29#include "arm_compute/graph/Types.h"
30#include "arm_compute/graph/nodes/Nodes.h"
Georgios Pinitas28705162018-03-21 20:10:53 +000031
32#include "arm_compute/core/Error.h"
33#include "arm_compute/core/ITensorInfo.h"
34
35namespace arm_compute
36{
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010037namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +000038{
39namespace backends
40{
41namespace detail
42{
43/** Returns backing tensor info of a given tensor
44 *
45 * @param[in] tensor Tensor to extract the backing tensor from
46 *
47 * @return Backing tensor tensor info if present else nullptr
48 */
Georgios Pinitasd9eb2752018-04-03 13:44:29 +010049inline arm_compute::ITensorInfo *get_backing_tensor_info(arm_compute::graph::Tensor *tensor)
Georgios Pinitas28705162018-03-21 20:10:53 +000050{
51 return ((tensor == nullptr) || (tensor->handle() == nullptr)) ? nullptr : tensor->handle()->tensor().info();
52}
53
54/** Validates a Convolution layer node
55 *
56 * @tparam ConvolutionLayer Default Convolution layer function type
57 * @tparam DirectConvolutionLayer Direct Convolution layer function type
58 * @tparam GEMMConvolutionLayer GEMM Convolution layer function type
59 * @tparam WinogradConvolutionLayer Winograd Convolution layer function type
60 *
61 * @param[in] node Node to validate
62 *
63 * @return Status
64 */
65template <typename ConvolutionLayer, typename DirectConvolutionLayer, typename GEMMConvolutionLayer, typename WinogradConvolutionLayer>
66Status validate_convolution_layer(ConvolutionLayerNode &node)
67{
68 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating ConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
69 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
70 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
71
72 // Extract IO and info
Giorgio Arenabb54e4e2018-04-05 17:20:34 +010073 arm_compute::ITensorInfo *input = get_backing_tensor_info(node.input(0));
74 arm_compute::ITensorInfo *weights = get_backing_tensor_info(node.input(1));
75 arm_compute::ITensorInfo *biases = get_backing_tensor_info(node.input(2));
76 arm_compute::ITensorInfo *output = get_backing_tensor_info(node.output(0));
77
78 if(is_data_type_quantized_asymmetric(input->data_type()))
79 {
80 biases->set_data_type(DataType::S32);
81 }
82
83 const PadStrideInfo conv_info = node.convolution_info();
84 const ConvolutionMethod conv_algorithm = node.convolution_method();
Georgios Pinitas28705162018-03-21 20:10:53 +000085
86 // Validate function
87 Status status{};
88 switch(conv_algorithm)
89 {
90 case ConvolutionMethod::DIRECT:
91 status = DirectConvolutionLayer::validate(input, weights, biases, output, conv_info);
92 break;
93 case ConvolutionMethod::GEMM:
94 status = GEMMConvolutionLayer::validate(input, weights, biases, output, conv_info);
95 break;
96 case ConvolutionMethod::WINOGRAD:
97 status = WinogradConvolutionLayer::validate(input, weights, biases, output, conv_info);
98 break;
99 default:
100 break;
101 }
102
103 // If validation fails try the Default approach
104 if(!bool(status) || (conv_algorithm == ConvolutionMethod::DEFAULT))
105 {
106 status = ConvolutionLayer::validate(input, weights, biases, output, conv_info);
107 if(bool(status))
108 {
109 ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : "
110 << node.id() << " and Name: " << node.name() << std::endl);
111 node.set_convolution_method(ConvolutionMethod::DEFAULT);
112 }
113 }
114
115 return status;
116}
117
118/** Validates a Depthwise Convolution layer node
119 *
120 * @tparam DepthwiseConvolutionLayer Default Depthwise Convolution layer type
121 * @tparam DepthwiseConvolutionLayer3x3 Optimized 3x3 Depthwise Convolution layer type
122 *
123 * @param[in] node Node to validate
124 *
125 * @return Status
126 */
127template <typename DepthwiseConvolutionLayer, typename DepthwiseConvolutionLayer3x3>
128Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
129{
130 ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating DepthwiseConvolutionLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
131 ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
132 ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 1);
133
134 // Extract IO and info
135 arm_compute::ITensorInfo *weights = detail::get_backing_tensor_info(node.input(1));
136 const DepthwiseConvolutionMethod dwc_algorithm = node.depthwise_convolution_method();
137 ARM_COMPUTE_ERROR_ON(weights == nullptr);
138
139 // TODO (geopin01) : Switch when validation is implemented
140 // Validate function
141 if((dwc_algorithm == DepthwiseConvolutionMethod::OPTIMIZED_3x3) && (weights->tensor_shape().x() != 3))
142 {
143 ARM_COMPUTE_LOG_GRAPH_INFO("Switched DepthwiseConvolutionLayer method of node with ID : "
144 << node.id() << " and Name: " << node.name() << std::endl);
145 node.set_depthwise_convolution_method(DepthwiseConvolutionMethod::DEFAULT);
146 }
147
148 return Status{};
149}
150} // namespace detail
151} // namespace backends
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100152} // namespace graph
Georgios Pinitas28705162018-03-21 20:10:53 +0000153} // namespace arm_compute
154
Georgios Pinitasd9eb2752018-04-03 13:44:29 +0100155#endif /* __ARM_COMPUTE_GRAPH_BACKENDS_DETAIL_VALIDATE_HELPERS_H__ */