blob: eefbd98dd8c5816c5cf0dd9005a9b17b0a3ff9a8 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Georgios Pinitasf72f9362018-01-12 16:29:45 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyroub91e34c2017-12-20 15:50:55 +000024#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010027#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/ITensor.h"
Georgios Pinitasf72f9362018-01-12 16:29:45 +000031#include "arm_compute/core/NEON/NEAsymm.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010032#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/Types.h"
34#include "arm_compute/core/Validate.h"
35#include "arm_compute/core/Window.h"
36
37#include <arm_neon.h>
38#include <cstddef>
39#include <cstdint>
40
41using namespace arm_compute;
42
43namespace
44{
Michalis Spyrouafa5d812017-11-30 14:25:57 +000045Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
46{
Anthony Barbiereaefd002018-07-20 17:49:35 +010047 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010048 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010049 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8,
50 DataType::F16,
Vidhya Sudhan Loganathanf4cb81b2018-07-04 15:13:14 +010051 DataType::S32, DataType::F32);
Michalis Spyroub91e34c2017-12-20 15:50:55 +000052
53 if(bias != nullptr)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000054 {
Vidhya Sudhan Loganathanf4cb81b2018-07-04 15:13:14 +010055 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::F16, DataType::S32, DataType::F32);
Michalis Spyroub91e34c2017-12-20 15:50:55 +000056
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010057 if(is_data_type_quantized_asymmetric(input->data_type()))
Georgios Pinitas19d05472018-02-01 16:44:12 +000058 {
59 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32);
60 }
Michalis Spyroub91e34c2017-12-20 15:50:55 +000061 else
62 {
63 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
64 }
65
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010066 ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)));
Michalis Spyroub91e34c2017-12-20 15:50:55 +000067 ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
Michalis Spyrouafa5d812017-11-30 14:25:57 +000068 }
69 else
70 {
Georgios Pinitas19d05472018-02-01 16:44:12 +000071 ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_float(input->data_type()), "Calling output stage kernel with floating point arguments");
Michalis Spyrouafa5d812017-11-30 14:25:57 +000072 }
73
Michalis Spyrouafa5d812017-11-30 14:25:57 +000074 // Checks performed when output is configured
75 if((output != nullptr) && (output->total_size() != 0))
76 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010077 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010078 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
79
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010080 if(is_data_type_quantized_asymmetric(output->data_type()))
Georgios Pinitasf72f9362018-01-12 16:29:45 +000081 {
82 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && output->data_type() != DataType::QASYMM8, "Wrong data type for bias");
Michalis Spyroub91e34c2017-12-20 15:50:55 +000083 }
84 else
85 {
86 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
87 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +000088 }
89
Michalis Spyrouafa5d812017-11-30 14:25:57 +000090 return Status{};
91}
92
93std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output)
94{
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010095 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
96
Georgios Pinitasf72f9362018-01-12 16:29:45 +000097 bool window_changed = false;
98 unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type());
99
100 // Update processed elements when input is S32 (comes from quantization input)
101 if(input->data_type() == DataType::S32)
102 {
103 num_elems_processed_per_iteration = 16;
104 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000105
106 // Configure kernel window
107 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
108 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000109
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000110 if(output != nullptr && (output->total_size() != 0))
111 {
112 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000113
114 if(bias == nullptr)
115 {
116 window_changed = update_window_and_padding(win, input_access, output_access);
117 }
118 else
119 {
120 AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
121 window_changed = update_window_and_padding(win, input_access, output_access, bias_access);
122 }
123
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000124 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
125 }
126 else
127 {
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000128 if(bias == nullptr)
129 {
130 window_changed = update_window_and_padding(win, input_access);
131 }
132 else
133 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100134 if(input->data_layout() == DataLayout::NCHW)
135 {
136 AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
137 window_changed = update_window_and_padding(win, input_access, bias_access);
138 }
139 else
140 {
141 AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration);
142 window_changed = update_window_and_padding(win, input_access, bias_access);
143 }
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000144 }
145
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000146 input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape()));
147 }
148
149 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
150 return std::make_pair(err, win);
151}
152
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100153// Internal load
154inline float32x4_t internal_vld1q(const float *in)
155{
156 return vld1q_f32(in);
157}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100158
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100159// Internal store
160inline void internal_vst1q(float *p, const float32x4_t &v)
161{
162 vst1q_f32(p, v);
163}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100164
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100165// Internal vdup
166inline float32x4_t internal_vdupq_n(float v)
167{
168 return vdupq_n_f32(v);
169}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100170
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100171// Internal vadd
172inline float32x4_t internal_vqaddq(const float32x4_t &x, const float32x4_t &y)
173{
174 return vaddq_f32(x, y);
175}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100176
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000177#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100178inline float16x8_t internal_vld1q(const float16_t *in)
179{
180 return vld1q_f16(in);
181}
182inline void internal_vst1q(float16_t *p, const float16x8_t &v)
183{
184 vst1q_f16(p, v);
185}
186inline float16x8_t internal_vdupq_n(float16_t v)
187{
188 return vdupq_n_f16(v);
189}
190inline float16x8_t internal_vqaddq(const float16x8_t &x, const float16x8_t &y)
191{
192 return vaddq_f16(x, y);
193}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000194#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100195
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000196template <typename T1, typename T2, bool in_place, bool has_bias>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000197void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
198 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100199{
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100200 ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000201 ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
202 ARM_COMPUTE_UNUSED(result_shift);
203 ARM_COMPUTE_UNUSED(result_offset_after_shift);
204
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100205 Iterator in(input, window);
206
207 if(in_place) // In place accumulate
208 {
209 execute_window_loop(window, [&](const Coordinates & id)
210 {
211 // Get bias and pointer to input
212 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100213
214 // Accumulate bias
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000215 if(has_bias)
216 {
217 const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
218 internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
219 }
220 else
221 {
222 internal_vst1q(in_ptr, internal_vld1q(in_ptr));
223 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100224 },
225 in);
226 }
227 else // Out of place accumulate
228 {
229 Iterator out(output, window);
230 execute_window_loop(window, [&](const Coordinates & id)
231 {
232 // Get bias and pointer to input
233 const auto in_ptr = reinterpret_cast<const T1 *>(in.ptr());
234 const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100235
236 // Accumulate bias
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000237 if(has_bias)
238 {
239 const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
240 internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
241 }
242 else
243 {
244 internal_vst1q(out_ptr, internal_vld1q(in_ptr));
245 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100246 },
247 in, out);
248 }
249}
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000250
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100251template <typename T1, typename T2, bool in_place, bool has_bias>
252void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
253 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
254{
255 ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
256 ARM_COMPUTE_UNUSED(result_shift);
257 ARM_COMPUTE_UNUSED(result_offset_after_shift);
258
259 Window window_bias = window;
260 window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
261 window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
262 window_bias.set(3, Window::Dimension(0, 0, 0));
263
264 Iterator in(input, window);
265 Iterator bi(bias, window_bias);
266
267 if(in_place) // In place accumulate
268 {
269 execute_window_loop(window, [&](const Coordinates & id)
270 {
271 // Get bias and pointer to input
272 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
273 const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
274
275 // Accumulate bias
276 if(has_bias)
277 {
278 internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
279 }
280 else
281 {
282 internal_vst1q(in_ptr, internal_vld1q(in_ptr));
283 }
284 },
285 in, bi);
286 }
287 else // Out of place accumulate
288 {
289 Iterator out(output, window);
290 execute_window_loop(window, [&](const Coordinates & id)
291 {
292 // Get bias and pointer to input
293 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
294 const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
295 const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
296
297 // Accumulate bias
298 if(has_bias)
299 {
300 internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
301 }
302 else
303 {
304 internal_vst1q(out_ptr, internal_vld1q(in_ptr));
305 }
306 },
307 in, bi);
308 }
309}
310
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000311// QASYMM8 specializations
312template <>
313void output_stage<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
314 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
315{
316 const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
317 uint8x16_t min = vdupq_n_u8(0);
318 uint8x16_t max = vdupq_n_u8(255);
319
320 Iterator in(input, window);
321 Iterator out(output, window);
322
323 execute_window_loop(window, [&](const Coordinates & id)
324 {
325 // Get bias and pointer to input
326 const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
327 int32x4x4_t v_in =
328 {
329 {
330 vld1q_s32(in_ptr),
331 vld1q_s32(in_ptr + 4),
332 vld1q_s32(in_ptr + 8),
333 vld1q_s32(in_ptr + 12)
334 }
335 };
336
337 // Accumulate bias
338 const auto vb = vdupq_n_s32(*reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))));
339 v_in =
340 {
341 {
342 vaddq_s32(v_in.val[0], vb),
343 vaddq_s32(v_in.val[1], vb),
344 vaddq_s32(v_in.val[2], vb),
345 vaddq_s32(v_in.val[3], vb)
346 }
347 };
348
349 const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
350 vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
351 },
352 in, out);
353}
354template <>
355void output_stage<int32_t, uint8_t, false, false>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
356 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
357{
358 ARM_COMPUTE_UNUSED(bias);
359
360 const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
361 uint8x16_t min = vdupq_n_u8(0);
362 uint8x16_t max = vdupq_n_u8(255);
363
364 Iterator in(input, window);
365 Iterator out(output, window);
366 execute_window_loop(window, [&](const Coordinates & id)
367 {
368 // Get bias and pointer to input
369 const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
370 int32x4x4_t v_in =
371 {
372 {
373 vld1q_s32(in_ptr),
374 vld1q_s32(in_ptr + 4),
375 vld1q_s32(in_ptr + 8),
376 vld1q_s32(in_ptr + 12)
377 }
378 };
379
380 const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
381 vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
382 },
383 in, out);
384}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100385} // namespace
386
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000387NEDirectConvolutionLayerOutputStageKernel::NEDirectConvolutionLayerOutputStageKernel()
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000388 : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr), _result_fixedpoint_multiplier(0), _result_shift(0), _result_offset_after_shift(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100389{
390}
391
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000392void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ITensor *bias, ITensor *output,
393 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100394{
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000395 ARM_COMPUTE_ERROR_ON_NULLPTR(input);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000396
Georgios Pinitas0223a782017-12-12 11:44:44 +0000397 // Auto-initialize output output if required
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100398 if(output != nullptr)
399 {
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000400 // Work out expected output data type
401 const DataType output_dt = (input->info()->data_type() == DataType::S32) ? DataType::QASYMM8 : input->info()->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000402 // Output tensor auto initialization if not yet initialized
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000403 auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(output_dt));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100404 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000405
406 // Perform validation step
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000407 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100408
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000409 _func = nullptr;
410 _bias = bias;
411 _input = input;
412 _output = output;
413 _result_fixedpoint_multiplier = result_fixedpoint_multiplier;
414 _result_shift = result_shift;
415 _result_offset_after_shift = result_offset_after_shift;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100416
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100417 // Configure kernel window
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000418 auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000419 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
420 INEKernel::configure(win_config.second);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100421
422 // Set appropriate function
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100423 if(input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100424 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100425 switch(input->info()->data_type())
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100426 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100427 case DataType::S32:
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000428 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100429 _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>;
430 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100431 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000432#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100433 case DataType::F16:
434 {
435 _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>;
436 break;
437 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000438#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100439 case DataType::F32:
440 {
441 _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>;
442 break;
443 }
444 default:
445 {
446 ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
447 }
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100448 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100449 }
450 else
451 {
452 switch(input->info()->data_type())
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100453 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100454 case DataType::F32:
455 {
456 _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
457 break;
458 }
459 default:
460 {
461 ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
462 }
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100463 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100464 }
465}
466
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000467Status NEDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000468{
469 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output));
Anthony Barbierde014682018-07-03 15:10:48 +0100470 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias == nullptr ? nullptr : bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000471
472 return Status{};
473}
474
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000475void NEDirectConvolutionLayerOutputStageKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100476{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100477 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100478 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
479 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
480 ARM_COMPUTE_ERROR_ON(_func == nullptr);
481
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000482 (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100483}