blob: 37a380428999c7b0abb12d8b26427b51658ff580 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Georgios Pinitasf72f9362018-01-12 16:29:45 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyroub91e34c2017-12-20 15:50:55 +000024#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerOutputStageKernel.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010025
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/ITensor.h"
Georgios Pinitasf72f9362018-01-12 16:29:45 +000030#include "arm_compute/core/NEON/NEAsymm.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010031#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/Types.h"
33#include "arm_compute/core/Validate.h"
34#include "arm_compute/core/Window.h"
35
36#include <arm_neon.h>
37#include <cstddef>
38#include <cstdint>
39
40using namespace arm_compute;
41
42namespace
43{
Michalis Spyrouafa5d812017-11-30 14:25:57 +000044Status validate_arguments(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
45{
Michalis Spyroub91e34c2017-12-20 15:50:55 +000046 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010047 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Georgios Pinitasf72f9362018-01-12 16:29:45 +000048 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QASYMM8,
49 DataType::QS16, DataType::F16,
50 DataType::QS32, DataType::S32, DataType::F32);
Michalis Spyroub91e34c2017-12-20 15:50:55 +000051
52 if(bias != nullptr)
Michalis Spyrouafa5d812017-11-30 14:25:57 +000053 {
Georgios Pinitasf72f9362018-01-12 16:29:45 +000054 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::QS32, DataType::S32, DataType::F32);
Michalis Spyroub91e34c2017-12-20 15:50:55 +000055
Georgios Pinitasf72f9362018-01-12 16:29:45 +000056 if(is_data_type_fixed_point(input->data_type()))
Michalis Spyroub91e34c2017-12-20 15:50:55 +000057 {
58 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS8 && bias->data_type() != DataType::QS8, "Wrong data type for bias");
59 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS16 && bias->data_type() != DataType::QS8, "Wrong data type for bias");
60 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS32 && bias->data_type() != DataType::QS16, "Wrong data type for bias");
Georgios Pinitasf72f9362018-01-12 16:29:45 +000061 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, bias);
Michalis Spyroub91e34c2017-12-20 15:50:55 +000062 }
Georgios Pinitas19d05472018-02-01 16:44:12 +000063 else if(is_data_type_quantized_asymmetric(input->data_type()))
64 {
65 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32);
66 }
Michalis Spyroub91e34c2017-12-20 15:50:55 +000067 else
68 {
69 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, bias);
70 }
71
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010072 ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != input->dimension(get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL)));
Michalis Spyroub91e34c2017-12-20 15:50:55 +000073 ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1);
Michalis Spyrouafa5d812017-11-30 14:25:57 +000074 }
75 else
76 {
Georgios Pinitas19d05472018-02-01 16:44:12 +000077 ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_float(input->data_type()), "Calling output stage kernel with floating point arguments");
Michalis Spyrouafa5d812017-11-30 14:25:57 +000078 }
79
Michalis Spyrouafa5d812017-11-30 14:25:57 +000080 // Checks performed when output is configured
81 if((output != nullptr) && (output->total_size() != 0))
82 {
Georgios Pinitasf72f9362018-01-12 16:29:45 +000083 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::QASYMM8, DataType::QS16, DataType::F32);
Giorgio Arena1ed1fc62018-03-26 16:20:05 +010084 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
85
Georgios Pinitasf72f9362018-01-12 16:29:45 +000086 if(is_data_type_fixed_point(input->data_type()))
Michalis Spyroub91e34c2017-12-20 15:50:55 +000087 {
88 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS8 && output->data_type() != DataType::QS8, "Wrong data type for output");
89 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS16 && output->data_type() != DataType::QS8, "Wrong data type for output");
90 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::QS32 && output->data_type() != DataType::QS16, "Wrong data type for output");
Georgios Pinitasf72f9362018-01-12 16:29:45 +000091 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT_POSITION(input, output);
92 }
93 else if(is_data_type_quantized_asymmetric(output->data_type()))
94 {
95 ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_type() == DataType::S32 && output->data_type() != DataType::QASYMM8, "Wrong data type for bias");
Michalis Spyroub91e34c2017-12-20 15:50:55 +000096 }
97 else
98 {
99 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
100 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000101 }
102
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000103 return Status{};
104}
105
106std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *bias, ITensorInfo *output)
107{
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100108 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
109
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000110 bool window_changed = false;
111 unsigned int num_elems_processed_per_iteration = 16 / element_size_from_data_type(input->data_type());
112
113 // Update processed elements when input is S32 (comes from quantization input)
114 if(input->data_type() == DataType::S32)
115 {
116 num_elems_processed_per_iteration = 16;
117 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000118
119 // Configure kernel window
120 Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration));
121 AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration);
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000122
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000123 if(output != nullptr && (output->total_size() != 0))
124 {
125 AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000126
127 if(bias == nullptr)
128 {
129 window_changed = update_window_and_padding(win, input_access, output_access);
130 }
131 else
132 {
133 AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
134 window_changed = update_window_and_padding(win, input_access, output_access, bias_access);
135 }
136
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000137 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
138 }
139 else
140 {
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000141 if(bias == nullptr)
142 {
143 window_changed = update_window_and_padding(win, input_access);
144 }
145 else
146 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100147 if(input->data_layout() == DataLayout::NCHW)
148 {
149 AccessWindowStatic bias_access(bias, 0, 0, bias->dimension(0), bias->dimension(1));
150 window_changed = update_window_and_padding(win, input_access, bias_access);
151 }
152 else
153 {
154 AccessWindowHorizontal bias_access(bias, 0, num_elems_processed_per_iteration);
155 window_changed = update_window_and_padding(win, input_access, bias_access);
156 }
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000157 }
158
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000159 input_access.set_valid_region(win, ValidRegion(Coordinates(), input->tensor_shape()));
160 }
161
162 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
163 return std::make_pair(err, win);
164}
165
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100166// Internal load
167inline float32x4_t internal_vld1q(const float *in)
168{
169 return vld1q_f32(in);
170}
171inline qint8x16_t internal_vld1q(const qint8_t *in)
172{
173 return vld1q_qs8(in);
174}
175inline qint16x8_t internal_vld1q(const qint16_t *in)
176{
177 return vld1q_qs16(in);
178}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100179inline qint32x4_t internal_vld1q(const qint32_t *in)
180{
181 return vld1q_s32(in);
182}
183
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100184// Internal store
185inline void internal_vst1q(float *p, const float32x4_t &v)
186{
187 vst1q_f32(p, v);
188}
189inline void internal_vst1q(qint8_t *p, const qint8x16_t &v)
190{
191 vst1q_qs8(p, v);
192}
193inline void internal_vst1q(qint8_t *p, const qint16x8_t &v)
194{
195 vst1_qs8(p, vqmovn_s16(v));
196}
197inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
198{
199 vst1q_qs16(p, v);
200}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100201inline void internal_vst1q(qint32_t *p, const qint32x4_t &v)
202{
203 vst1q_s32(p, v);
204}
205
206inline void internal_vst1q(qint16_t *p, const qint32x4_t &v)
207{
208 vst1_qs16(p, vqmovn_qs32(v));
209}
210
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100211// Internal vdup
212inline float32x4_t internal_vdupq_n(float v)
213{
214 return vdupq_n_f32(v);
215}
216inline qint8x16_t internal_vdupq_n(qint8_t v)
217{
218 return vdupq_n_qs8(v);
219}
220inline qint16x8_t internal_vdupq_n(qint16_t v)
221{
222 return vdupq_n_qs16(v);
223}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100224inline qint32x4_t internal_vdupq_n(qint32_t v)
225{
226 return vdupq_n_qs32(v);
227}
228
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100229// Internal vadd
230inline float32x4_t internal_vqaddq(const float32x4_t &x, const float32x4_t &y)
231{
232 return vaddq_f32(x, y);
233}
234inline qint8x16_t internal_vqaddq(const qint8x16_t &x, const qint8x16_t &y)
235{
236 return vqaddq_qs8(x, y);
237}
238inline qint16x8_t internal_vqaddq(const qint16x8_t &x, const qint16x8_t &y)
239{
240 return vqaddq_qs16(x, y);
241}
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100242inline qint32x4_t internal_vqaddq(const qint32x4_t &x, const qint32x4_t &y)
243{
244 return vqaddq_qs32(x, y);
245}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100246
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000247#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100248inline float16x8_t internal_vld1q(const float16_t *in)
249{
250 return vld1q_f16(in);
251}
252inline void internal_vst1q(float16_t *p, const float16x8_t &v)
253{
254 vst1q_f16(p, v);
255}
256inline float16x8_t internal_vdupq_n(float16_t v)
257{
258 return vdupq_n_f16(v);
259}
260inline float16x8_t internal_vqaddq(const float16x8_t &x, const float16x8_t &y)
261{
262 return vaddq_f16(x, y);
263}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000264#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100265
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000266template <typename T1, typename T2, bool in_place, bool has_bias>
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000267void output_stage(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
268 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100269{
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100270 ARM_COMPUTE_ERROR_ON(input->info()->data_layout() == DataLayout::UNKNOWN);
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000271 ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
272 ARM_COMPUTE_UNUSED(result_shift);
273 ARM_COMPUTE_UNUSED(result_offset_after_shift);
274
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100275 Iterator in(input, window);
276
277 if(in_place) // In place accumulate
278 {
279 execute_window_loop(window, [&](const Coordinates & id)
280 {
281 // Get bias and pointer to input
282 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100283
284 // Accumulate bias
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000285 if(has_bias)
286 {
287 const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
288 internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
289 }
290 else
291 {
292 internal_vst1q(in_ptr, internal_vld1q(in_ptr));
293 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100294 },
295 in);
296 }
297 else // Out of place accumulate
298 {
299 Iterator out(output, window);
300 execute_window_loop(window, [&](const Coordinates & id)
301 {
302 // Get bias and pointer to input
303 const auto in_ptr = reinterpret_cast<const T1 *>(in.ptr());
304 const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100305
306 // Accumulate bias
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000307 if(has_bias)
308 {
309 const auto vb = internal_vdupq_n(static_cast<T1>(*reinterpret_cast<const T2 *>(bias->ptr_to_element(Coordinates(id.z())))));
310 internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), vb));
311 }
312 else
313 {
314 internal_vst1q(out_ptr, internal_vld1q(in_ptr));
315 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100316 },
317 in, out);
318 }
319}
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000320
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100321template <typename T1, typename T2, bool in_place, bool has_bias>
322void output_stage_nhwc(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
323 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
324{
325 ARM_COMPUTE_UNUSED(result_fixedpoint_multiplier);
326 ARM_COMPUTE_UNUSED(result_shift);
327 ARM_COMPUTE_UNUSED(result_offset_after_shift);
328
329 Window window_bias = window;
330 window_bias.set(Window::DimY, Window::Dimension(0, 0, 0));
331 window_bias.set(Window::DimZ, Window::Dimension(0, 0, 0));
332 window_bias.set(3, Window::Dimension(0, 0, 0));
333
334 Iterator in(input, window);
335 Iterator bi(bias, window_bias);
336
337 if(in_place) // In place accumulate
338 {
339 execute_window_loop(window, [&](const Coordinates & id)
340 {
341 // Get bias and pointer to input
342 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
343 const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
344
345 // Accumulate bias
346 if(has_bias)
347 {
348 internal_vst1q(in_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
349 }
350 else
351 {
352 internal_vst1q(in_ptr, internal_vld1q(in_ptr));
353 }
354 },
355 in, bi);
356 }
357 else // Out of place accumulate
358 {
359 Iterator out(output, window);
360 execute_window_loop(window, [&](const Coordinates & id)
361 {
362 // Get bias and pointer to input
363 const auto in_ptr = reinterpret_cast<T1 *>(in.ptr());
364 const auto out_ptr = reinterpret_cast<T2 *>(out.ptr());
365 const auto bias_ptr = reinterpret_cast<T2 *>(bi.ptr());
366
367 // Accumulate bias
368 if(has_bias)
369 {
370 internal_vst1q(out_ptr, internal_vqaddq(internal_vld1q(in_ptr), internal_vld1q(bias_ptr)));
371 }
372 else
373 {
374 internal_vst1q(out_ptr, internal_vld1q(in_ptr));
375 }
376 },
377 in, bi);
378 }
379}
380
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000381// QASYMM8 specializations
382template <>
383void output_stage<int32_t, uint8_t, false, true>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
384 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
385{
386 const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
387 uint8x16_t min = vdupq_n_u8(0);
388 uint8x16_t max = vdupq_n_u8(255);
389
390 Iterator in(input, window);
391 Iterator out(output, window);
392
393 execute_window_loop(window, [&](const Coordinates & id)
394 {
395 // Get bias and pointer to input
396 const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
397 int32x4x4_t v_in =
398 {
399 {
400 vld1q_s32(in_ptr),
401 vld1q_s32(in_ptr + 4),
402 vld1q_s32(in_ptr + 8),
403 vld1q_s32(in_ptr + 12)
404 }
405 };
406
407 // Accumulate bias
408 const auto vb = vdupq_n_s32(*reinterpret_cast<const int32_t *>(bias->ptr_to_element(Coordinates(id.z()))));
409 v_in =
410 {
411 {
412 vaddq_s32(v_in.val[0], vb),
413 vaddq_s32(v_in.val[1], vb),
414 vaddq_s32(v_in.val[2], vb),
415 vaddq_s32(v_in.val[3], vb)
416 }
417 };
418
419 const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
420 vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
421 },
422 in, out);
423}
424template <>
425void output_stage<int32_t, uint8_t, false, false>(ITensor *input, const ITensor *bias, const Window &window, ITensor *output,
426 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
427{
428 ARM_COMPUTE_UNUSED(bias);
429
430 const int32x4_t result_offset_after_shift_s32 = vdupq_n_s32(result_offset_after_shift);
431 uint8x16_t min = vdupq_n_u8(0);
432 uint8x16_t max = vdupq_n_u8(255);
433
434 Iterator in(input, window);
435 Iterator out(output, window);
436 execute_window_loop(window, [&](const Coordinates & id)
437 {
438 // Get bias and pointer to input
439 const auto in_ptr = reinterpret_cast<int32_t *>(in.ptr());
440 int32x4x4_t v_in =
441 {
442 {
443 vld1q_s32(in_ptr),
444 vld1q_s32(in_ptr + 4),
445 vld1q_s32(in_ptr + 8),
446 vld1q_s32(in_ptr + 12)
447 }
448 };
449
450 const auto out_ptr = reinterpret_cast<uint8_t *>(out.ptr());
451 vst1q_u8(out_ptr, finalize_quantization<false>(v_in, result_fixedpoint_multiplier, result_shift, result_offset_after_shift_s32, min, max));
452 },
453 in, out);
454}
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100455} // namespace
456
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000457NEDirectConvolutionLayerOutputStageKernel::NEDirectConvolutionLayerOutputStageKernel()
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000458 : _func(nullptr), _input(nullptr), _bias(nullptr), _output(nullptr), _result_fixedpoint_multiplier(0), _result_shift(0), _result_offset_after_shift(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100459{
460}
461
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000462void NEDirectConvolutionLayerOutputStageKernel::configure(ITensor *input, const ITensor *bias, ITensor *output,
463 int result_fixedpoint_multiplier, int result_shift, int result_offset_after_shift)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100464{
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000465 ARM_COMPUTE_ERROR_ON_NULLPTR(input);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000466
Georgios Pinitas0223a782017-12-12 11:44:44 +0000467 // Auto-initialize output output if required
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100468 if(output != nullptr)
469 {
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000470 // Work out expected output data type
471 const DataType output_dt = (input->info()->data_type() == DataType::S32) ? DataType::QASYMM8 : input->info()->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000472 // Output tensor auto initialization if not yet initialized
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000473 auto_init_if_empty(*output->info(), input->info()->clone()->set_data_type(output_dt));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100474 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000475
476 // Perform validation step
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000477 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info()));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100478
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000479 _func = nullptr;
480 _bias = bias;
481 _input = input;
482 _output = output;
483 _result_fixedpoint_multiplier = result_fixedpoint_multiplier;
484 _result_shift = result_shift;
485 _result_offset_after_shift = result_offset_after_shift;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100486
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100487 // Configure kernel window
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000488 auto win_config = validate_and_configure_window(input->info(), (bias == nullptr) ? nullptr : bias->info(), (output == nullptr) ? nullptr : output->info());
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000489 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
490 INEKernel::configure(win_config.second);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100491
492 // Set appropriate function
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100493 if(input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100494 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100495 switch(input->info()->data_type())
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100496 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100497 case DataType::QS8:
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000498 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100499 if(bias == nullptr)
500 {
501 _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, false> : &output_stage<qint8_t, qint8_t, false, false>;
502 }
503 else
504 {
505 _func = (output == nullptr) ? &output_stage<qint8_t, qint8_t, true, true> : &output_stage<qint8_t, qint8_t, false, true>;
506 }
507 break;
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000508 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100509 case DataType::QS16:
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000510 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100511 if(bias != nullptr && bias->info()->data_type() == DataType::QS8)
512 {
513 _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, true> : &output_stage<qint16_t, qint8_t, false, true>;
514 }
515 else if(bias == nullptr)
516 {
517 _func = (output == nullptr) ? &output_stage<qint16_t, qint8_t, true, false> : &output_stage<qint16_t, qint8_t, false, false>;
518 }
519 else
520 {
521 ARM_COMPUTE_ERROR("Not implemented");
522 }
523 break;
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000524 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100525 case DataType::QS32:
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100526 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100527 _func = (output == nullptr) ? &output_stage<qint32_t, qint16_t, true, true> : &output_stage<qint32_t, qint16_t, false, true>;
528 break;
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000529 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100530 case DataType::S32:
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000531 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100532 _func = (bias == nullptr) ? &output_stage<int32_t, uint8_t, false, false> : &output_stage<int32_t, uint8_t, false, true>;
533 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100534 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000535#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100536 case DataType::F16:
537 {
538 _func = (output == nullptr) ? &output_stage<float16_t, float16_t, true, true> : &output_stage<float16_t, float16_t, false, true>;
539 break;
540 }
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000541#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100542 case DataType::F32:
543 {
544 _func = (output == nullptr) ? &output_stage<float, float, true, true> : &output_stage<float, float, false, true>;
545 break;
546 }
547 default:
548 {
549 ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
550 }
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100551 }
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100552 }
553 else
554 {
555 switch(input->info()->data_type())
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100556 {
Giorgio Arena1ed1fc62018-03-26 16:20:05 +0100557 case DataType::F32:
558 {
559 _func = (output == nullptr) ? &output_stage_nhwc<float, float, true, true> : &output_stage_nhwc<float, float, false, true>;
560 break;
561 }
562 default:
563 {
564 ARM_COMPUTE_ERROR("Unsupported combination of types among the inputs.");
565 }
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100566 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100567 }
568}
569
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000570Status NEDirectConvolutionLayerOutputStageKernel::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000571{
572 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, bias, output));
Anthony Barbierde014682018-07-03 15:10:48 +0100573 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(), bias == nullptr ? nullptr : bias->clone().get(), output == nullptr ? nullptr : output->clone().get()).first);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000574
575 return Status{};
576}
577
Michalis Spyroub91e34c2017-12-20 15:50:55 +0000578void NEDirectConvolutionLayerOutputStageKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100579{
Moritz Pflanzerc186b572017-09-07 09:48:04 +0100580 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100581 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
582 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
583 ARM_COMPUTE_ERROR_ON(_func == nullptr);
584
Georgios Pinitasf72f9362018-01-12 16:29:45 +0000585 (*_func)(_input, _bias, window, _output, _result_fixedpoint_multiplier, _result_shift, _result_offset_after_shift);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100586}