blob: e94c8933ae3158857ab828e8b7cf9712dc258b0b [file] [log] [blame]
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +00001/*
George Wort2d7e6832019-02-22 16:37:41 +00002 * Copyright (c) 2017-2019 ARM Limited.
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h"
25
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000026#include "arm_compute/core/Size2D.h"
27#include "arm_compute/core/Utils.h"
28#include "arm_compute/core/Validate.h"
Gian Marco Iodice597a8562018-08-01 15:06:06 +010029#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000030#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
31#include "arm_compute/runtime/NEON/NEScheduler.h"
32#include "support/ToolchainSupport.h"
33
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000034#include <cmath>
Georgios Pinitas08346e92018-10-16 19:10:46 +010035#include <set>
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000036#include <tuple>
37
Gian Marco Iodice597a8562018-08-01 15:06:06 +010038using namespace arm_compute;
39using namespace arm_compute::misc::shape_calculator;
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000040
Gian Marco Iodice597a8562018-08-01 15:06:06 +010041NEConvolutionLayerReshapeWeights::NEConvolutionLayerReshapeWeights()
42 : _weights_reshape_kernel()
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000043{
44}
45
Gian Marco Iodice597a8562018-08-01 15:06:06 +010046void NEConvolutionLayerReshapeWeights::configure(const ITensor *weights, const ITensor *biases, ITensor *output)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000047{
48 // Perform validation step
49 ARM_COMPUTE_ERROR_ON_NULLPTR(weights, output);
50 ARM_COMPUTE_ERROR_THROW_ON(NEConvolutionLayerReshapeWeights::validate(weights->info(),
51 (biases != nullptr) ? biases->info() : nullptr,
Gian Marco Iodice597a8562018-08-01 15:06:06 +010052 output->info()));
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000053
Gian Marco Iodice597a8562018-08-01 15:06:06 +010054 const bool append_biases = (biases != nullptr) && !is_data_type_quantized_asymmetric(weights->info()->data_type());
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000055 const ITensor *biases_to_use = (append_biases) ? biases : nullptr;
56
Gian Marco Iodice597a8562018-08-01 15:06:06 +010057 _weights_reshape_kernel.configure(weights, biases_to_use, output);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000058
59 output->info()->set_quantization_info(weights->info()->quantization_info());
60}
61
Gian Marco Iodice597a8562018-08-01 15:06:06 +010062Status NEConvolutionLayerReshapeWeights::validate(const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000063{
Gian Marco Iodice597a8562018-08-01 15:06:06 +010064 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(weights);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010065 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000066 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000067
Gian Marco Iodice597a8562018-08-01 15:06:06 +010068 if(biases != nullptr)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000069 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +010070 const int idx_kernels = get_data_layout_dimension_index(weights->data_layout(), DataLayoutDimension::BATCHES);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000071 ARM_COMPUTE_RETURN_ERROR_ON(is_data_type_quantized_asymmetric(weights->data_type()));
72 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, biases);
Gian Marco Iodice597a8562018-08-01 15:06:06 +010073 ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000074 ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
75 }
76
Gian Marco Iodice597a8562018-08-01 15:06:06 +010077 if((output != nullptr) && (output->total_size() != 0))
Michalis Spyroue2503892018-04-23 15:17:31 +010078 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +010079 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(weights, output);
Michalis Spyroue2503892018-04-23 15:17:31 +010080
Gian Marco Iodice597a8562018-08-01 15:06:06 +010081 NEWeightsReshapeKernel::validate(weights, biases, output);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000082 }
83
84 return Status{};
85}
86
87void NEConvolutionLayerReshapeWeights::run()
88{
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000089 NEScheduler::get().schedule(&_weights_reshape_kernel, 3);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000090}
91
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000092NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager)
George Wort2d7e6832019-02-22 16:37:41 +000093 : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(),
94 _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false),
95 _skip_col2im(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +000096{
97}
98
George Wort2d7e6832019-02-22 16:37:41 +000099void NEGEMMConvolutionLayer::configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act_info, int gemm_3d_depth)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000100{
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100101 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights);
George Wort2d7e6832019-02-22 16:37:41 +0000102 ARM_COMPUTE_ERROR_THROW_ON(validate_mm(input->info(), weights->info(), biases == nullptr ? nullptr : biases->info(), output == nullptr ? nullptr : output->info(), act_info, gemm_3d_depth,
103 _skip_im2col));
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100104
Georgios Pinitasbb081ca2018-11-08 10:22:01 +0000105 const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
106 gemm_3d_depth, _skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
107
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000108 if(_is_quantized)
109 {
110 // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
111 // Extract and negate input and weights offset
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100112 const UniformQuantizationInfo iqinfo = input->info()->quantization_info().uniform();
113 const UniformQuantizationInfo wqinfo = weights->info()->quantization_info().uniform();
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000114
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100115 input->info()->set_quantization_info(QuantizationInfo(iqinfo.scale, -iqinfo.offset));
116 weights->info()->set_quantization_info(QuantizationInfo(wqinfo.scale, -wqinfo.offset));
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000117
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100118 const UniformQuantizationInfo oqinfo = (output->info()->total_size() == 0) ? iqinfo : output->info()->quantization_info().uniform();
George Wort2d7e6832019-02-22 16:37:41 +0000119
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100120 float multiplier = iqinfo.scale * wqinfo.scale / oqinfo.scale;
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100121 int output_multiplier;
122 int output_shift;
George Wort2d7e6832019-02-22 16:37:41 +0000123 quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
124
125 // Merge activation with output stage
126 int min_activation = 0;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100127 int max_activation = 255;
George Wort2d7e6832019-02-22 16:37:41 +0000128
129 const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
130 ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
131 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
132 };
133 if(_is_activationlayer_enabled && supported_acts.count(act_info.activation()) != 0)
134 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100135 const int a_const_int = quantize_qasymm8(act_info.a(), oqinfo);
136 const int b_const_int = quantize_qasymm8(act_info.b(), oqinfo);
George Wort2d7e6832019-02-22 16:37:41 +0000137
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100138 min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? oqinfo.offset : b_const_int;
George Wort2d7e6832019-02-22 16:37:41 +0000139 max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int;
140
141 _is_activationlayer_enabled = false;
142 }
143
144 GEMMLowpOutputStageInfo output_info;
145 output_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100146 output_info.gemmlowp_offset = oqinfo.offset;
George Wort2d7e6832019-02-22 16:37:41 +0000147 output_info.gemmlowp_multiplier = output_multiplier;
148 output_info.gemmlowp_shift = output_shift;
149 output_info.gemmlowp_min_bound = min_activation;
150 output_info.gemmlowp_max_bound = max_activation;
151
152 _mm_gemmlowp.configure(input, weights, biases, output, GEMMInfo(false, false, true, gemm_3d_depth, _skip_im2col, false, output_info));
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000153
154 // Revert back QuantizatioInfo as input and weights could be used in other convolution layers
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100155 input->info()->set_quantization_info(QuantizationInfo(iqinfo.scale, iqinfo.offset));
156 weights->info()->set_quantization_info(QuantizationInfo(wqinfo.scale, wqinfo.offset));
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000157 }
158 else
159 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100160 // Configure matrix multiply function
Georgios Pinitasbb081ca2018-11-08 10:22:01 +0000161 _mm_gemm.configure(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000162 }
163}
164
George Wort2d7e6832019-02-22 16:37:41 +0000165Status NEGEMMConvolutionLayer::validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ActivationLayerInfo &act_info,
166 int gemm_3d_depth, bool skip_im2col)
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100167{
George Wort2d7e6832019-02-22 16:37:41 +0000168 const bool is_quantized = is_data_type_quantized_asymmetric(input->data_type());
169 const bool is_activation_enabled = act_info.enabled();
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100170
Georgios Pinitasbb081ca2018-11-08 10:22:01 +0000171 const GEMMInfo &gemm_info = GEMMInfo(false, false, true /* Reshape weights only for the first run */,
172 gemm_3d_depth, skip_im2col /* Reinterpret the input as 3D if im2col is skipped */);
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100173 if(is_quantized)
174 {
175 // Since we need negative offsets for computing convolution, we need to change QuantizationInfo()
176 // Extract and negate input and weights offset
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100177 const UniformQuantizationInfo iqinfo = input->quantization_info().uniform();
178 const UniformQuantizationInfo wqinfo = weights->quantization_info().uniform();
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100179
180 std::unique_ptr<ITensorInfo> input_qa = input->clone();
181 std::unique_ptr<ITensorInfo> weights_qa = weights->clone();
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100182 input_qa->set_quantization_info(QuantizationInfo(iqinfo.scale, -iqinfo.offset));
183 weights_qa->set_quantization_info(QuantizationInfo(wqinfo.scale, -wqinfo.offset));
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100184
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100185 const UniformQuantizationInfo oqinfo = (output->total_size() == 0) ? iqinfo : output->quantization_info().uniform();
George Wort2d7e6832019-02-22 16:37:41 +0000186
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100187 float multiplier = iqinfo.scale * wqinfo.scale / oqinfo.scale;
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100188 int output_multiplier;
189 int output_shift;
Georgios Pinitas4d600c72019-07-30 15:09:10 +0100190 ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift));
George Wort2d7e6832019-02-22 16:37:41 +0000191
192 // Merge activation with output stage
193 int min_activation = 0;
Georgios Pinitascfa2bba2019-06-27 17:00:52 +0100194 int max_activation = 255;
George Wort2d7e6832019-02-22 16:37:41 +0000195
196 const std::set<ActivationLayerInfo::ActivationFunction> supported_acts = { ActivationLayerInfo::ActivationFunction::RELU,
197 ActivationLayerInfo::ActivationFunction::BOUNDED_RELU,
198 ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU
199 };
200 if(is_activation_enabled && supported_acts.count(act_info.activation()) != 0)
201 {
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100202 const int a_const_int = quantize_qasymm8(act_info.a(), oqinfo);
203 const int b_const_int = quantize_qasymm8(act_info.b(), oqinfo);
George Wort2d7e6832019-02-22 16:37:41 +0000204
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100205 min_activation = act_info.activation() != ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU ? oqinfo.offset : b_const_int;
George Wort2d7e6832019-02-22 16:37:41 +0000206 max_activation = act_info.activation() == ActivationLayerInfo::ActivationFunction::RELU ? 255 : a_const_int;
207 }
208
209 GEMMLowpOutputStageInfo output_info;
210 output_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
Georgios Pinitas4c5469b2019-05-21 13:32:43 +0100211 output_info.gemmlowp_offset = oqinfo.offset;
George Wort2d7e6832019-02-22 16:37:41 +0000212 output_info.gemmlowp_multiplier = output_multiplier;
213 output_info.gemmlowp_shift = output_shift;
214 output_info.gemmlowp_min_bound = min_activation;
215 output_info.gemmlowp_max_bound = max_activation;
216
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100217 // Perform validation step on GEMMLowp
George Wort2d7e6832019-02-22 16:37:41 +0000218 return NEGEMMLowpMatrixMultiplyCore::validate(input_qa.get(), weights_qa.get(), biases, output, GEMMInfo(false, false, true, gemm_3d_depth, skip_im2col, false, output_info));
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100219 }
220 else
221 {
222 // Perform validation step on Matrix multiply function
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100223 return NEGEMM::validate(input, weights, nullptr, output, 1.0f, 0.0f, gemm_info);
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100224 }
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100225}
226
George Wort2d7e6832019-02-22 16:37:41 +0000227Status NEGEMMConvolutionLayer::validate_gemm3d(const ITensorInfo *input_info, const ActivationLayerInfo &act_info, int gemm_3d_depth, bool skip_im2col)
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100228{
George Wort2d7e6832019-02-22 16:37:41 +0000229 const DataType data_type = input_info->data_type();
230 const unsigned int mult_y = skip_im2col ? 1U : gemm_3d_depth;
231 const unsigned int mult_z = skip_im2col ? gemm_3d_depth : 1U;
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100232
233 // Set dummy tensor shapes for the validation
George Wort2d7e6832019-02-22 16:37:41 +0000234 const TensorInfo dummy_input_info(TensorShape(4U, 4U * mult_y, 1U * mult_z), 1, data_type, input_info->quantization_info());
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100235 const TensorInfo dummy_weights_info(TensorShape(4U, 4U), 1, data_type);
George Wort2d7e6832019-02-22 16:37:41 +0000236 const TensorInfo dummy_output_info(TensorShape(4U, 4U, gemm_3d_depth), 1, data_type, input_info->quantization_info());
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100237
George Wort2d7e6832019-02-22 16:37:41 +0000238 return validate_mm(&dummy_input_info, &dummy_weights_info, nullptr, &dummy_output_info, act_info, gemm_3d_depth, skip_im2col);
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100239}
240
Alex Gilday7da29b62018-03-23 14:16:00 +0000241void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, const WeightsInfo &weights_info,
Gian Marco Iodice916d1bc2018-08-13 11:20:41 +0100242 const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000243{
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000244 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Gian Marco Iodice916d1bc2018-08-13 11:20:41 +0100245 ARM_COMPUTE_UNUSED(num_groups);
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100246 ARM_COMPUTE_ERROR_THROW_ON(NEGEMMConvolutionLayer::validate(input->info(),
247 weights->info(),
248 biases != nullptr ? biases->info() : nullptr,
249 output->info(),
250 conv_info,
251 weights_info,
252 dilation,
Gian Marco Iodice916d1bc2018-08-13 11:20:41 +0100253 act_info,
254 num_groups));
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000255
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100256 const DataType data_type = input->info()->data_type();
257 const DataLayout data_layout = input->info()->data_layout();
258 const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
259 const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100260 const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
Michalis Spyroue2503892018-04-23 15:17:31 +0100261
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100262 const unsigned int kernel_width = weights->info()->dimension(idx_width);
263 const unsigned int kernel_height = weights->info()->dimension(idx_height);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000264
Georgios Pinitas08346e92018-10-16 19:10:46 +0100265 _is_prepared = weights_info.retain_internal_weights();
266 _original_weights = weights;
267 _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
268 _data_layout = data_layout;
269 _skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
Georgios Pinitas08346e92018-10-16 19:10:46 +0100270 _append_bias = (biases != nullptr) && (!_is_quantized);
271 _is_activationlayer_enabled = act_info.enabled();
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000272
George Wort2d7e6832019-02-22 16:37:41 +0000273 const ITensor *gemm_input_to_use = input;
274 ITensor *gemm_output_to_use = output;
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000275
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100276 // Get convolved dimensions
277 unsigned int conv_w = 0;
278 unsigned int conv_h = 0;
279 std::tie(conv_w, conv_h) = scaled_dimensions(input->info()->dimension(idx_width),
280 input->info()->dimension(idx_height),
281 kernel_width,
282 kernel_height,
283 conv_info,
284 dilation);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000285
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100286 // Check if GEMM3D is supported
Georgios Pinitase413d252018-11-14 18:29:58 +0000287 if(data_layout == DataLayout::NHWC)
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100288 {
George Wort2d7e6832019-02-22 16:37:41 +0000289 _skip_col2im = bool(validate_gemm3d(input->info(), act_info, conv_h, true));
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100290 // If not supported, we need to perform im2col and col2im (or reshape layer)
Georgios Pinitase413d252018-11-14 18:29:58 +0000291 if(!_skip_col2im)
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100292 {
293 _skip_im2col = false;
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100294 }
295 }
Georgios Pinitase413d252018-11-14 18:29:58 +0000296 else
297 {
298 _skip_col2im = false;
299 }
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100300
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100301 const ITensor *biases_to_use = (_append_bias && !_skip_im2col) ? biases : nullptr;
302
303 // Get parameters from conv_info
304 unsigned int stride_x = 0;
305 unsigned int stride_y = 0;
306 std::tie(stride_x, stride_y) = conv_info.stride();
307
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100308 unsigned int mat_weights_cols = weights->info()->dimension(idx_kernels);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000309
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100310 // _weights_reshaped will be auto configured in the kernel.
311 // Just append biases and do not transpose 1xW as it will be reshaped in NEGEMM
312 _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped);
313
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100314 // Create tensor to store im2col reshaped inputs
Michalis Spyroue2503892018-04-23 15:17:31 +0100315 if(!_skip_im2col)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000316 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100317 _memory_group.manage(&_im2col_output);
Michalis Spyroue2503892018-04-23 15:17:31 +0100318
Gian Marco Iodice215b4ea2018-06-28 16:29:29 +0100319 // Configure
Giorgio Arena0f170392018-07-18 16:13:12 +0100320 _im2col_kernel.configure(input, &_im2col_output, Size2D(kernel_width, kernel_height), conv_info, _append_bias, dilation);
Michalis Spyroue2503892018-04-23 15:17:31 +0100321
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100322 // Update GEMM input
323 gemm_input_to_use = &_im2col_output;
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000324 }
Michalis Spyroue2503892018-04-23 15:17:31 +0100325 else if(_append_bias)
326 {
327 // Configure add bias kernel
328 _add_bias_kernel.configure(output, biases, output, ConvertPolicy::SATURATE);
329 }
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000330
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100331 // Create temporary GEMM output tensor in case we cannot skip col2im
George Wort2d7e6832019-02-22 16:37:41 +0000332 if(!_skip_col2im)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000333 {
George Wort2d7e6832019-02-22 16:37:41 +0000334 TensorShape shape_gemm;
Georgios Pinitasbb081ca2018-11-08 10:22:01 +0000335
George Wort2d7e6832019-02-22 16:37:41 +0000336 // Calculate GEMM output shape
337 shape_gemm = _im2col_output.info()->tensor_shape();
338 shape_gemm.set(0, mat_weights_cols);
339 shape_gemm.set(1, conv_w * conv_h);
Georgios Pinitasbb081ca2018-11-08 10:22:01 +0000340
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100341 // FIXME: input->clone() doesn't work with subtensors for grouped convolutions.
George Wort2d7e6832019-02-22 16:37:41 +0000342 TensorInfo info_gemm(shape_gemm, 1, data_type);
Georgios Pinitas041f36d2018-09-18 18:38:37 +0100343 info_gemm.set_quantization_info(output->info()->quantization_info()).set_data_layout(input->info()->data_layout());
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100344 _gemm_output.allocator()->init(info_gemm);
345 _memory_group.manage(&_gemm_output);
346
347 // Update GEMM output
348 gemm_output_to_use = &_gemm_output;
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000349 }
350
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100351 // Configure GEMM
Gian Marco Iodice3139f032018-11-05 14:26:32 +0000352 // In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
353 const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
George Wort2d7e6832019-02-22 16:37:41 +0000354 configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, gemm_3d_depth);
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100355
Michalis Spyroue2503892018-04-23 15:17:31 +0100356 if(!_skip_im2col)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000357 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100358 _im2col_output.allocator()->allocate();
359 }
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000360
Georgios Pinitase413d252018-11-14 18:29:58 +0000361 if(!_skip_col2im)
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100362 {
Georgios Pinitase413d252018-11-14 18:29:58 +0000363 if(_data_layout == DataLayout::NCHW)
364 {
365 // Configure col2im
George Wort2d7e6832019-02-22 16:37:41 +0000366 _col2im_kernel.configure(gemm_output_to_use, output, Size2D(conv_w, conv_h));
Georgios Pinitase413d252018-11-14 18:29:58 +0000367 }
368 else
369 {
370 // Configure reshape layer
George Wort2d7e6832019-02-22 16:37:41 +0000371 _reshape_layer.configure(gemm_output_to_use, output);
Georgios Pinitase413d252018-11-14 18:29:58 +0000372 }
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100373 }
374
Georgios Pinitase413d252018-11-14 18:29:58 +0000375 if(_is_quantized && !_skip_col2im)
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100376 {
377 _tmp_output.allocator()->allocate();
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100378 }
379
Georgios Pinitasbb081ca2018-11-08 10:22:01 +0000380 if(!_skip_col2im || _is_quantized)
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100381 {
Michalis Spyroue2503892018-04-23 15:17:31 +0100382 _gemm_output.allocator()->allocate();
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000383 }
384
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100385 ARM_COMPUTE_ERROR_ON_MSG((output->info()->dimension(idx_width) != conv_w) || (output->info()->dimension(idx_height) != conv_h),
386 "Output shape does not match the expected one");
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000387
Georgios Pinitas08346e92018-10-16 19:10:46 +0100388 // Configure Activation Layer
Isabella Gottardi3f217ec2018-02-12 14:59:19 +0000389 if(_is_activationlayer_enabled)
390 {
391 _activationlayer_function.configure(output, nullptr, act_info);
392 }
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100393
394 ARM_COMPUTE_UNUSED(weights_info);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000395}
396
397Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
Gian Marco Iodice916d1bc2018-08-13 11:20:41 +0100398 const WeightsInfo &weights_info, const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000399{
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100400 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
401 ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
402 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
403 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
404 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights);
Gian Marco Iodice916d1bc2018-08-13 11:20:41 +0100405 ARM_COMPUTE_RETURN_ERROR_ON_MSG(num_groups > 1, "Grouping (num_groups != 1) is not supported on NEON");
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000406
Michalis Spyroue2503892018-04-23 15:17:31 +0100407 const DataLayout data_layout = input->data_layout();
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100408 const DataType data_type = input->data_type();
Michalis Spyroue2503892018-04-23 15:17:31 +0100409 const int idx_width = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
410 const int idx_height = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100411 const int idx_channel = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
412 const int idx_kernels = get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES);
Michalis Spyroue2503892018-04-23 15:17:31 +0100413
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100414 const unsigned int kernel_width = weights->dimension(idx_width);
415 const unsigned int kernel_height = weights->dimension(idx_height);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000416
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100417 TensorInfo im2col_reshaped_info{};
418 TensorInfo info_gemm{};
419 TensorInfo tmp_info{};
420 TensorInfo weights_reshaped_info{};
George Wort2d7e6832019-02-22 16:37:41 +0000421 const ITensorInfo *gemm_input_to_use = input;
422 const ITensorInfo *gemm_output_to_use = output;
423 const ITensorInfo *weights_to_use = weights;
Ioan-Cristian Szabob4e3e1c2017-11-30 17:17:17 +0000424
Georgios Pinitas08346e92018-10-16 19:10:46 +0100425 const bool is_quantized = is_data_type_quantized_asymmetric(data_type);
426 const bool append_bias = (biases != nullptr) && (!is_quantized);
427 bool skip_im2col = (data_layout == DataLayout::NHWC && kernel_width == 1 && kernel_height == 1 && conv_info.stride().first == 1 && conv_info.stride().second == 1);
Georgios Pinitas08346e92018-10-16 19:10:46 +0100428 bool is_activation_enabled = act_info.enabled();
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100429
430 // Get convolved dimensions
431 unsigned int conv_w = 0;
432 unsigned int conv_h = 0;
433
434 std::tie(conv_w, conv_h) = scaled_dimensions(input->dimension(idx_width),
435 input->dimension(idx_height),
436 kernel_width,
437 kernel_height,
438 conv_info,
439 dilation);
440
441 // Check if GEMM3D is supported
Georgios Pinitase413d252018-11-14 18:29:58 +0000442 bool skip_col2im = false;
443 if(data_layout == DataLayout::NHWC)
444 {
George Wort2d7e6832019-02-22 16:37:41 +0000445 skip_col2im = bool(validate_gemm3d(input, act_info, conv_h, true));
Georgios Pinitase413d252018-11-14 18:29:58 +0000446 // If not supported, we need to perform im2col and col2im (or reshape layer)
447 if(!skip_col2im)
448 {
449 skip_im2col = false;
450 }
451 }
452
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100453 if(skip_col2im)
454 {
455 // If not supported, we need to perform im2col and col2im (or reshape layer)
George Wort2d7e6832019-02-22 16:37:41 +0000456 if(!bool(validate_gemm3d(input, act_info, conv_h, skip_im2col)))
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100457 {
458 skip_im2col = false;
459 skip_col2im = false;
460 }
461 }
462
463 const unsigned bias_element = (append_bias && !skip_im2col) ? 1 : 0;
464 const ITensorInfo *biases_to_use = (append_bias && !skip_im2col) ? biases : nullptr;
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000465
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100466 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_channel) != input->dimension(idx_channel));
467 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000468
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100469 // Validate biases
470 if(biases != nullptr)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000471 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100472 if(is_quantized)
473 {
474 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
475 }
476 else
477 {
478 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases);
479 }
480 ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(idx_kernels));
481 ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000482 }
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000483
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100484 if(act_info.enabled())
485 {
486 ARM_COMPUTE_ERROR_ON(act_info.b() > act_info.a());
487 }
488
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100489 unsigned int mat_weights_cols = weights->dimension(idx_kernels);
490 unsigned int mat_weights_rows = weights->dimension(idx_width) * weights->dimension(idx_height) * weights->dimension(idx_channel) + bias_element;
491
492 // Output tensor auto inizialization if not yet initialized
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100493 ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases_to_use, nullptr));
494 weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, (append_bias && !skip_im2col)), 1, data_type);
Georgios Pinitas4d600c72019-07-30 15:09:10 +0100495 weights_reshaped_info.set_quantization_info(weights->quantization_info());
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100496 weights_to_use = &weights_reshaped_info;
497
Michalis Spyroue2503892018-04-23 15:17:31 +0100498 if(!skip_im2col)
499 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100500 // Create tensor info for im2col reshaped inputs
501 // For NEON the batch size is on the fourth dimension
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100502 // TODO (giaiod01): Auto-initialize the output shape of im2col COMPMID-1482
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100503 TensorShape shape_im2col = input->tensor_shape();
504 shape_im2col.set(0, mat_weights_rows);
505 shape_im2col.set(1, conv_w * conv_h);
506 shape_im2col.set(2, 1);
507
508 im2col_reshaped_info = TensorInfo(shape_im2col, 1, data_type);
509 im2col_reshaped_info.set_quantization_info(input->quantization_info());
510
Giorgio Arena0f170392018-07-18 16:13:12 +0100511 ARM_COMPUTE_RETURN_ON_ERROR(NEIm2ColKernel::validate(input, &im2col_reshaped_info, Size2D(kernel_width, kernel_height), conv_info, append_bias, dilation));
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100512 gemm_input_to_use = &im2col_reshaped_info;
Michalis Spyroue2503892018-04-23 15:17:31 +0100513 }
514 else if(append_bias)
515 {
516 // Validate add bias kernel
517 ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(output, biases, output, ConvertPolicy::SATURATE));
518 }
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000519
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100520 // Create temporary GEMM output tensor in case we cannot skip col2im
521 if(!skip_col2im)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000522 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100523 TensorShape shape_gemm = gemm_input_to_use->tensor_shape();
524 shape_gemm.set(0, mat_weights_cols);
525 shape_gemm.set(1, conv_w * conv_h);
George Wort2d7e6832019-02-22 16:37:41 +0000526 info_gemm = TensorInfo(shape_gemm, 1, data_type);
Michalis Spyroue2503892018-04-23 15:17:31 +0100527 }
Georgios Pinitasbb081ca2018-11-08 10:22:01 +0000528 else
529 {
George Wort2d7e6832019-02-22 16:37:41 +0000530 info_gemm = TensorInfo(output->tensor_shape(), 1, data_type);
Georgios Pinitasbb081ca2018-11-08 10:22:01 +0000531 }
532 info_gemm.set_quantization_info(output->quantization_info()).set_data_layout(input->data_layout());
533 gemm_output_to_use = &info_gemm;
George Wort2d7e6832019-02-22 16:37:41 +0000534 ARM_COMPUTE_RETURN_ON_ERROR(validate_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, skip_col2im ? conv_h : 0, skip_im2col));
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100535
Gian Marco Iodicedb9d46d2018-08-08 12:29:38 +0100536 // Validate Col2Im/ReshapeLayer
537 if(!skip_col2im && (data_layout == DataLayout::NCHW))
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100538 {
George Wort2d7e6832019-02-22 16:37:41 +0000539 ARM_COMPUTE_RETURN_ON_ERROR(NECol2ImKernel::validate(gemm_output_to_use, output, Size2D(conv_w, conv_h)));
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100540 }
541
542 //Validate Activation Layer
Georgios Pinitas08346e92018-10-16 19:10:46 +0100543 if(is_activation_enabled)
Isabella Gottardi3f217ec2018-02-12 14:59:19 +0000544 {
545 ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info));
546 }
547
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000548 return Status{};
549}
550
551void NEGEMMConvolutionLayer::run()
552{
Georgios Pinitas72219332018-06-05 14:56:06 +0100553 prepare();
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000554
Georgios Pinitasda953f22019-04-02 17:27:03 +0100555 MemoryGroupResourceScope scope_mg(_memory_group);
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000556
Michalis Spyroue2503892018-04-23 15:17:31 +0100557 if(!_skip_im2col)
558 {
559 // Run input reshaping
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100560 unsigned int y_dim = get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT);
561 NEScheduler::get().schedule(&_im2col_kernel, y_dim);
Michalis Spyroue2503892018-04-23 15:17:31 +0100562 }
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000563
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100564 // Runs NEGEMM or NEGEMMLowpMatrixMultiplyCore functions
565 if(_is_quantized)
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000566 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100567 // Run gemmlowp
568 _mm_gemmlowp.run();
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000569 }
570 else
571 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100572 // Run gemm
573 _mm_gemm.run();
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000574 }
575
Michalis Spyroue2503892018-04-23 15:17:31 +0100576 if(_skip_im2col && _append_bias)
577 {
578 NEScheduler::get().schedule(&_add_bias_kernel, Window::DimY);
579 }
580
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000581 // Reshape output matrix
Georgios Pinitase413d252018-11-14 18:29:58 +0000582 if(!_skip_col2im)
Michalis Spyroue2503892018-04-23 15:17:31 +0100583 {
Georgios Pinitase413d252018-11-14 18:29:58 +0000584 if(_data_layout == DataLayout::NCHW)
585 {
586 NEScheduler::get().schedule(&_col2im_kernel, Window::DimY);
587 }
588 else
589 {
590 _reshape_layer.run();
591 }
Michalis Spyroue2503892018-04-23 15:17:31 +0100592 }
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000593
Isabella Gottardi3f217ec2018-02-12 14:59:19 +0000594 if(_is_activationlayer_enabled)
595 {
596 _activationlayer_function.run();
597 }
Isabella Gottardi6acc6ad2018-02-02 17:19:18 +0000598}
Georgios Pinitas72219332018-06-05 14:56:06 +0100599
600void NEGEMMConvolutionLayer::prepare()
601{
602 if(!_is_prepared)
603 {
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100604 ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
Georgios Pinitas72219332018-06-05 14:56:06 +0100605
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100606 // Run weights reshaping and mark original weights tensor as unused
607 _weights_reshaped.allocator()->allocate();
608 _reshape_weights.run();
609 _original_weights->mark_as_unused();
Georgios Pinitas72219332018-06-05 14:56:06 +0100610
Gian Marco Iodice597a8562018-08-01 15:06:06 +0100611 // Prepare GEMM
612 _is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare();
Georgios Pinitas72219332018-06-05 14:56:06 +0100613 if(!_weights_reshaped.is_used())
614 {
615 _weights_reshaped.allocator()->free();
616 }
617
618 _is_prepared = true;
619 }
620}