blob: cb8246d09e959231608aabfb555800a7f04b7362 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrou621965e2018-01-08 17:11:26 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010025#include "arm_compute/core/NEON/kernels/convolution/NEDirectConvolutionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
27#include "arm_compute/core/AccessWindowStatic.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/IAccessWindow.h"
31#include "arm_compute/core/ITensor.h"
32#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010034#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/Validate.h"
36
37#include <algorithm>
38#include <arm_neon.h>
39
40using namespace arm_compute;
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010041using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
43namespace
44{
Pablo Tellof87cc7f2017-07-26 10:28:40 +010045template <unsigned int stridex>
46qint16x8_t internal_vld1q(const qint16_t *in);
47
48template <>
49qint16x8_t internal_vld1q<1>(const qint16_t *in)
50{
51 return vld1q_qs16(in);
52}
53
54template <>
55qint16x8_t internal_vld1q<2>(const qint16_t *in)
56{
57 const int16x8x2_t tmp = vld2q_s16(in);
58 return tmp.val[0];
59}
60
61template <>
62qint16x8_t internal_vld1q<3>(const qint16_t *in)
63{
64 const int16x8x3_t tmp = vld3q_s16(in);
65 return tmp.val[0];
66}
67
68inline qint16x8_t internal_vdupq_n(qint16_t v)
69{
70 return vdupq_n_qs16(v);
71}
72
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000073#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010074template <unsigned int stridex>
75float16x8_t internal_vld1q(const float16_t *in);
76
77template <>
78float16x8_t internal_vld1q<1>(const float16_t *in)
79{
80 return vld1q_f16(in);
81}
82
83template <>
84float16x8_t internal_vld1q<2>(const float16_t *in)
85{
86 const float16x8x2_t tmp = vld2q_f16(in);
87 return tmp.val[0];
88}
89
90template <>
91float16x8_t internal_vld1q<3>(const float16_t *in)
92{
93 const float16x8x3_t tmp = vld3q_f16(in);
94 return tmp.val[0];
95}
96
97inline float16x8_t internal_vdupq_n(float16_t v)
98{
99 return vdupq_n_f16(v);
100}
101
102inline void internal_vst1q(float16_t *p, const float16x8_t &v)
103{
104 vst1q_f16(p, v);
105}
106
107float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
108{
109 ARM_COMPUTE_UNUSED(fixed_point_position);
110 return vmulq_f16(x, y);
111}
112
113inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
114{
115 ARM_COMPUTE_UNUSED(fixed_point_position);
116 return vaddq_f16(x, vmulq_f16(y, z));
117}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000118#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100119
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120template <unsigned int stridex>
121float32x4_t internal_vld1q(const float *in);
122
123template <>
124float32x4_t internal_vld1q<1>(const float *in)
125{
126 return vld1q_f32(in);
127}
128
129template <>
130float32x4_t internal_vld1q<2>(const float *in)
131{
132 const float32x4x2_t tmp = vld2q_f32(in);
133 return tmp.val[0];
134}
135
136template <>
137float32x4_t internal_vld1q<3>(const float *in)
138{
139 const float32x4x3_t tmp = vld3q_f32(in);
140 return tmp.val[0];
141}
142
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100143inline float32x4_t internal_vdupq_n(float v)
144{
145 return vdupq_n_f32(v);
146}
147
148inline void internal_vst1q(float *p, const float32x4_t &v)
149{
150 vst1q_f32(p, v);
151}
152
153float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
154{
155 ARM_COMPUTE_UNUSED(fixed_point_position);
156 return vmulq_f32(x, y);
157}
158
159inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
160{
161 ARM_COMPUTE_UNUSED(fixed_point_position);
162 return vmlaq_f32(x, y, z);
163}
164
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100165template <unsigned int stridex>
166qint8x8_t internal_vld1q(const qint8_t *in);
167
168template <>
169qint8x8_t internal_vld1q<1>(const qint8_t *in)
170{
171 return vld1_qs8(in);
172}
173
174template <>
175qint8x8_t internal_vld1q<2>(const qint8_t *in)
176{
177 const qint8x8x2_t tmp = vld2_s8(in);
178 return tmp.val[0];
179}
180
181template <>
182qint8x8_t internal_vld1q<3>(const qint8_t *in)
183{
184 const qint8x8x3_t tmp = vld3_s8(in);
185 return tmp.val[0];
186}
187
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100188inline qint8x8_t internal_vdupq_n(qint8_t v)
189{
190 return vdup_n_qs8(v);
191}
192
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100193inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100195 return vmull_qs8(x, y, fixed_point_position);
196}
197
198inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
199{
200 return vqmlal_qs8(x, y, z, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201}
202
203inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
204{
205 vst1q_qs16(p, v);
206}
207
Michalis Spyrou490bf2e2017-09-29 11:24:55 +0100208inline void internal_vst1q(int32_t *p, const qint32x4x2_t &v)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100210 vst1q_s32(p, v.val[0]);
211 vst1q_s32(p + 4, v.val[1]);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100212}
213
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100214template <unsigned int stridex>
215qint32x4x2_t internal_vld1q(const qint32_t *in);
216
217template <>
218qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100219{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100220 const qint32x4x2_t r =
221 {
222 {
223 vld1q_s32(in),
224 vld1q_s32(in + 4)
225 }
226 };
227 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100228}
229
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100230inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100232 const qint32x4x2_t r =
233 {
234 {
235 vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
236 vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
237 }
238 };
239 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100240}
241
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100242inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100244 const qint32x4x2_t r =
245 {
246 {
247 vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
248 vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
249 }
250 };
251 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100252}
253
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000254constexpr int small_tensor_size_optim = 8;
255inline bool run_optim_small_tensor_info(const ITensorInfo *t)
256{
257 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
258}
259
Pablo Telloc09314a2017-09-21 13:59:14 +0100260inline bool run_optim_small_tensor(const ITensor *t)
261{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000262 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100263}
264
265// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
266// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
267// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
268template <unsigned int stridex>
269class convolver_w1x1_i8x8_f32
270{
271public:
272 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
273 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000274 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim);
275 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100276
277 const int input_stride_y = input->info()->strides_in_bytes().y();
278 const int input_stride_z = input->info()->strides_in_bytes().z();
279 const int output_stride_y = output->info()->strides_in_bytes().y();
280 const int output_stride_z = output->info()->strides_in_bytes().z();
281 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
282 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
283 const int output_h = output->info()->dimension(1);
284 const int range_z = window.z().end() - window.z().start();
285 const int kernel_depth = weights->info()->dimension(Window::DimZ);
286 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
287
288 // setup output window for the iterator
289 Window window_out = window;
290 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
291 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
292 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
293
294 // setup input window for the iterator
295 Window window_in = window;
296 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
297 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
298 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
299 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
300
301 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
302 Iterator out(output, window_out);
303 Iterator in(input, window_in);
304 Iterator k(weights, window_k);
305
306 const uint8_t *k_ptr = k.ptr();
307
308 execute_window_loop(window_out, [&](const Coordinates & id)
309 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000310 const uint8_t *input_ptr = in.ptr();
311 uint8_t *out_ptr = out.ptr();
312 int ih = 0;
313 int oh = 0;
314 float32x4_t accum0[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
315 float32x4_t accum1[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100316 for(int oz = 0; oz < range_z; ++oz)
317 {
318 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
319 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
320 auto p_out_base = out_ptr + oz * output_stride_z;
321 for(int p = 0; p < kernel_depth; ++p)
322 {
323 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
324 const auto vk0 = internal_vdupq_n(*k_val);
325 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
326 {
327 const int offset_xy = ih * input_stride_y;
328 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
329 auto v_in0 = internal_vld1q<stridex>(in_val);
330 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
331 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
332 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
333 }
334 }
335 for(oh = 0; oh < output_h; ++oh)
336 {
337 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
338 vst1q_f32(p_out, accum0[oh]);
339 vst1q_f32(p_out + 4, accum1[oh]);
340 }
341 }
342 },
343 in, out);
344 }
345};
346
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100347template <typename T1, typename T2, unsigned int stridex>
348class convolver_1x1
349{
350public:
351 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
352 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
353 {
354 const int input_stride_y = input->info()->strides_in_bytes().y();
355 const int input_stride_z = input->info()->strides_in_bytes().z();
356 const int output_stride_y = output->info()->strides_in_bytes().y();
357 const int output_stride_z = output->info()->strides_in_bytes().z();
358 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
359 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
360 const int output_w = output->info()->dimension(0);
361 const int output_h = output->info()->dimension(1);
362 const int range_z = window.z().end() - window.z().start();
363 const int kernel_depth = weights->info()->dimension(Window::DimZ);
364 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
365 const int fixed_point_position = input->info()->fixed_point_position();
366
367 // setup output window for the iterator
368 Window window_out = window;
369 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
370 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
371 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
372
373 // setup input window for the iterator
374 Window window_in = window;
375 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
376 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
377 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
378 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
379
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100380 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100381 Iterator out(output, window_out);
382 Iterator in(input, window_in);
383 Iterator k(weights, window_k);
384
385 const uint8_t *k_ptr = k.ptr();
386
387 execute_window_loop(window_out, [&](const Coordinates & id)
388 {
389 /*
390 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
391 */
392 const uint8_t *input_ptr = in.ptr();
393 uint8_t *out_ptr = out.ptr();
394 int ih = 0;
395 int oh = 0;
396 for(int oz = 0; oz < range_z; ++oz)
397 {
398 auto p_out_base = out_ptr + oz * output_stride_z;
399 // Step 1
400 {
401 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
402 const auto vk = internal_vdupq_n(*k_val);
403 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
404 {
405 const int offset_xy = ih * input_stride_y;
406 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
407 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
408 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
409 {
410 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
411 }
412 }
413 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100414
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100415 // Step 2
416 for(int p = 1; p < kernel_depth; ++p)
417 {
418 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
419 const auto vk = internal_vdupq_n(*k_val);
420 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
421 {
422 const int offset_xy = ih * input_stride_y;
423 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
424 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
425 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
426 {
427 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
428 }
429 }
430 }
431 }
432 },
433 in, out);
434 }
435};
436
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000437#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100438
439template <unsigned int stridex>
440void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
441
442template <>
443void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
444{
445 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
446 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
447}
448
449template <>
450void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
451{
452 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
453}
454
455template <>
456void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
457{
458 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
459}
460
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000461#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100462
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100463template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100464float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
465 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position);
466
467inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
468{
469 const float32x4x3_t m00 =
470 {
471 {
472 vld1q_dup_f32(m0),
473 vld1q_dup_f32(m1),
474 vld1q_dup_f32(m2)
475 }
476 };
477 return m00;
478}
479
480inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
481{
482 const float32x4x2_t m00 =
483 {
484 {
485 vld1q_dup_f32(m3),
486 vld1q_dup_f32(m4)
487 }
488 };
489 return m00;
490}
491
492inline float32x4x3_t load_input(const float *const in)
493{
494 const float32x4x3_t vin =
495 {
496 {
497 vld1q_f32(in),
498 vld1q_f32(in + 4),
499 vld1q_f32(in + 8)
500 }
501 };
502 return vin;
503}
504
505template <>
506inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
507 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
508{
509 ARM_COMPUTE_UNUSED(fixed_point_position);
510 const float32x4x3_t vin0 = load_input(in_0);
511 const float32x4x3_t vin1 = load_input(in_1);
512 const float32x4x3_t vin2 = load_input(in_2);
513 const float32x4x3_t vin3 = load_input(in_3);
514 const float32x4x3_t vin4 = load_input(in_4);
515 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
516 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
517 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
518 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
519 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
520 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
521 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
522 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
523 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
524 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
525
526 float32x4x2_t out =
527 {
528 {
529 vmulq_f32(vin0.val[0], m00.val[0]),
530 vmulq_f32(vin0.val[1], m00.val[0])
531 }
532 };
533
534 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
535 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
536 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
537 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
538
539 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
540 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
541 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
542 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
543 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
544
545 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
546 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
547 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
548 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
549 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
550
551 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
552 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
553 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
554 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
555 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
556
557 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
558 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
559 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
560 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
561 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
562
563 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
564 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
565 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
566 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
567
568 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
569 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
570 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
571 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
572 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
573
574 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
575 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
576 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
577 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
578 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
579
580 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
581 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
582 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
583 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
584 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
585
586 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
587 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
588 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
589 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
590 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
591
592 return out;
593}
594
595template <>
596inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
597 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
598{
599 ARM_COMPUTE_UNUSED(fixed_point_position);
600 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
601 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
602 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
603 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
604 return out;
605}
606
607template <>
608inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
609 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
610{
611 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
612 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
613 return out;
614}
615
616template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100617void accumulate_results(float *buffer, const float32x4x2_t &values);
618
619template <>
620void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
621{
622 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
623 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
624}
625
626template <>
627void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
628{
629 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
630}
631
632template <>
633void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
634{
635 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
636}
637
638template <unsigned int stridex>
639void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
640
641template <>
642void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
643{
644 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
645 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
646}
647
648template <>
649void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
650{
651 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
652}
653
654template <>
655void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
656{
657 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
658}
659
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100660template <typename T1, typename T2, unsigned int stridex>
661class convolver_3x3
662{
663public:
664 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
665 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
666 {
667 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
668 const int input_stride_x = input->info()->strides_in_bytes().x();
669 const int input_stride_y = input->info()->strides_in_bytes().y();
670 const int input_stride_z = input->info()->strides_in_bytes().z();
671 const int output_stride_y = output->info()->strides_in_bytes().y();
672 const int output_stride_z = output->info()->strides_in_bytes().z();
673 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
674 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
675 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
676 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
677 const int output_w = output->info()->dimension(0);
678 const int output_h = output->info()->dimension(1);
679 const int num_planes_z = window.z().end() - window.z().start();
680 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
681 const int kernel_depth = weights->info()->dimension(Window::DimZ);
682 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
683 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
684 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
685 const int fixed_point_position = input->info()->fixed_point_position();
686
687 // setup output window for the iterator
688 Window window_out = window;
689 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
690 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
691 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
692
693 // setup input window for the iterator
694 Window window_in = window;
695 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
696 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
697 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
698 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
699
700 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
701
702 Iterator out(output, window_out);
703 Iterator in(input, window_in);
704 Iterator k(weights, window_k);
705
706 const uint8_t *k_ptr = k.ptr();
707
708 execute_window_loop(window_out, [&](const Coordinates & id)
709 {
710 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
711 uint8_t *out_ptr = out.ptr();
712 int ih = 0;
713 int oh = 0;
714 /*
715 Each thread executing this kernel computes one or more output's volume planes.
716
717 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
718 the third thread [16,24] and the fourth thread [25,31].
719
720 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100721 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100722
723 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
724 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
725 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
726 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100727 for(int oz = 0; oz < num_planes_z; ++oz)
728 {
Pablo Tello0d176142017-07-06 16:43:14 +0100729 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100730 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
731 // Step 1
732 {
Pablo Tello0d176142017-07-06 16:43:14 +0100733 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
734 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
735 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100736 const auto vk_r0 = load_matrix_row(ptr_k_r0);
737 const auto vk_r1 = load_matrix_row(ptr_k_r1);
738 const auto vk_r2 = load_matrix_row(ptr_k_r2);
739 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
740 {
741 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
742 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
743 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
744 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
745 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
746 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
747 {
748 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
749 store_results<stridex>(p_out, vres);
750 }
751 }
752 }
753 // Step 2
754 for(int p = 1; p < kernel_depth; ++p)
755 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100756 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
757 const uint8_t *input_base = input_ptr + p * input_stride_z;
758 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
759 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
760 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
761 const auto vk_r0 = load_matrix_row(ptr_k_r0);
762 const auto vk_r1 = load_matrix_row(ptr_k_r1);
763 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100764 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
765 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100766 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
767 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
768 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100769 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
770 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
771 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
772 {
773 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
774 accumulate_results<stridex>(p_out, vres);
775 }
776 }
777 }
778 }
779 },
780 in, out);
781 }
782};
783
Pablo Tello06da39d2017-08-10 15:10:40 +0100784template <typename T1, typename T2, unsigned int stridex>
785class convolver_5x5
786{
787public:
788 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
789 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
790 {
791 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
792 const int input_stride_x = input->info()->strides_in_bytes().x();
793 const int input_stride_y = input->info()->strides_in_bytes().y();
794 const int input_stride_z = input->info()->strides_in_bytes().z();
795 const int output_stride_y = output->info()->strides_in_bytes().y();
796 const int output_stride_z = output->info()->strides_in_bytes().z();
797 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
798 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
799 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
800 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
801 const int output_w = output->info()->dimension(0);
802 const int output_h = output->info()->dimension(1);
803 const int num_planes_z = window.z().end() - window.z().start();
804 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
805 const int kernel_depth = weights->info()->dimension(Window::DimZ);
806 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
807 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
808 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
809 const int fixed_point_position = input->info()->fixed_point_position();
810
811 // setup output window for the iterator
812 Window window_out = window;
813 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
814 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
815 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
816
817 // setup input window for the iterator
818 Window window_in = window;
819 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
820 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
821 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
822 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
823
824 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
825
826 Iterator out(output, window_out);
827 Iterator in(input, window_in);
828 Iterator k(weights, window_k);
829
830 const uint8_t *k_ptr = k.ptr();
831
832 execute_window_loop(window_out, [&](const Coordinates & id)
833 {
834 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
835 uint8_t *out_ptr = out.ptr();
836 int ih = 0;
837 int oh = 0;
838 for(int oz = 0; oz < num_planes_z; ++oz)
839 {
840 const int zoffset = id.z() + oz;
841 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
842 // Step 1
843 {
844 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
845 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
846 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
847 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
848 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
849 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
850 {
851 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
852 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
853 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
854 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
855 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
856 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
857 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
858 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
859 {
860 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
861 store_results<stridex>(p_out, vres);
862 }
863 }
864 }
865 // Step 2
866 for(int p = 1; p < kernel_depth; ++p)
867 {
868 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
869 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
870 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
871 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
872 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
873
874 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
875 {
876 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
877 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
878 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
879 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
880 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
881 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
882 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
883 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
884 {
885 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
886 accumulate_results<stridex>(p_out, vres);
887 }
888 }
889 }
890 }
891 },
892 in, out);
893 }
894};
895
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100896template <typename T1, typename T2>
897inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
898 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
899{
900 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
901 switch(conv_stride_x)
902 {
903 case 1:
904 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
905 break;
906 case 2:
907 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
908 break;
909 case 3:
910 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
911 break;
912 default:
913 ARM_COMPUTE_ERROR("Not implemented");
914 }
915}
916
Pablo Telloc09314a2017-09-21 13:59:14 +0100917template <>
918inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
919 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
920{
921 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
922 if(run_optim_small_tensor(input))
923 {
924 switch(conv_stride_x)
925 {
926 case 1:
927 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
928 break;
929 case 2:
930 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
931 break;
932 case 3:
933 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
934 break;
935 default:
936 ARM_COMPUTE_ERROR("Not implemented");
937 }
938 }
939 else
940 {
941 switch(conv_stride_x)
942 {
943 case 1:
944 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
945 break;
946 case 2:
947 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
948 break;
949 case 3:
950 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
951 break;
952 default:
953 ARM_COMPUTE_ERROR("Not implemented");
954 }
955 }
956}
957
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100958template <typename T1, typename T2>
959inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
960 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
961{
962 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
963 switch(conv_stride_x)
964 {
965 case 1:
966 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
967 break;
968 case 2:
969 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
970 break;
971 case 3:
972 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
973 break;
974 default:
975 ARM_COMPUTE_ERROR("Not implemented");
976 }
977}
Pablo Tello06da39d2017-08-10 15:10:40 +0100978
979template <typename T1, typename T2>
980inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
981 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
982{
983 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
984 switch(conv_stride_x)
985 {
986 case 1:
987 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
988 break;
989 case 2:
990 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
991 break;
992 case 3:
993 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
994 break;
995 default:
996 ARM_COMPUTE_ERROR("Not implemented");
997 }
998}
999
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001000inline TensorShape get_convolved_dimensions(const ITensorInfo *input, const ITensorInfo *weights, const int kernel_size, const PadStrideInfo &conv_info)
1001{
1002 unsigned int output_width = 0;
1003 unsigned int output_height = 0;
1004 std::tie(output_width, output_height) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_size, kernel_size, conv_info);
1005
1006 TensorShape output_shape = input->tensor_shape();
1007 output_shape.set(0, output_width);
1008 output_shape.set(1, output_height);
1009 output_shape.set(2, weights->dimension(3));
1010
1011 return output_shape;
1012}
1013
1014Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1015{
1016 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
1017 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
1018 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
1019 ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
1020 "Pad > 0 not supported for 1x1 weights");
1021 ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
1022 "Pad > 1 not supported for 3x3 weights");
1023 ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights->dimension(0) == 5 && (std::get<0>(conv_info.pad()) > 2 || std::get<1>(conv_info.pad()) > 2),
1024 "Pad > 2 not supported for 5x5 weights");
1025
1026 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
1027 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) != input->dimension(2));
1028 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(0) != weights->dimension(1));
1029 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
1030
1031 // Checks performed when output is configured
1032 if(output->total_size() != 0)
1033 {
1034 TensorShape output_shape = get_convolved_dimensions(input, weights, weights->dimension(0), conv_info);
1035
1036 DataType data_type = input->data_type();
1037 if(is_data_type_fixed_point(data_type))
1038 {
1039 // Promote data type in case of fixed point
1040 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1041 }
1042
1043 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
1044 ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
1045 }
1046
1047 return Status{};
1048}
1049
1050std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001051 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001052{
1053 // Calculate right and bottom border
1054 unsigned int kernel_size = weights->dimension(0);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001055 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1056 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001057 const int input_width = input->dimension(0);
1058 const int input_height = input->dimension(1);
1059
1060 switch(kernel_size)
1061 {
1062 case 1:
1063 {
1064 switch(input->data_type())
1065 {
1066#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1067 case DataType::F16:
1068#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1069 case DataType::QS8:
1070 case DataType::QS16:
1071 num_elems_written_per_iteration = 8;
1072 break;
1073 case DataType::F32:
1074 if(run_optim_small_tensor_info(input))
1075 {
1076 num_elems_written_per_iteration = 8;
1077 }
1078 else
1079 {
1080 num_elems_written_per_iteration = 4;
1081 }
1082 break;
1083 default:
1084 ARM_COMPUTE_ERROR("Data type not supported.");
1085 break;
1086 }
1087 num_weight_elems_read_per_row = kernel_size;
1088 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
1089 break;
1090 }
1091 case 3:
1092 case 5:
1093 {
1094 switch(input->data_type())
1095 {
1096 case DataType::F32:
1097 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1098 num_elems_read_per_iteration = 12;
1099 num_elems_written_per_iteration = 16 >> conv_stride_x;
1100 break;
1101#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
1102 case DataType::F16:
1103#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1104 case DataType::QS8:
1105 case DataType::QS16:
1106 num_weight_elems_read_per_row = 8 + kernel_size - 1;
1107 num_elems_read_per_iteration = 24;
1108 num_elems_written_per_iteration = 32 >> conv_stride_x;
1109 break;
1110 default:
1111 ARM_COMPUTE_ERROR("Data type not supported.");
1112 break;
1113 }
1114 }
1115 break;
1116 default:
1117 {
1118 ARM_COMPUTE_ERROR("Not implemented");
1119 break;
1120 }
1121 }
1122
Michalis Spyrou621965e2018-01-08 17:11:26 +00001123 // Calculate border
1124 int upper_bound_w = ceil_to_multiple(((output->dimension(0) - 1) * conv_stride_x + kernel_size), num_elems_read_per_iteration) - conv_info.pad_left() - conv_info.pad_right() - input_width;
1125 int upper_bound_h = ((output->dimension(1) - 1) * conv_stride_y - conv_info.pad_top() - conv_info.pad_bottom() + kernel_size) - input_height;
1126
1127 const unsigned int conv_pad_left = std::max(upper_bound_w - static_cast<int>(conv_info.pad_right()), static_cast<int>(kernel_size) / 2);
1128 const unsigned int conv_pad_top = std::max(upper_bound_h - static_cast<int>(conv_info.pad_bottom()), static_cast<int>(kernel_size) / 2);
1129 const unsigned int conv_pad_right = std::max(upper_bound_w - static_cast<int>(conv_info.pad_left()), static_cast<int>(kernel_size) / 2);
1130 const unsigned int conv_pad_bottom = std::max(upper_bound_h - static_cast<int>(conv_info.pad_top()), static_cast<int>(kernel_size) / 2);
1131
1132 border_size.right = conv_pad_right;
1133 border_size.bottom = conv_pad_bottom;
1134 border_size.left = conv_pad_left;
1135 border_size.top = conv_pad_top;
1136
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001137 Window win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001138 AccessWindowStatic input_access(input, -conv_pad_left, -conv_pad_top, input_width + conv_pad_right, input_height + conv_pad_bottom);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001139 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
1140 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
1141 bool window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1142 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1143
1144 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1145 return std::make_pair(err, win);
1146}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001147} // namespace
1148
1149NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001150 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1151 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001152{
1153}
1154
1155BorderSize NEDirectConvolutionLayerKernel::border_size() const
1156{
1157 return _border_size;
1158}
1159
1160void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1161{
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001162 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001163
1164 _input = input;
1165 _weights = weights;
1166 _output = output;
1167 _conv_info = conv_info;
1168 _kernel_size = weights->info()->dimension(0);
Michalis Spyrou621965e2018-01-08 17:11:26 +00001169
1170 const unsigned int conv_pad_left = conv_info.pad_left();
1171 const unsigned int conv_pad_top = conv_info.pad_top();
1172 const unsigned int conv_pad_right = conv_info.pad_right();
1173 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1174 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001175
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001176 // Get convolved dimensions
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001177 TensorShape output_shape = get_convolved_dimensions(input->info(), weights->info(), _kernel_size, conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001178
1179 DataType data_type = input->info()->data_type();
1180
1181 if(is_data_type_fixed_point(data_type))
1182 {
1183 // Promote data type in case of fixed point
1184 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1185 }
1186
1187 // Output auto inizialitation if not yet initialized
1188 auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
1189
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001190 // Perform validation step
1191 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001192
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001193 // Configure kernel window
1194 auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001195 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001196 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1197 INEKernel::configure(win_config.second);
1198}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001199
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001200Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1201{
1202 unsigned int num_weight_elems_read_per_row = 0;
1203 unsigned int num_elems_read_per_iteration = 0;
1204 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas0223a782017-12-12 11:44:44 +00001205 BorderSize border_size(conv_info.pad().first, conv_info.pad().second);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001206 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info));
Georgios Pinitas0223a782017-12-12 11:44:44 +00001207 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1208 weights->clone().get(),
1209 output->clone().get(),
1210 conv_info,
1211 num_weight_elems_read_per_row,
1212 num_elems_read_per_iteration,
1213 num_elems_written_per_iteration,
1214 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001215 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001216
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001217 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001218}
1219
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001220void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001221{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001222 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001223 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1224 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1225 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1226
1227 const int kernel_size = _weights->info()->dimension(0);
1228
1229 switch(kernel_size)
1230 {
1231 case 1:
1232 {
Pablo Tello0d176142017-07-06 16:43:14 +01001233 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001234 {
Pablo Tello0d176142017-07-06 16:43:14 +01001235 case DataType::QS8:
1236 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1237 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001238 case DataType::QS16:
1239 convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1240 break;
Pablo Tello0d176142017-07-06 16:43:14 +01001241 case DataType::F32:
1242 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1243 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001244#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +01001245 case DataType::F16:
1246 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1247 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001248#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +01001249 default:
1250 ARM_COMPUTE_ERROR("Data type not supported");
1251 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001252 }
1253 break;
1254 }
1255 case 3:
1256 {
Pablo Tello0d176142017-07-06 16:43:14 +01001257 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001258 {
Pablo Tello0d176142017-07-06 16:43:14 +01001259 case DataType::QS8:
1260 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1261 break;
1262 case DataType::F32:
1263 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1264 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001265#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +01001266 case DataType::F16:
1267 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1268 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001269#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +01001270 default:
1271 ARM_COMPUTE_ERROR("Data type not supported");
1272 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001273 }
1274 break;
1275 }
Pablo Tello06da39d2017-08-10 15:10:40 +01001276 case 5:
1277 {
1278 switch(_input->info()->data_type())
1279 {
1280 case DataType::F32:
1281 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1282 break;
1283 default:
1284 ARM_COMPUTE_ERROR("Data type not supported");
1285 break;
1286 }
1287 break;
1288 }
1289
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001290 default:
1291 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001292 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001293 break;
1294 }
1295 }
1296}