blob: 2814c67f70542492a78bba98e5b8afb0bc90f287 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyroub91e34c2017-12-20 15:50:55 +000024#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010027#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/ITensor.h"
Georgios Pinitasf72f9362018-01-12 16:29:45 +000031#include "arm_compute/core/NEON/NEAsymm.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/NEON/NEFixedPoint.h"
Michele Di Giorgio45361932019-12-19 13:53:44 +000033#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/Types.h"
35#include "arm_compute/core/Validate.h"
36#include "arm_compute/core/Window.h"
Michele Di Giorgio45361932019-12-19 13:53:44 +000037#include "arm_compute/core/utils/misc/Traits.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
39#include <arm_neon.h>
40#include <cstddef>
41#include <cstdint>
42
Michele Di Giorgiof29d1b72019-10-29 10:58:13 +000043namespace arm_compute
44{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045namespace
46{
Michele Di Giorgioff271922019-07-17 15:59:32 +010047Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
Michele Di Giorgio45361932019-12-19 13:53:44 +000048 const DirectConvolutionLayerOutputStageKernelInfo &info)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000049{
Michele Di Giorgio45361932019-12-19 13:53:44 +000050 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
Anthony Barbiereaefd002018-07-20 17:49:35 +010051 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010052 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Michele Di Giorgio45361932019-12-19 13:53:44 +000053 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::S32, DataType::F32);
Michalis Spyroub91e34c2017-12-20 15:50:55 +000054
55 if(bias != nullptr)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000056 {
Michele Di Giorgio45361932019-12-19 13:53:44 +000057 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010058 ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)));
Michalis Spyroub91e34c2017-12-20 15:50:55 +000059 ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
Michalis Spyrouafa5d812017-11-30 14:25:57 +000060 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +000061
Michele Di Giorgio45361932019-12-19 13:53:44 +000062 if(input->data_type() == DataType::S32)
63 {
64 ARM_COMPUTE_RETURN_ERROR_ON_MSG(output == nullptr, "In-place computation not allowed for quantized output");
65 }
66
Michalis Spyrouafa5d812017-11-30 14:25:57 +000067 // Checks performed when output is configured
68 if((output != nullptr) && (output->total_size() != 0))
69 {
Michele Di Giorgio45361932019-12-19 13:53:44 +000070 if(is_data_type_float(input->data_type()))
Michalis Spyroub91e34c2017-12-20 15:50:55 +000071 {
72 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
73 }
Michele Di Giorgio45361932019-12-19 13:53:44 +000074 else
75 {
76 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
77 }
78 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
79 }
80 else if(input->data_type() == DataType::S32)
81 {
82 // In case of quantized computation and unconfigured output, the output data type must be provided through DirectConvolutionLayerOutputStageKernelInfo
83 ARM_COMPUTE_RETURN_ERROR_ON((info.output_data_type != DataType::QASYMM8) && (info.output_data_type != DataType::QASYMM8_SIGNED));
Michalis Spyrouafa5d812017-11-30 14:25:57 +000084 }
85
Michalis Spyrouafa5d812017-11-30 14:25:57 +000086 return Status{};
87}
88
Michele Di Giorgio45361932019-12-19 13:53:44 +000089std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output,
90 const DirectConvolutionLayerOutputStageKernelInfo &info)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000091{
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010092 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
93
Michele Di Giorgio45361932019-12-19 13:53:44 +000094 const DataType data_type = input->data_type();
95
96 // Auto-initialize output output if required
97 if(output != nullptr)
98 {
99 // Work out expected output data type
100 const DataType output_dt = (data_type == DataType::S32) ? info.output_data_type : data_type;
101 // Output tensor auto initialization if not yet initialized
102 auto_init_if_empty(*output, input->clone()->set_data_type(output_dt));
103 }
104
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000105 bool window_changed = false;
Michele Di Giorgio45361932019-12-19 13:53:44 +0000106 unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(data_type);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000107
108 // Update processed elements when input is S32 (comes from quantization input)
Michele Di Giorgio45361932019-12-19 13:53:44 +0000109 if(data_type == DataType::S32)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000110 {
111 num_elems_processed_per_iteration = 16;
112 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000113
114 // Configure kernel window
115 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
116 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000117
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000118 if(output != nullptr && (output->total_size() != 0))
119 {
120 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000121
122 if(bias == nullptr)
123 {
124 window_changed = update_window_and_padding(win, input_access, output_access);
125 }
126 else
127 {
128 AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
129 window_changed = update_window_and_padding(win, input_access, output_access, bias_access);
130 }
131
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000132 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
133 }
134 else
135 {
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000136 if(bias == nullptr)
137 {
138 window_changed = update_window_and_padding(win, input_access);
139 }
140 else
141 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100142 if(input->data_layout() == DataLayout::NCHW)
143 {
144 AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
145 window_changed = update_window_and_padding(win, input_access, bias_access);
146 }
147 else
148 {
149 AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration);
150 window_changed = update_window_and_padding(win, input_access, bias_access);
151 }
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000152 }
153
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000154 input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape()));
155 }
156
157 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
158 return std::make_pair(err, win);
159}
160
Michele Di Giorgio45361932019-12-19 13:53:44 +0000161template <typename T, bool has_bias>
162typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
163output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
164 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100165{
Michele Di Giorgio45361932019-12-19 13:53:44 +0000166 /** NEON vector tag type. */
167 using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t<T, wrapper::traits::BitWidth::W128>;
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100168
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100169 ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000170 ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
171 ARM_COMPUTE_UNUSED(result_shift);
172 ARM_COMPUTE_UNUSED(result_offset_after_shift);
173
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100174 Iterator in(input, window);
Michele Di Giorgio45361932019-12-19 13:53:44 +0000175 Iterator out(output, window);
176 execute_window_loop(window, [&](const Coordinates & id)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100177 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000178 // Get bias and pointer to input
179 const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
180 auto v_in = wrapper::vloadq(in_ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100181
Michele Di Giorgio45361932019-12-19 13:53:44 +0000182 // Accumulate bias
183 if(has_bias)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100184 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000185 const auto vb = wrapper::vdup_n(*reinterpret_cast<const T *>(bias->ptr_to_element(Coordinates(id.z()))), ExactTagType{});
186 v_in = wrapper::vadd(v_in, vb);
187 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100188
Michele Di Giorgio45361932019-12-19 13:53:44 +0000189 const auto out_ptr = reinterpret_cast<T *>(out.ptr());
190 wrapper::vstore(out_ptr, v_in);
191 },
192 in, out);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100193}
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000194
Michele Di Giorgio45361932019-12-19 13:53:44 +0000195template <typename T, bool has_bias>
196typename std::enable_if<arm_compute::utils::traits::is_floating_point<T>::value, void>::type
197output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
198 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100199{
200 ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
201 ARM_COMPUTE_UNUSED(result_shift);
202 ARM_COMPUTE_UNUSED(result_offset_after_shift);
203
204 Window window_bias = window;
205 window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
206 window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
207 window_bias.set(3, Window::Dimension(0, 0, 0));
208
209 Iterator in(input, window);
210 Iterator bi(bias, window_bias);
Michele Di Giorgio45361932019-12-19 13:53:44 +0000211 Iterator out(output, window);
212 execute_window_loop(window, [&](const Coordinates &)
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100213 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000214 // Get bias and pointer to input
215 const auto in_ptr = reinterpret_cast<const T *>(in.ptr());
216 auto v_in = wrapper::vloadq(in_ptr);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100217
Michele Di Giorgio45361932019-12-19 13:53:44 +0000218 // Accumulate bias
219 if(has_bias)
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100220 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000221 const auto bias_ptr = reinterpret_cast<T *>(bi.ptr());
222 v_in = wrapper::vadd(v_in, wrapper::vloadq(bias_ptr));
223 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100224
Michele Di Giorgio45361932019-12-19 13:53:44 +0000225 const auto out_ptr = reinterpret_cast<T *>(out.ptr());
226 wrapper::vstore(out_ptr, v_in);
227
228 },
229 in, bi, out);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100230}
231
Michele Di Giorgio45361932019-12-19 13:53:44 +0000232// Quantized case
233template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
234void output_stage_nchw(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
235 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000236{
Michele Di Giorgio45361932019-12-19 13:53:44 +0000237 using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
238 using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
239
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000240 const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
Michele Di Giorgio45361932019-12-19 13:53:44 +0000241
242 const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
243 const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000244
245 Iterator in(input, window);
246 Iterator out(output, window);
247
248 execute_window_loop(window, [&](const Coordinates & id)
249 {
250 // Get bias and pointer to input
251 const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
252 int32x4x4_t v_in =
253 {
254 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000255 wrapper::vloadq(in_ptr),
256 wrapper::vloadq(in_ptr + 4),
257 wrapper::vloadq(in_ptr + 8),
258 wrapper::vloadq(in_ptr + 12)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000259 }
260 };
261
262 // Accumulate bias
Michele Di Giorgio45361932019-12-19 13:53:44 +0000263 if(has_bias)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000264 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000265 const auto vb = wrapper::vdup_n(*reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))), TagType{});
266 v_in =
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000267 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000268 {
269 wrapper::vadd(v_in.val[0], vb),
270 wrapper::vadd(v_in.val[1], vb),
271 wrapper::vadd(v_in.val[2], vb),
272 wrapper::vadd(v_in.val[3], vb)
273 }
274 };
275 }
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000276
Michele Di Giorgio45361932019-12-19 13:53:44 +0000277 const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
Michalis Spyrou70d43a32020-06-22 17:05:43 +0100278 wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max, false));
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000279 },
280 in, out);
281}
Michele Di Giorgio45361932019-12-19 13:53:44 +0000282template < typename TOut, bool has_bias, typename std::enable_if < std::is_same<TOut, uint8_t>::value || std::is_same<TOut, int8_t>::value, int >::type = 0 >
283void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
284 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000285{
Michele Di Giorgio45361932019-12-19 13:53:44 +0000286 using VectorType = typename wrapper::traits::neon_bitvector_t<TOut, wrapper::traits::BitWidth::W128>;
287 using TagType = typename wrapper::traits::neon_bitvector_tag_t<TOut, wrapper::traits::BitWidth::W128>;
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000288
289 const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000290
Michele Di Giorgio45361932019-12-19 13:53:44 +0000291 const VectorType min = wrapper::vdup_n(std::numeric_limits<TOut>::lowest(), TagType{});
292 const VectorType max = wrapper::vdup_n(std::numeric_limits<TOut>::max(), TagType{});
Georgios Pinitasa799ce02018-09-12 20:11:34 +0100293
294 Window window_bias = window;
295 window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
296 window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
297 window_bias.set(3, Window::Dimension(0, 0, 0));
298
299 Iterator in(input, window);
300 Iterator bi(bias, window_bias);
301
302 Iterator out(output, window);
Michalis Spyroua4f378d2019-04-26 14:54:54 +0100303 execute_window_loop(window, [&](const Coordinates &)
Georgios Pinitasa799ce02018-09-12 20:11:34 +0100304 {
305 // Get bias and pointer to input
Michele Di Giorgio45361932019-12-19 13:53:44 +0000306 const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
307 int32x4x4_t v_in =
308 {
309 {
310 wrapper::vloadq(in_ptr),
311 wrapper::vloadq(in_ptr + 4),
312 wrapper::vloadq(in_ptr + 8),
313 wrapper::vloadq(in_ptr + 12),
314 }
315 };
Georgios Pinitasa799ce02018-09-12 20:11:34 +0100316
317 // Accumulate bias
Michele Di Giorgio45361932019-12-19 13:53:44 +0000318 if(has_bias)
Georgios Pinitasa799ce02018-09-12 20:11:34 +0100319 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000320 const auto bias_ptr = reinterpret_cast<int32_t *>(bi.ptr());
Georgios Pinitasa799ce02018-09-12 20:11:34 +0100321
Michele Di Giorgio45361932019-12-19 13:53:44 +0000322 wrapper::vadd(v_in.val[0], wrapper::vloadq(bias_ptr));
323 wrapper::vadd(v_in.val[1], wrapper::vloadq(bias_ptr + 4));
324 wrapper::vadd(v_in.val[2], wrapper::vloadq(bias_ptr + 8));
325 wrapper::vadd(v_in.val[3], wrapper::vloadq(bias_ptr + 12));
326 }
327
328 const auto out_ptr = reinterpret_cast<TOut *>(out.ptr());
Michalis Spyrou70d43a32020-06-22 17:05:43 +0100329 wrapper::vstore(out_ptr, finalize_quantization(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max, false));
Georgios Pinitasa799ce02018-09-12 20:11:34 +0100330 },
331 in, bi, out);
332}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100333} // namespace
334
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000335NEDirectConvolutionLayerOutputStageKernel::NEDirectConvolutionLayerOutputStageKernel()
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000336 : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr), _result_fixedpoint_multiplier(0), _result_shift(0), _result_offset_after_shift(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100337{
338}
339
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000340void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ITensor *bias, ITensor *output,
Michele Di Giorgio45361932019-12-19 13:53:44 +0000341 const DirectConvolutionLayerOutputStageKernelInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000343 // Perform validation step
Michele Di Giorgio45361932019-12-19 13:53:44 +0000344 ARM_COMPUTE_ERROR_ON_NULLPTR(input);
345 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(), info));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100346
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000347 _func = nullptr;
348 _bias = bias;
349 _input = input;
Michele Di Giorgio45361932019-12-19 13:53:44 +0000350 _output = (output != nullptr) ? output : input;
351 _result_fixedpoint_multiplier = info.result_fixedpoint_multiplier;
352 _result_shift = info.result_shift;
353 _result_offset_after_shift = info.result_offset_after_shift;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100354
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100355 // Configure kernel window
Michele Di Giorgio45361932019-12-19 13:53:44 +0000356 auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info(), info);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000357 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
358 INEKernel::configure(win_config.second);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100359
Michele Di Giorgio45361932019-12-19 13:53:44 +0000360 const bool has_bias = bias != nullptr;
361 const bool is_qasymm8_signed = (output != nullptr) ? is_data_type_quantized_asymmetric_signed(output->info()->data_type()) : false;
Gian Marco Iodice618493d2018-11-27 16:38:33 +0000362
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100363 // Set appropriate function
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100364 if(input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100365 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100366 switch(input->info()->data_type())
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100367 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100368 case DataType::S32:
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000369 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000370 if(is_qasymm8_signed)
371 {
372 _func = (has_bias) ? &output_stage_nchw<int8_t, true> : &output_stage_nchw<int8_t, false>;
373 }
374 else
375 {
376 _func = (has_bias) ? &output_stage_nchw<uint8_t, true> : &output_stage_nchw<uint8_t, false>;
377 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100378 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100379 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000380#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100381 case DataType::F16:
382 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000383 _func = (has_bias) ? &output_stage_nchw<float16_t, true> : &output_stage_nchw<float16_t, false>;
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100384 break;
385 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000386#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100387 case DataType::F32:
388 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000389 _func = (has_bias) ? &output_stage_nchw<float, true> : &output_stage_nchw<float, false>;
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100390 break;
391 }
392 default:
393 {
394 ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
395 }
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100396 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100397 }
398 else
399 {
400 switch(input->info()->data_type())
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100401 {
Georgios Pinitasa799ce02018-09-12 20:11:34 +0100402 case DataType::S32:
403 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000404 if(is_qasymm8_signed)
405 {
406 _func = (has_bias) ? &output_stage_nhwc<int8_t, true> : &output_stage_nhwc<int8_t, false>;
407 }
408 else
409 {
410 _func = (has_bias) ? &output_stage_nhwc<uint8_t, true> : &output_stage_nhwc<uint8_t, false>;
411 }
Georgios Pinitasa799ce02018-09-12 20:11:34 +0100412 break;
413 }
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100414#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
415 case DataType::F16:
416 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000417 _func = (has_bias) ? &output_stage_nhwc<float16_t, true> : &output_stage_nhwc<float16_t, false>;
Georgios Pinitas20c246a2018-09-12 16:45:53 +0100418 break;
419 }
420#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100421 case DataType::F32:
422 {
Michele Di Giorgio45361932019-12-19 13:53:44 +0000423 _func = (has_bias) ? &output_stage_nhwc<float, true> : &output_stage_nhwc<float, false>;
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100424 break;
425 }
426 default:
427 {
428 ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
429 }
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100430 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100431 }
432}
433
Michele Di Giorgioff271922019-07-17 15:59:32 +0100434Status NEDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output,
Michele Di Giorgio45361932019-12-19 13:53:44 +0000435 const DirectConvolutionLayerOutputStageKernelInfo &info)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000436{
Michele Di Giorgio45361932019-12-19 13:53:44 +0000437 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output, info));
438 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
439 bias == nullptr ? nullptr : bias->clone().get(),
440 output == nullptr ? nullptr : output->clone().get(),
441 info)
442 .first);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000443
444 return Status{};
445}
446
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000447void NEDirectConvolutionLayerOutputStageKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100448{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100449 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100450 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
451 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
452 ARM_COMPUTE_ERROR_ON(_func == nullptr);
453
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000454 (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100455}
Michele Di Giorgiof29d1b72019-10-29 10:58:13 +0000456} // namespace arm_compute