blob: e4cd4d04658f756cc210e582e1772b5d23c44628 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Georgios Pinitasf72f9362018-01-12 16:29:45 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyroub91e34c2017-12-20 15:50:55 +000024#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
Georgios Pinitasf72f9362018-01-12 16:29:45 +000030#include "arm_compute/core/NEON/NEAsymm.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/Types.h"
33#include "arm_compute/core/Validate.h"
34#include "arm_compute/core/Window.h"
35
36#include <arm_neon.h>
37#include <cstddef>
38#include <cstdint>
39
40using namespace arm_compute;
41
42namespace
43{
Michalis Spyrouafa5d812017-11-30 14:25:57 +000044Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
45{
Michalis Spyroub91e34c2017-12-20 15:50:55 +000046 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010047 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010048 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8,
49 DataType::F16,
Georgios Pinitasf72f9362018-01-12 16:29:45 +000050 DataType::QS32, DataType::S32, DataType::F32);
Michalis Spyroub91e34c2017-12-20 15:50:55 +000051
52 if(bias != nullptr)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000053 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010054 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::F16, DataType::QS32, DataType::S32, DataType::F32);
Michalis Spyroub91e34c2017-12-20 15:50:55 +000055
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010056 if(is_data_type_quantized_asymmetric(input->data_type()))
Georgios Pinitas19d05472018-02-01 16:44:12 +000057 {
58 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32);
59 }
Michalis Spyroub91e34c2017-12-20 15:50:55 +000060 else
61 {
62 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
63 }
64
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010065 ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)));
Michalis Spyroub91e34c2017-12-20 15:50:55 +000066 ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
Michalis Spyrouafa5d812017-11-30 14:25:57 +000067 }
68 else
69 {
Georgios Pinitas19d05472018-02-01 16:44:12 +000070 ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_float(input->data_type()), "Calling output stage kernel with floating point arguments");
Michalis Spyrouafa5d812017-11-30 14:25:57 +000071 }
72
Michalis Spyrouafa5d812017-11-30 14:25:57 +000073 // Checks performed when output is configured
74 if((output != nullptr) && (output->total_size() != 0))
75 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010076 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QASYMM8, DataType::F32);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010077 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
78
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010079 if(is_data_type_quantized_asymmetric(output->data_type()))
Georgios Pinitasf72f9362018-01-12 16:29:45 +000080 {
81 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && output->data_type() != DataType::QASYMM8, "Wrong data type for bias");
Michalis Spyroub91e34c2017-12-20 15:50:55 +000082 }
83 else
84 {
85 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
86 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +000087 }
88
Michalis Spyrouafa5d812017-11-30 14:25:57 +000089 return Status{};
90}
91
92std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output)
93{
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010094 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
95
Georgios Pinitasf72f9362018-01-12 16:29:45 +000096 bool window_changed = false;
97 unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type());
98
99 // Update processed elements when input is S32 (comes from quantization input)
100 if(input->data_type() == DataType::S32)
101 {
102 num_elems_processed_per_iteration = 16;
103 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000104
105 // Configure kernel window
106 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
107 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000108
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000109 if(output != nullptr && (output->total_size() != 0))
110 {
111 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000112
113 if(bias == nullptr)
114 {
115 window_changed = update_window_and_padding(win, input_access, output_access);
116 }
117 else
118 {
119 AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
120 window_changed = update_window_and_padding(win, input_access, output_access, bias_access);
121 }
122
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000123 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
124 }
125 else
126 {
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000127 if(bias == nullptr)
128 {
129 window_changed = update_window_and_padding(win, input_access);
130 }
131 else
132 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100133 if(input->data_layout() == DataLayout::NCHW)
134 {
135 AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
136 window_changed = update_window_and_padding(win, input_access, bias_access);
137 }
138 else
139 {
140 AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration);
141 window_changed = update_window_and_padding(win, input_access, bias_access);
142 }
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000143 }
144
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000145 input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape()));
146 }
147
148 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
149 return std::make_pair(err, win);
150}
151
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100152// Internal load
153inline float32x4_t internal_vld1q(const float *in)
154{
155 return vld1q_f32(in);
156}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100157
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100158// Internal store
159inline void internal_vst1q(float *p, const float32x4_t &v)
160{
161 vst1q_f32(p, v);
162}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100163
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100164// Internal vdup
165inline float32x4_t internal_vdupq_n(float v)
166{
167 return vdupq_n_f32(v);
168}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100169
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100170// Internal vadd
171inline float32x4_t internal_vqaddq(const float32x4_t &x, const float32x4_t &y)
172{
173 return vaddq_f32(x, y);
174}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100175
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000176#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100177inline float16x8_t internal_vld1q(const float16_t *in)
178{
179 return vld1q_f16(in);
180}
181inline void internal_vst1q(float16_t *p, const float16x8_t &v)
182{
183 vst1q_f16(p, v);
184}
185inline float16x8_t internal_vdupq_n(float16_t v)
186{
187 return vdupq_n_f16(v);
188}
189inline float16x8_t internal_vqaddq(const float16x8_t &x, const float16x8_t &y)
190{
191 return vaddq_f16(x, y);
192}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000193#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100194
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000195template <typename T1, typename T2, bool in_place, bool has_bias>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000196void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
197 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100198{
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100199 ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000200 ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
201 ARM_COMPUTE_UNUSED(result_shift);
202 ARM_COMPUTE_UNUSED(result_offset_after_shift);
203
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100204 Iterator in(input, window);
205
206 if(in_place) // In place accumulate
207 {
208 execute_window_loop(window, [&](const Coordinates & id)
209 {
210 // Get bias and pointer to input
211 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100212
213 // Accumulate bias
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000214 if(has_bias)
215 {
216 const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
217 internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
218 }
219 else
220 {
221 internal_vst1q(in_ptr, internal_vld1q(in_ptr));
222 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100223 },
224 in);
225 }
226 else // Out of place accumulate
227 {
228 Iterator out(output, window);
229 execute_window_loop(window, [&](const Coordinates & id)
230 {
231 // Get bias and pointer to input
232 const auto in_ptr = reinterpret_cast<const T1 *>(in.ptr());
233 const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100234
235 // Accumulate bias
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000236 if(has_bias)
237 {
238 const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
239 internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
240 }
241 else
242 {
243 internal_vst1q(out_ptr, internal_vld1q(in_ptr));
244 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245 },
246 in, out);
247 }
248}
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000249
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100250template <typename T1, typename T2, bool in_place, bool has_bias>
251void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
252 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
253{
254 ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
255 ARM_COMPUTE_UNUSED(result_shift);
256 ARM_COMPUTE_UNUSED(result_offset_after_shift);
257
258 Window window_bias = window;
259 window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
260 window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
261 window_bias.set(3, Window::Dimension(0, 0, 0));
262
263 Iterator in(input, window);
264 Iterator bi(bias, window_bias);
265
266 if(in_place) // In place accumulate
267 {
268 execute_window_loop(window, [&](const Coordinates & id)
269 {
270 // Get bias and pointer to input
271 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
272 const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
273
274 // Accumulate bias
275 if(has_bias)
276 {
277 internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
278 }
279 else
280 {
281 internal_vst1q(in_ptr, internal_vld1q(in_ptr));
282 }
283 },
284 in, bi);
285 }
286 else // Out of place accumulate
287 {
288 Iterator out(output, window);
289 execute_window_loop(window, [&](const Coordinates & id)
290 {
291 // Get bias and pointer to input
292 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
293 const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
294 const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
295
296 // Accumulate bias
297 if(has_bias)
298 {
299 internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
300 }
301 else
302 {
303 internal_vst1q(out_ptr, internal_vld1q(in_ptr));
304 }
305 },
306 in, bi);
307 }
308}
309
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000310// QASYMM8 specializations
311template <>
312void output_stage<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
313 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
314{
315 const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
316 uint8x16_t min = vdupq_n_u8(0);
317 uint8x16_t max = vdupq_n_u8(255);
318
319 Iterator in(input, window);
320 Iterator out(output, window);
321
322 execute_window_loop(window, [&](const Coordinates & id)
323 {
324 // Get bias and pointer to input
325 const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
326 int32x4x4_t v_in =
327 {
328 {
329 vld1q_s32(in_ptr),
330 vld1q_s32(in_ptr + 4),
331 vld1q_s32(in_ptr + 8),
332 vld1q_s32(in_ptr + 12)
333 }
334 };
335
336 // Accumulate bias
337 const auto vb = vdupq_n_s32(*reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))));
338 v_in =
339 {
340 {
341 vaddq_s32(v_in.val[0], vb),
342 vaddq_s32(v_in.val[1], vb),
343 vaddq_s32(v_in.val[2], vb),
344 vaddq_s32(v_in.val[3], vb)
345 }
346 };
347
348 const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
349 vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
350 },
351 in, out);
352}
353template <>
354void output_stage<int32_t, uint8_t, false, false>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
355 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
356{
357 ARM_COMPUTE_UNUSED(bias);
358
359 const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
360 uint8x16_t min = vdupq_n_u8(0);
361 uint8x16_t max = vdupq_n_u8(255);
362
363 Iterator in(input, window);
364 Iterator out(output, window);
365 execute_window_loop(window, [&](const Coordinates & id)
366 {
367 // Get bias and pointer to input
368 const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
369 int32x4x4_t v_in =
370 {
371 {
372 vld1q_s32(in_ptr),
373 vld1q_s32(in_ptr + 4),
374 vld1q_s32(in_ptr + 8),
375 vld1q_s32(in_ptr + 12)
376 }
377 };
378
379 const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
380 vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
381 },
382 in, out);
383}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100384} // namespace
385
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000386NEDirectConvolutionLayerOutputStageKernel::NEDirectConvolutionLayerOutputStageKernel()
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000387 : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr), _result_fixedpoint_multiplier(0), _result_shift(0), _result_offset_after_shift(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100388{
389}
390
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000391void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ITensor *bias, ITensor *output,
392 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100393{
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000394 ARM_COMPUTE_ERROR_ON_NULLPTR(input);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000395
Georgios Pinitas0223a782017-12-12 11:44:44 +0000396 // Auto-initialize output output if required
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100397 if(output != nullptr)
398 {
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000399 // Work out expected output data type
400 const DataType output_dt = (input->info()->data_type() == DataType::S32) ? DataType::QASYMM8 : input->info()->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000401 // Output tensor auto initialization if not yet initialized
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000402 auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(output_dt));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100403 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000404
405 // Perform validation step
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000406 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100407
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000408 _func = nullptr;
409 _bias = bias;
410 _input = input;
411 _output = output;
412 _result_fixedpoint_multiplier = result_fixedpoint_multiplier;
413 _result_shift = result_shift;
414 _result_offset_after_shift = result_offset_after_shift;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100415
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100416 // Configure kernel window
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000417 auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000418 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
419 INEKernel::configure(win_config.second);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100420
421 // Set appropriate function
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100422 if(input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100423 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100424 switch(input->info()->data_type())
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100425 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100426 case DataType::S32:
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000427 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100428 _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>;
429 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100430 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000431#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100432 case DataType::F16:
433 {
434 _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>;
435 break;
436 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000437#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100438 case DataType::F32:
439 {
440 _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>;
441 break;
442 }
443 default:
444 {
445 ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
446 }
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100447 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100448 }
449 else
450 {
451 switch(input->info()->data_type())
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100452 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100453 case DataType::F32:
454 {
455 _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
456 break;
457 }
458 default:
459 {
460 ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
461 }
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100462 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100463 }
464}
465
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000466Status NEDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000467{
468 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output));
Anthony Barbierde014682018-07-03 15:10:48 +0100469 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias == nullptr ? nullptr : bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000470
471 return Status{};
472}
473
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000474void NEDirectConvolutionLayerOutputStageKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100475{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100476 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100477 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
478 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
479 ARM_COMPUTE_ERROR_ON(_func == nullptr);
480
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000481 (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100482}