blob: 4dc186a8a72242cba2c4bf6dde51cf4cf733099b [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrou621965e2018-01-08 17:11:26 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
Georgios Pinitas4074c992018-01-30 18:13:46 +000025#include "arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
27#include "arm_compute/core/AccessWindowStatic.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/IAccessWindow.h"
31#include "arm_compute/core/ITensor.h"
32#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010034#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/Validate.h"
36
37#include <algorithm>
38#include <arm_neon.h>
39
40using namespace arm_compute;
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010041using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
43namespace
44{
Pablo Tellof87cc7f2017-07-26 10:28:40 +010045template <unsigned int stridex>
46qint16x8_t internal_vld1q(const qint16_t *in);
47
48template <>
49qint16x8_t internal_vld1q<1>(const qint16_t *in)
50{
51 return vld1q_qs16(in);
52}
53
54template <>
55qint16x8_t internal_vld1q<2>(const qint16_t *in)
56{
57 const int16x8x2_t tmp = vld2q_s16(in);
58 return tmp.val[0];
59}
60
61template <>
62qint16x8_t internal_vld1q<3>(const qint16_t *in)
63{
64 const int16x8x3_t tmp = vld3q_s16(in);
65 return tmp.val[0];
66}
67
68inline qint16x8_t internal_vdupq_n(qint16_t v)
69{
70 return vdupq_n_qs16(v);
71}
72
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000073#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010074template <unsigned int stridex>
75float16x8_t internal_vld1q(const float16_t *in);
76
77template <>
78float16x8_t internal_vld1q<1>(const float16_t *in)
79{
80 return vld1q_f16(in);
81}
82
83template <>
84float16x8_t internal_vld1q<2>(const float16_t *in)
85{
86 const float16x8x2_t tmp = vld2q_f16(in);
87 return tmp.val[0];
88}
89
90template <>
91float16x8_t internal_vld1q<3>(const float16_t *in)
92{
93 const float16x8x3_t tmp = vld3q_f16(in);
94 return tmp.val[0];
95}
96
97inline float16x8_t internal_vdupq_n(float16_t v)
98{
99 return vdupq_n_f16(v);
100}
101
102inline void internal_vst1q(float16_t *p, const float16x8_t &v)
103{
104 vst1q_f16(p, v);
105}
106
107float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
108{
109 ARM_COMPUTE_UNUSED(fixed_point_position);
110 return vmulq_f16(x, y);
111}
112
113inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
114{
115 ARM_COMPUTE_UNUSED(fixed_point_position);
116 return vaddq_f16(x, vmulq_f16(y, z));
117}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000118#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100119
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120template <unsigned int stridex>
121float32x4_t internal_vld1q(const float *in);
122
123template <>
124float32x4_t internal_vld1q<1>(const float *in)
125{
126 return vld1q_f32(in);
127}
128
129template <>
130float32x4_t internal_vld1q<2>(const float *in)
131{
132 const float32x4x2_t tmp = vld2q_f32(in);
133 return tmp.val[0];
134}
135
136template <>
137float32x4_t internal_vld1q<3>(const float *in)
138{
139 const float32x4x3_t tmp = vld3q_f32(in);
140 return tmp.val[0];
141}
142
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100143inline float32x4_t internal_vdupq_n(float v)
144{
145 return vdupq_n_f32(v);
146}
147
148inline void internal_vst1q(float *p, const float32x4_t &v)
149{
150 vst1q_f32(p, v);
151}
152
153float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
154{
155 ARM_COMPUTE_UNUSED(fixed_point_position);
156 return vmulq_f32(x, y);
157}
158
159inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
160{
161 ARM_COMPUTE_UNUSED(fixed_point_position);
162 return vmlaq_f32(x, y, z);
163}
164
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100165template <unsigned int stridex>
166qint8x8_t internal_vld1q(const qint8_t *in);
167
168template <>
169qint8x8_t internal_vld1q<1>(const qint8_t *in)
170{
171 return vld1_qs8(in);
172}
173
174template <>
175qint8x8_t internal_vld1q<2>(const qint8_t *in)
176{
177 const qint8x8x2_t tmp = vld2_s8(in);
178 return tmp.val[0];
179}
180
181template <>
182qint8x8_t internal_vld1q<3>(const qint8_t *in)
183{
184 const qint8x8x3_t tmp = vld3_s8(in);
185 return tmp.val[0];
186}
187
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100188inline qint8x8_t internal_vdupq_n(qint8_t v)
189{
190 return vdup_n_qs8(v);
191}
192
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100193inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100195 return vmull_qs8(x, y, fixed_point_position);
196}
197
198inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
199{
200 return vqmlal_qs8(x, y, z, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201}
202
203inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
204{
205 vst1q_qs16(p, v);
206}
207
Michalis Spyrou490bf2e2017-09-29 11:24:55 +0100208inline void internal_vst1q(int32_t *p, const qint32x4x2_t &v)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100210 vst1q_s32(p, v.val[0]);
211 vst1q_s32(p + 4, v.val[1]);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100212}
213
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100214template <unsigned int stridex>
215qint32x4x2_t internal_vld1q(const qint32_t *in);
216
217template <>
218qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100219{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100220 const qint32x4x2_t r =
221 {
222 {
223 vld1q_s32(in),
224 vld1q_s32(in + 4)
225 }
226 };
227 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100228}
229
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100230inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100232 const qint32x4x2_t r =
233 {
234 {
235 vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
236 vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
237 }
238 };
239 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100240}
241
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100242inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100244 const qint32x4x2_t r =
245 {
246 {
247 vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
248 vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
249 }
250 };
251 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100252}
253
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000254constexpr int small_tensor_size_optim = 8;
255inline bool run_optim_small_tensor_info(const ITensorInfo *t)
256{
257 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
258}
259
Pablo Telloc09314a2017-09-21 13:59:14 +0100260inline bool run_optim_small_tensor(const ITensor *t)
261{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000262 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100263}
264
265// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
266// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
267// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
268template <unsigned int stridex>
269class convolver_w1x1_i8x8_f32
270{
271public:
272 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
273 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000274 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim);
275 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100276
Georgios Pinitas15997872018-02-19 13:58:22 +0000277 const int input_stride_x = input->info()->strides_in_bytes().x();
Pablo Telloc09314a2017-09-21 13:59:14 +0100278 const int input_stride_y = input->info()->strides_in_bytes().y();
279 const int input_stride_z = input->info()->strides_in_bytes().z();
280 const int output_stride_y = output->info()->strides_in_bytes().y();
281 const int output_stride_z = output->info()->strides_in_bytes().z();
282 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
283 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
284 const int output_h = output->info()->dimension(1);
285 const int range_z = window.z().end() - window.z().start();
286 const int kernel_depth = weights->info()->dimension(Window::DimZ);
287 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000288 const unsigned int conv_pad_left = conv_info.pad_left();
289 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Telloc09314a2017-09-21 13:59:14 +0100290
291 // setup output window for the iterator
292 Window window_out = window;
293 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
294 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
295 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
296
297 // setup input window for the iterator
298 Window window_in = window;
299 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
300 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
301 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
302 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
303
304 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
305 Iterator out(output, window_out);
306 Iterator in(input, window_in);
307 Iterator k(weights, window_k);
308
309 const uint8_t *k_ptr = k.ptr();
310
311 execute_window_loop(window_out, [&](const Coordinates & id)
312 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000313 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000314 uint8_t *out_ptr = out.ptr();
315 int ih = 0;
316 int oh = 0;
317 float32x4_t accum0[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
318 float32x4_t accum1[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100319 for(int oz = 0; oz < range_z; ++oz)
320 {
321 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
322 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
323 auto p_out_base = out_ptr + oz * output_stride_z;
324 for(int p = 0; p < kernel_depth; ++p)
325 {
326 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
327 const auto vk0 = internal_vdupq_n(*k_val);
328 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
329 {
330 const int offset_xy = ih * input_stride_y;
331 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
332 auto v_in0 = internal_vld1q<stridex>(in_val);
333 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
334 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
335 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
336 }
337 }
338 for(oh = 0; oh < output_h; ++oh)
339 {
340 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
341 vst1q_f32(p_out, accum0[oh]);
342 vst1q_f32(p_out + 4, accum1[oh]);
343 }
344 }
345 },
346 in, out);
347 }
348};
349
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100350template <typename T1, typename T2, unsigned int stridex>
351class convolver_1x1
352{
353public:
354 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
355 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
356 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000357 const int input_stride_x = input->info()->strides_in_bytes().x();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100358 const int input_stride_y = input->info()->strides_in_bytes().y();
359 const int input_stride_z = input->info()->strides_in_bytes().z();
360 const int output_stride_y = output->info()->strides_in_bytes().y();
361 const int output_stride_z = output->info()->strides_in_bytes().z();
362 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
363 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
364 const int output_w = output->info()->dimension(0);
365 const int output_h = output->info()->dimension(1);
366 const int range_z = window.z().end() - window.z().start();
367 const int kernel_depth = weights->info()->dimension(Window::DimZ);
368 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000369 const unsigned int conv_pad_left = conv_info.pad_left();
370 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100371 const int fixed_point_position = input->info()->fixed_point_position();
372
373 // setup output window for the iterator
374 Window window_out = window;
375 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
376 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
377 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
378
379 // setup input window for the iterator
380 Window window_in = window;
381 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
382 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
383 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
384 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
385
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100386 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100387 Iterator out(output, window_out);
388 Iterator in(input, window_in);
389 Iterator k(weights, window_k);
390
391 const uint8_t *k_ptr = k.ptr();
392
393 execute_window_loop(window_out, [&](const Coordinates & id)
394 {
395 /*
396 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
397 */
Georgios Pinitas15997872018-02-19 13:58:22 +0000398 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100399 uint8_t *out_ptr = out.ptr();
400 int ih = 0;
401 int oh = 0;
402 for(int oz = 0; oz < range_z; ++oz)
403 {
404 auto p_out_base = out_ptr + oz * output_stride_z;
405 // Step 1
406 {
407 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
408 const auto vk = internal_vdupq_n(*k_val);
409 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
410 {
411 const int offset_xy = ih * input_stride_y;
412 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
413 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
414 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
415 {
416 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
417 }
418 }
419 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100420
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100421 // Step 2
422 for(int p = 1; p < kernel_depth; ++p)
423 {
424 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
425 const auto vk = internal_vdupq_n(*k_val);
426 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
427 {
428 const int offset_xy = ih * input_stride_y;
429 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
430 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
431 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
432 {
433 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
434 }
435 }
436 }
437 }
438 },
439 in, out);
440 }
441};
442
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000443#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100444
445template <unsigned int stridex>
446void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
447
448template <>
449void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
450{
451 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
452 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
453}
454
455template <>
456void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
457{
458 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
459}
460
461template <>
462void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
463{
464 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
465}
466
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000467#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100468
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100469template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100470float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
471 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position);
472
473inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
474{
475 const float32x4x3_t m00 =
476 {
477 {
478 vld1q_dup_f32(m0),
479 vld1q_dup_f32(m1),
480 vld1q_dup_f32(m2)
481 }
482 };
483 return m00;
484}
485
486inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
487{
488 const float32x4x2_t m00 =
489 {
490 {
491 vld1q_dup_f32(m3),
492 vld1q_dup_f32(m4)
493 }
494 };
495 return m00;
496}
497
498inline float32x4x3_t load_input(const float *const in)
499{
500 const float32x4x3_t vin =
501 {
502 {
503 vld1q_f32(in),
504 vld1q_f32(in + 4),
505 vld1q_f32(in + 8)
506 }
507 };
508 return vin;
509}
510
511template <>
512inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
513 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
514{
515 ARM_COMPUTE_UNUSED(fixed_point_position);
516 const float32x4x3_t vin0 = load_input(in_0);
517 const float32x4x3_t vin1 = load_input(in_1);
518 const float32x4x3_t vin2 = load_input(in_2);
519 const float32x4x3_t vin3 = load_input(in_3);
520 const float32x4x3_t vin4 = load_input(in_4);
521 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
522 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
523 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
524 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
525 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
526 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
527 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
528 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
529 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
530 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
531
532 float32x4x2_t out =
533 {
534 {
535 vmulq_f32(vin0.val[0], m00.val[0]),
536 vmulq_f32(vin0.val[1], m00.val[0])
537 }
538 };
539
540 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
541 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
542 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
543 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
544
545 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
546 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
547 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
548 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
549 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
550
551 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
552 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
553 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
554 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
555 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
556
557 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
558 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
559 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
560 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
561 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
562
563 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
564 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
565 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
566 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
567 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
568
569 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
570 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
571 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
572 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
573
574 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
575 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
576 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
577 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
578 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
579
580 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
581 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
582 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
583 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
584 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
585
586 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
587 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
588 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
589 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
590 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
591
592 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
593 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
594 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
595 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
596 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
597
598 return out;
599}
600
601template <>
602inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
603 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
604{
605 ARM_COMPUTE_UNUSED(fixed_point_position);
606 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
607 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
608 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
609 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
610 return out;
611}
612
613template <>
614inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
615 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
616{
617 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
618 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
619 return out;
620}
621
622template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100623void accumulate_results(float *buffer, const float32x4x2_t &values);
624
625template <>
626void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
627{
628 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
629 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
630}
631
632template <>
633void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
634{
635 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
636}
637
638template <>
639void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
640{
641 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
642}
643
644template <unsigned int stridex>
645void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
646
647template <>
648void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
649{
650 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
651 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
652}
653
654template <>
655void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
656{
657 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
658}
659
660template <>
661void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
662{
663 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
664}
665
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100666template <typename T1, typename T2, unsigned int stridex>
667class convolver_3x3
668{
669public:
670 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
671 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
672 {
673 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
674 const int input_stride_x = input->info()->strides_in_bytes().x();
675 const int input_stride_y = input->info()->strides_in_bytes().y();
676 const int input_stride_z = input->info()->strides_in_bytes().z();
677 const int output_stride_y = output->info()->strides_in_bytes().y();
678 const int output_stride_z = output->info()->strides_in_bytes().z();
679 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
680 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
681 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
682 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
683 const int output_w = output->info()->dimension(0);
684 const int output_h = output->info()->dimension(1);
685 const int num_planes_z = window.z().end() - window.z().start();
686 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
687 const int kernel_depth = weights->info()->dimension(Window::DimZ);
688 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000689 const unsigned int conv_pad_left = conv_info.pad_left();
690 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100691 const int fixed_point_position = input->info()->fixed_point_position();
692
693 // setup output window for the iterator
694 Window window_out = window;
695 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
696 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
697 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
698
699 // setup input window for the iterator
700 Window window_in = window;
701 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
702 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
703 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
704 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
705
706 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
707
708 Iterator out(output, window_out);
709 Iterator in(input, window_in);
710 Iterator k(weights, window_k);
711
712 const uint8_t *k_ptr = k.ptr();
713
714 execute_window_loop(window_out, [&](const Coordinates & id)
715 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000716 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100717 uint8_t *out_ptr = out.ptr();
718 int ih = 0;
719 int oh = 0;
720 /*
721 Each thread executing this kernel computes one or more output's volume planes.
722
723 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
724 the third thread [16,24] and the fourth thread [25,31].
725
726 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100727 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100728
729 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
730 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
731 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
732 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100733 for(int oz = 0; oz < num_planes_z; ++oz)
734 {
Pablo Tello0d176142017-07-06 16:43:14 +0100735 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100736 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
737 // Step 1
738 {
Pablo Tello0d176142017-07-06 16:43:14 +0100739 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
740 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
741 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100742 const auto vk_r0 = load_matrix_row(ptr_k_r0);
743 const auto vk_r1 = load_matrix_row(ptr_k_r1);
744 const auto vk_r2 = load_matrix_row(ptr_k_r2);
745 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
746 {
747 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
748 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
749 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
750 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
751 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
752 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
753 {
754 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
755 store_results<stridex>(p_out, vres);
756 }
757 }
758 }
759 // Step 2
760 for(int p = 1; p < kernel_depth; ++p)
761 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100762 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
763 const uint8_t *input_base = input_ptr + p * input_stride_z;
764 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
765 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
766 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
767 const auto vk_r0 = load_matrix_row(ptr_k_r0);
768 const auto vk_r1 = load_matrix_row(ptr_k_r1);
769 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100770 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
771 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100772 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
773 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
774 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100775 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
776 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
777 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
778 {
779 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
780 accumulate_results<stridex>(p_out, vres);
781 }
782 }
783 }
784 }
785 },
786 in, out);
787 }
788};
789
Pablo Tello06da39d2017-08-10 15:10:40 +0100790template <typename T1, typename T2, unsigned int stridex>
791class convolver_5x5
792{
793public:
794 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
795 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
796 {
797 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
798 const int input_stride_x = input->info()->strides_in_bytes().x();
799 const int input_stride_y = input->info()->strides_in_bytes().y();
800 const int input_stride_z = input->info()->strides_in_bytes().z();
801 const int output_stride_y = output->info()->strides_in_bytes().y();
802 const int output_stride_z = output->info()->strides_in_bytes().z();
803 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
804 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
805 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
806 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
807 const int output_w = output->info()->dimension(0);
808 const int output_h = output->info()->dimension(1);
809 const int num_planes_z = window.z().end() - window.z().start();
810 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
811 const int kernel_depth = weights->info()->dimension(Window::DimZ);
812 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000813 const unsigned int conv_pad_left = conv_info.pad_left();
814 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Tello06da39d2017-08-10 15:10:40 +0100815 const int fixed_point_position = input->info()->fixed_point_position();
816
817 // setup output window for the iterator
818 Window window_out = window;
819 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
820 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
821 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
822
823 // setup input window for the iterator
824 Window window_in = window;
825 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
826 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
827 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
828 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
829
830 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
831
832 Iterator out(output, window_out);
833 Iterator in(input, window_in);
834 Iterator k(weights, window_k);
835
836 const uint8_t *k_ptr = k.ptr();
837
838 execute_window_loop(window_out, [&](const Coordinates & id)
839 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000840 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Pablo Tello06da39d2017-08-10 15:10:40 +0100841 uint8_t *out_ptr = out.ptr();
842 int ih = 0;
843 int oh = 0;
844 for(int oz = 0; oz < num_planes_z; ++oz)
845 {
846 const int zoffset = id.z() + oz;
847 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
848 // Step 1
849 {
850 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
851 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
852 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
853 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
854 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
855 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
856 {
857 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
858 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
859 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
860 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
861 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
862 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
863 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
864 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
865 {
866 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
867 store_results<stridex>(p_out, vres);
868 }
869 }
870 }
871 // Step 2
872 for(int p = 1; p < kernel_depth; ++p)
873 {
874 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
875 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
876 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
877 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
878 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
879
880 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
881 {
882 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
883 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
884 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
885 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
886 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
887 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
888 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
889 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
890 {
891 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
892 accumulate_results<stridex>(p_out, vres);
893 }
894 }
895 }
896 }
897 },
898 in, out);
899 }
900};
901
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100902template <typename T1, typename T2>
903inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
904 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
905{
906 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
907 switch(conv_stride_x)
908 {
909 case 1:
910 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
911 break;
912 case 2:
913 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
914 break;
915 case 3:
916 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
917 break;
918 default:
919 ARM_COMPUTE_ERROR("Not implemented");
920 }
921}
922
Pablo Telloc09314a2017-09-21 13:59:14 +0100923template <>
924inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
925 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
926{
927 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
928 if(run_optim_small_tensor(input))
929 {
930 switch(conv_stride_x)
931 {
932 case 1:
933 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
934 break;
935 case 2:
936 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
937 break;
938 case 3:
939 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
940 break;
941 default:
942 ARM_COMPUTE_ERROR("Not implemented");
943 }
944 }
945 else
946 {
947 switch(conv_stride_x)
948 {
949 case 1:
950 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
951 break;
952 case 2:
953 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
954 break;
955 case 3:
956 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
957 break;
958 default:
959 ARM_COMPUTE_ERROR("Not implemented");
960 }
961 }
962}
963
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100964template <typename T1, typename T2>
965inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
966 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
967{
968 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
969 switch(conv_stride_x)
970 {
971 case 1:
972 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
973 break;
974 case 2:
975 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
976 break;
977 case 3:
978 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
979 break;
980 default:
981 ARM_COMPUTE_ERROR("Not implemented");
982 }
983}
Pablo Tello06da39d2017-08-10 15:10:40 +0100984
985template <typename T1, typename T2>
986inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
987 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
988{
989 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
990 switch(conv_stride_x)
991 {
992 case 1:
993 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
994 break;
995 case 2:
996 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
997 break;
998 case 3:
999 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1000 break;
1001 default:
1002 ARM_COMPUTE_ERROR("Not implemented");
1003 }
1004}
1005
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001006inline TensorShape get_convolved_dimensions(const ITensorInfo *input, const ITensorInfo *weights, const int kernel_size, const PadStrideInfo &conv_info)
1007{
1008 unsigned int output_width = 0;
1009 unsigned int output_height = 0;
1010 std::tie(output_width, output_height) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_size, kernel_size, conv_info);
1011
1012 TensorShape output_shape = input->tensor_shape();
1013 output_shape.set(0, output_width);
1014 output_shape.set(1, output_height);
1015 output_shape.set(2, weights->dimension(3));
1016
1017 return output_shape;
1018}
1019
1020Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1021{
1022 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
1023 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
1024 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001025
1026 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
1027 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) != input->dimension(2));
1028 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(0) != weights->dimension(1));
1029 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
1030
1031 // Checks performed when output is configured
1032 if(output->total_size() != 0)
1033 {
1034 TensorShape output_shape = get_convolved_dimensions(input, weights, weights->dimension(0), conv_info);
1035
1036 DataType data_type = input->data_type();
1037 if(is_data_type_fixed_point(data_type))
1038 {
1039 // Promote data type in case of fixed point
1040 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1041 }
1042
1043 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
1044 ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
1045 }
1046
1047 return Status{};
1048}
1049
1050std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001051 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001052{
1053 // Calculate right and bottom border
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001054 unsigned int kernel_size = weights->dimension(0);
1055 const int conv_stride_x = std::get<0>(conv_info.stride());
1056 const int input_width = input->dimension(0);
1057 const int input_height = input->dimension(1);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001058
1059 switch(kernel_size)
1060 {
1061 case 1:
1062 {
1063 switch(input->data_type())
1064 {
1065#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1066 case DataType::F16:
1067#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1068 case DataType::QS8:
1069 case DataType::QS16:
1070 num_elems_written_per_iteration = 8;
1071 break;
1072 case DataType::F32:
1073 if(run_optim_small_tensor_info(input))
1074 {
1075 num_elems_written_per_iteration = 8;
1076 }
1077 else
1078 {
1079 num_elems_written_per_iteration = 4;
1080 }
1081 break;
1082 default:
1083 ARM_COMPUTE_ERROR("Data type not supported.");
1084 break;
1085 }
1086 num_weight_elems_read_per_row = kernel_size;
1087 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
1088 break;
1089 }
1090 case 3:
1091 case 5:
1092 {
1093 switch(input->data_type())
1094 {
1095 case DataType::F32:
1096 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1097 num_elems_read_per_iteration = 12;
1098 num_elems_written_per_iteration = 16 >> conv_stride_x;
1099 break;
1100#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1101 case DataType::F16:
1102#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1103 case DataType::QS8:
1104 case DataType::QS16:
1105 num_weight_elems_read_per_row = 8 + kernel_size - 1;
1106 num_elems_read_per_iteration = 24;
1107 num_elems_written_per_iteration = 32 >> conv_stride_x;
1108 break;
1109 default:
1110 ARM_COMPUTE_ERROR("Data type not supported.");
1111 break;
1112 }
1113 }
1114 break;
1115 default:
1116 {
1117 ARM_COMPUTE_ERROR("Not implemented");
1118 break;
1119 }
1120 }
1121
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001122 // Calculate right pad
1123 int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
1124 int end_x = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
1125 int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
1126
Michalis Spyrou621965e2018-01-08 17:11:26 +00001127 // Calculate border
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001128 const unsigned int conv_pad_left = conv_info.pad_left();
1129 const unsigned int conv_pad_top = conv_info.pad_top();
1130 const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
1131 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
Michalis Spyrou621965e2018-01-08 17:11:26 +00001132
Michalis Spyrou621965e2018-01-08 17:11:26 +00001133 border_size.left = conv_pad_left;
1134 border_size.top = conv_pad_top;
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001135 border_size.right = conv_pad_right;
1136 border_size.bottom = conv_pad_bottom;
Michalis Spyrou621965e2018-01-08 17:11:26 +00001137
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001138 Window win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001139 AccessWindowStatic input_access(input, -conv_pad_left, -conv_pad_top, input_width + conv_pad_right, input_height + conv_pad_bottom);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001140 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
1141 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
1142 bool window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1143 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1144
1145 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1146 return std::make_pair(err, win);
1147}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001148} // namespace
1149
1150NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001151 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1152 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001153{
1154}
1155
1156BorderSize NEDirectConvolutionLayerKernel::border_size() const
1157{
1158 return _border_size;
1159}
1160
1161void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1162{
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001163 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001164
1165 _input = input;
1166 _weights = weights;
1167 _output = output;
1168 _conv_info = conv_info;
1169 _kernel_size = weights->info()->dimension(0);
Michalis Spyrou621965e2018-01-08 17:11:26 +00001170
1171 const unsigned int conv_pad_left = conv_info.pad_left();
1172 const unsigned int conv_pad_top = conv_info.pad_top();
1173 const unsigned int conv_pad_right = conv_info.pad_right();
1174 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1175 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001176
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001177 // Get convolved dimensions
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001178 TensorShape output_shape = get_convolved_dimensions(input->info(), weights->info(), _kernel_size, conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001179
1180 DataType data_type = input->info()->data_type();
1181
1182 if(is_data_type_fixed_point(data_type))
1183 {
1184 // Promote data type in case of fixed point
1185 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1186 }
1187
1188 // Output auto inizialitation if not yet initialized
1189 auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
1190
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001191 // Perform validation step
1192 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001193
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001194 // Configure kernel window
1195 auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001196 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001197 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1198 INEKernel::configure(win_config.second);
1199}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001200
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001201Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1202{
1203 unsigned int num_weight_elems_read_per_row = 0;
1204 unsigned int num_elems_read_per_iteration = 0;
1205 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas15997872018-02-19 13:58:22 +00001206 BorderSize border_size = {};
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001207 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info));
Georgios Pinitas0223a782017-12-12 11:44:44 +00001208 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1209 weights->clone().get(),
1210 output->clone().get(),
1211 conv_info,
1212 num_weight_elems_read_per_row,
1213 num_elems_read_per_iteration,
1214 num_elems_written_per_iteration,
1215 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001216 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001217
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001218 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001219}
1220
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001221void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001222{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001223 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001224 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1225 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1226 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1227
1228 const int kernel_size = _weights->info()->dimension(0);
1229
1230 switch(kernel_size)
1231 {
1232 case 1:
1233 {
Pablo Tello0d176142017-07-06 16:43:14 +01001234 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001235 {
Pablo Tello0d176142017-07-06 16:43:14 +01001236 case DataType::QS8:
1237 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1238 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001239 case DataType::QS16:
1240 convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1241 break;
Pablo Tello0d176142017-07-06 16:43:14 +01001242 case DataType::F32:
1243 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1244 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001245#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +01001246 case DataType::F16:
1247 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1248 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001249#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +01001250 default:
1251 ARM_COMPUTE_ERROR("Data type not supported");
1252 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001253 }
1254 break;
1255 }
1256 case 3:
1257 {
Pablo Tello0d176142017-07-06 16:43:14 +01001258 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001259 {
Pablo Tello0d176142017-07-06 16:43:14 +01001260 case DataType::QS8:
1261 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1262 break;
1263 case DataType::F32:
1264 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1265 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001266#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +01001267 case DataType::F16:
1268 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1269 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001270#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +01001271 default:
1272 ARM_COMPUTE_ERROR("Data type not supported");
1273 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001274 }
1275 break;
1276 }
Pablo Tello06da39d2017-08-10 15:10:40 +01001277 case 5:
1278 {
1279 switch(_input->info()->data_type())
1280 {
1281 case DataType::F32:
1282 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1283 break;
1284 default:
1285 ARM_COMPUTE_ERROR("Data type not supported");
1286 break;
1287 }
1288 break;
1289 }
1290
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001291 default:
1292 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001293 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001294 break;
1295 }
1296 }
1297}