blob: 60a3a1b636aaeb3f50f7621ef89c9b7e8e9b2589 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010033#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/Validate.h"
35
36#include <algorithm>
37#include <arm_neon.h>
38
39using namespace arm_compute;
40
41namespace
42{
Pablo Tellof87cc7f2017-07-26 10:28:40 +010043template <unsigned int stridex>
44qint16x8_t internal_vld1q(const qint16_t *in);
45
46template <>
47qint16x8_t internal_vld1q<1>(const qint16_t *in)
48{
49 return vld1q_qs16(in);
50}
51
52template <>
53qint16x8_t internal_vld1q<2>(const qint16_t *in)
54{
55 const int16x8x2_t tmp = vld2q_s16(in);
56 return tmp.val[0];
57}
58
59template <>
60qint16x8_t internal_vld1q<3>(const qint16_t *in)
61{
62 const int16x8x3_t tmp = vld3q_s16(in);
63 return tmp.val[0];
64}
65
66inline qint16x8_t internal_vdupq_n(qint16_t v)
67{
68 return vdupq_n_qs16(v);
69}
70
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +010071#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello0d176142017-07-06 16:43:14 +010072template <unsigned int stridex>
73float16x8_t internal_vld1q(const float16_t *in);
74
75template <>
76float16x8_t internal_vld1q<1>(const float16_t *in)
77{
78 return vld1q_f16(in);
79}
80
81template <>
82float16x8_t internal_vld1q<2>(const float16_t *in)
83{
84 const float16x8x2_t tmp = vld2q_f16(in);
85 return tmp.val[0];
86}
87
88template <>
89float16x8_t internal_vld1q<3>(const float16_t *in)
90{
91 const float16x8x3_t tmp = vld3q_f16(in);
92 return tmp.val[0];
93}
94
95inline float16x8_t internal_vdupq_n(float16_t v)
96{
97 return vdupq_n_f16(v);
98}
99
100inline void internal_vst1q(float16_t *p, const float16x8_t &v)
101{
102 vst1q_f16(p, v);
103}
104
105float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
106{
107 ARM_COMPUTE_UNUSED(fixed_point_position);
108 return vmulq_f16(x, y);
109}
110
111inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
112{
113 ARM_COMPUTE_UNUSED(fixed_point_position);
114 return vaddq_f16(x, vmulq_f16(y, z));
115}
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +0100116#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Pablo Tello0d176142017-07-06 16:43:14 +0100117
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100118template <unsigned int stridex>
119float32x4_t internal_vld1q(const float *in);
120
121template <>
122float32x4_t internal_vld1q<1>(const float *in)
123{
124 return vld1q_f32(in);
125}
126
127template <>
128float32x4_t internal_vld1q<2>(const float *in)
129{
130 const float32x4x2_t tmp = vld2q_f32(in);
131 return tmp.val[0];
132}
133
134template <>
135float32x4_t internal_vld1q<3>(const float *in)
136{
137 const float32x4x3_t tmp = vld3q_f32(in);
138 return tmp.val[0];
139}
140
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100141inline float32x4_t internal_vdupq_n(float v)
142{
143 return vdupq_n_f32(v);
144}
145
146inline void internal_vst1q(float *p, const float32x4_t &v)
147{
148 vst1q_f32(p, v);
149}
150
151float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
152{
153 ARM_COMPUTE_UNUSED(fixed_point_position);
154 return vmulq_f32(x, y);
155}
156
157inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
158{
159 ARM_COMPUTE_UNUSED(fixed_point_position);
160 return vmlaq_f32(x, y, z);
161}
162
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100163template <unsigned int stridex>
164qint8x8_t internal_vld1q(const qint8_t *in);
165
166template <>
167qint8x8_t internal_vld1q<1>(const qint8_t *in)
168{
169 return vld1_qs8(in);
170}
171
172template <>
173qint8x8_t internal_vld1q<2>(const qint8_t *in)
174{
175 const qint8x8x2_t tmp = vld2_s8(in);
176 return tmp.val[0];
177}
178
179template <>
180qint8x8_t internal_vld1q<3>(const qint8_t *in)
181{
182 const qint8x8x3_t tmp = vld3_s8(in);
183 return tmp.val[0];
184}
185
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100186inline qint8x8_t internal_vdupq_n(qint8_t v)
187{
188 return vdup_n_qs8(v);
189}
190
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100191inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100192{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100193 return vmull_qs8(x, y, fixed_point_position);
194}
195
196inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
197{
198 return vqmlal_qs8(x, y, z, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100199}
200
201inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
202{
203 vst1q_qs16(p, v);
204}
205
Michalis Spyrou490bf2e2017-09-29 11:24:55 +0100206inline void internal_vst1q(int32_t *p, const qint32x4x2_t &v)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100207{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100208 vst1q_s32(p, v.val[0]);
209 vst1q_s32(p + 4, v.val[1]);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100210}
211
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100212template <unsigned int stridex>
213qint32x4x2_t internal_vld1q(const qint32_t *in);
214
215template <>
216qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100217{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100218 const qint32x4x2_t r =
219 {
220 {
221 vld1q_s32(in),
222 vld1q_s32(in + 4)
223 }
224 };
225 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100226}
227
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100228inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100229{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100230 const qint32x4x2_t r =
231 {
232 {
233 vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
234 vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
235 }
236 };
237 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100238}
239
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100240inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100241{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100242 const qint32x4x2_t r =
243 {
244 {
245 vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
246 vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
247 }
248 };
249 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100250}
251
Pablo Telloc09314a2017-09-21 13:59:14 +0100252constexpr int SmallTensorSizeOptim = 8;
253inline bool run_optim_small_tensor(const ITensor *t)
254{
255 return t->info()->dimension(Window::DimX) <= SmallTensorSizeOptim && t->info()->dimension(Window::DimY) <= SmallTensorSizeOptim;
256}
257
258// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
259// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
260// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
261template <unsigned int stridex>
262class convolver_w1x1_i8x8_f32
263{
264public:
265 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
266 {
267 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > SmallTensorSizeOptim);
268 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > SmallTensorSizeOptim);
269
270 const int input_stride_y = input->info()->strides_in_bytes().y();
271 const int input_stride_z = input->info()->strides_in_bytes().z();
272 const int output_stride_y = output->info()->strides_in_bytes().y();
273 const int output_stride_z = output->info()->strides_in_bytes().z();
274 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
275 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
276 const int output_h = output->info()->dimension(1);
277 const int range_z = window.z().end() - window.z().start();
278 const int kernel_depth = weights->info()->dimension(Window::DimZ);
279 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
280
281 // setup output window for the iterator
282 Window window_out = window;
283 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
284 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
285 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
286
287 // setup input window for the iterator
288 Window window_in = window;
289 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
290 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
291 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
292 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
293
294 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
295 Iterator out(output, window_out);
296 Iterator in(input, window_in);
297 Iterator k(weights, window_k);
298
299 const uint8_t *k_ptr = k.ptr();
300
301 execute_window_loop(window_out, [&](const Coordinates & id)
302 {
303 const uint8_t *input_ptr = in.ptr();
304 uint8_t *out_ptr = out.ptr();
305 int ih = 0;
306 int oh = 0;
307 float32x4_t accum0[SmallTensorSizeOptim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
308 float32x4_t accum1[SmallTensorSizeOptim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
309 for(int oz = 0; oz < range_z; ++oz)
310 {
311 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
312 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
313 auto p_out_base = out_ptr + oz * output_stride_z;
314 for(int p = 0; p < kernel_depth; ++p)
315 {
316 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
317 const auto vk0 = internal_vdupq_n(*k_val);
318 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
319 {
320 const int offset_xy = ih * input_stride_y;
321 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
322 auto v_in0 = internal_vld1q<stridex>(in_val);
323 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
324 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
325 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
326 }
327 }
328 for(oh = 0; oh < output_h; ++oh)
329 {
330 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
331 vst1q_f32(p_out, accum0[oh]);
332 vst1q_f32(p_out + 4, accum1[oh]);
333 }
334 }
335 },
336 in, out);
337 }
338};
339
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100340template <typename T1, typename T2, unsigned int stridex>
341class convolver_1x1
342{
343public:
344 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
345 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
346 {
347 const int input_stride_y = input->info()->strides_in_bytes().y();
348 const int input_stride_z = input->info()->strides_in_bytes().z();
349 const int output_stride_y = output->info()->strides_in_bytes().y();
350 const int output_stride_z = output->info()->strides_in_bytes().z();
351 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
352 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
353 const int output_w = output->info()->dimension(0);
354 const int output_h = output->info()->dimension(1);
355 const int range_z = window.z().end() - window.z().start();
356 const int kernel_depth = weights->info()->dimension(Window::DimZ);
357 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
358 const int fixed_point_position = input->info()->fixed_point_position();
359
360 // setup output window for the iterator
361 Window window_out = window;
362 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
363 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
364 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
365
366 // setup input window for the iterator
367 Window window_in = window;
368 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
369 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
370 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
371 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
372
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100373 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100374 Iterator out(output, window_out);
375 Iterator in(input, window_in);
376 Iterator k(weights, window_k);
377
378 const uint8_t *k_ptr = k.ptr();
379
380 execute_window_loop(window_out, [&](const Coordinates & id)
381 {
382 /*
383 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
384 */
385 const uint8_t *input_ptr = in.ptr();
386 uint8_t *out_ptr = out.ptr();
387 int ih = 0;
388 int oh = 0;
389 for(int oz = 0; oz < range_z; ++oz)
390 {
391 auto p_out_base = out_ptr + oz * output_stride_z;
392 // Step 1
393 {
394 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
395 const auto vk = internal_vdupq_n(*k_val);
396 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
397 {
398 const int offset_xy = ih * input_stride_y;
399 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
400 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
401 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
402 {
403 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
404 }
405 }
406 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100407
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100408 // Step 2
409 for(int p = 1; p < kernel_depth; ++p)
410 {
411 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
412 const auto vk = internal_vdupq_n(*k_val);
413 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
414 {
415 const int offset_xy = ih * input_stride_y;
416 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
417 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
418 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
419 {
420 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
421 }
422 }
423 }
424 }
425 },
426 in, out);
427 }
428};
429
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +0100430#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello0d176142017-07-06 16:43:14 +0100431inline float16x8x3_t load_matrix_row(const float16_t *ptr)
432{
433 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
434 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
435 const float16x8x3_t r =
436 {
437 {
438 vld1q_dup_f16(ptr),
439 vld1q_dup_f16(1 + ptr),
440 vld1q_dup_f16(2 + ptr)
441 }
442 };
443 return r;
444}
445
446template <unsigned int stridex>
447float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
448 int fixed_point_position);
449
450template <>
451float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
452 int fixed_point_position)
453{
454 ARM_COMPUTE_UNUSED(fixed_point_position);
455
456 const float16x8x3_t vtop =
457 {
458 {
459 vld1q_f16(in_top),
460 vld1q_f16(in_top + 8),
461 vld1q_f16(in_top + 16)
462 }
463 };
464 const float16x8x3_t vmid =
465 {
466 {
467 vld1q_f16(in_mid),
468 vld1q_f16(in_mid + 8),
469 vld1q_f16(in_mid + 16)
470 }
471 };
472 const float16x8x3_t vlow =
473 {
474 {
475 vld1q_f16(in_low),
476 vld1q_f16(in_low + 8),
477 vld1q_f16(in_low + 16)
478 }
479 };
480 float16x8x2_t out =
481 {
482 {
483 vmulq_f16(vtop.val[0], m0.val[0]),
484 vmulq_f16(vtop.val[1], m0.val[0])
485 }
486 };
487 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
488 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
489 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
490 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
491 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
492 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
493 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
494 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
495 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
496 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
497 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
498 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
499 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
500 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
501 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
502 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
503 return out;
504}
505
506template <>
507inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
508 int fixed_point_position)
509{
510 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
511 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
512 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 2);
513 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 3);
514 return out;
515}
516
517template <>
518inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
519 int fixed_point_position)
520{
521 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
522 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
523 return out;
524}
525
526template <unsigned int stridex>
527void store_results(float16_t *buffer, const float16x8x2_t &values);
528
529template <>
530void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
531{
532 vst1q_f16(buffer, values.val[0]);
533 vst1q_f16(buffer + 8, values.val[1]);
534}
535
536template <>
537void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
538{
539 vst1q_f16(buffer, values.val[0]);
540}
541
542template <>
543void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
544{
545 vst1_f16(buffer, vget_low_f16(values.val[0]));
546}
547
548template <unsigned int stridex>
549void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
550
551template <>
552void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
553{
554 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
555 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
556}
557
558template <>
559void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
560{
561 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
562}
563
564template <>
565void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
566{
567 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
568}
569
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +0100570#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Pablo Tello0d176142017-07-06 16:43:14 +0100571
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100572inline float32x4x3_t load_matrix_row(const float *ptr)
573{
574 const float32x4x3_t r =
575 {
576 {
577 vld1q_dup_f32(ptr),
578 vld1q_dup_f32(1 + ptr),
579 vld1q_dup_f32(2 + ptr)
580 }
581 };
582 return r;
583}
584inline qint8x8x3_t load_matrix_row(const qint8_t *ptr)
585{
586 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
587 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
588 const qint8x8x3_t r =
589 {
590 {
591 vld1_dup_qs8(ptr),
592 vld1_dup_qs8(1 + ptr),
593 vld1_dup_qs8(2 + ptr)
594 }
595 };
596 return r;
597}
598
599template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100600float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
601 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position);
602
603inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
604{
605 const float32x4x3_t m00 =
606 {
607 {
608 vld1q_dup_f32(m0),
609 vld1q_dup_f32(m1),
610 vld1q_dup_f32(m2)
611 }
612 };
613 return m00;
614}
615
616inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
617{
618 const float32x4x2_t m00 =
619 {
620 {
621 vld1q_dup_f32(m3),
622 vld1q_dup_f32(m4)
623 }
624 };
625 return m00;
626}
627
628inline float32x4x3_t load_input(const float *const in)
629{
630 const float32x4x3_t vin =
631 {
632 {
633 vld1q_f32(in),
634 vld1q_f32(in + 4),
635 vld1q_f32(in + 8)
636 }
637 };
638 return vin;
639}
640
641template <>
642inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
643 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
644{
645 ARM_COMPUTE_UNUSED(fixed_point_position);
646 const float32x4x3_t vin0 = load_input(in_0);
647 const float32x4x3_t vin1 = load_input(in_1);
648 const float32x4x3_t vin2 = load_input(in_2);
649 const float32x4x3_t vin3 = load_input(in_3);
650 const float32x4x3_t vin4 = load_input(in_4);
651 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
652 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
653 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
654 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
655 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
656 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
657 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
658 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
659 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
660 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
661
662 float32x4x2_t out =
663 {
664 {
665 vmulq_f32(vin0.val[0], m00.val[0]),
666 vmulq_f32(vin0.val[1], m00.val[0])
667 }
668 };
669
670 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
671 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
672 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
673 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
674
675 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
676 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
677 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
678 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
679 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
680
681 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
682 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
683 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
684 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
685 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
686
687 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
688 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
689 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
690 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
691 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
692
693 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
694 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
695 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
696 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
697 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
698
699 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
700 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
701 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
702 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
703
704 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
705 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
706 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
707 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
708 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
709
710 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
711 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
712 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
713 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
714 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
715
716 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
717 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
718 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
719 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
720 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
721
722 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
723 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
724 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
725 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
726 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
727
728 return out;
729}
730
731template <>
732inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
733 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
734{
735 ARM_COMPUTE_UNUSED(fixed_point_position);
736 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
737 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
738 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
739 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
740 return out;
741}
742
743template <>
744inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
745 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
746{
747 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
748 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
749 return out;
750}
751
752template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100753float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position);
754
755template <>
756inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
757{
758 ARM_COMPUTE_UNUSED(fixed_point_position);
759
760 const float32x4x3_t vtop =
761 {
762 {
763 vld1q_f32(in_top),
764 vld1q_f32(in_top + 4),
765 vld1q_f32(in_top + 8)
766 }
767 };
768 const float32x4x3_t vmid =
769 {
770 {
771 vld1q_f32(in_mid),
772 vld1q_f32(in_mid + 4),
773 vld1q_f32(in_mid + 8)
774 }
775 };
776 const float32x4x3_t vlow =
777 {
778 {
779 vld1q_f32(in_low),
780 vld1q_f32(in_low + 4),
781 vld1q_f32(in_low + 8)
782 }
783 };
784 float32x4x2_t out =
785 {
786 {
787 vmulq_f32(vtop.val[0], m0.val[0]),
788 vmulq_f32(vtop.val[1], m0.val[0])
789 }
790 };
791 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
792 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100793
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100794 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
795 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
796 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100797
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100798 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
799 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
800 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100801
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100802 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
803 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100804
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100805 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
806 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
807 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100808
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100809 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
810 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
811 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
812 return out;
813}
814
815template <>
816inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
817{
818 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
819 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
820 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
821 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
822 return out;
823}
824
825template <>
826inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
827{
828 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
829 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
830 return out;
831}
832
833template <unsigned int stridex>
834qint16x8x2_t convolve_3x3(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position);
835
836template <>
837inline qint16x8x2_t convolve_3x3<1>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
838{
839 ARM_COMPUTE_UNUSED(fixed_point_position);
840
841 const qint8x8x3_t vtop =
842 {
843 {
844 vld1_qs8(in_top),
845 vld1_qs8(in_top + 8),
846 vld1_qs8(in_top + 16)
847 }
848 };
849 const qint8x8x3_t vmid =
850 {
851 {
852 vld1_qs8(in_mid),
853 vld1_qs8(in_mid + 8),
854 vld1_qs8(in_mid + 16)
855 }
856 };
857 const qint8x8x3_t vlow =
858 {
859 {
860 vld1_qs8(in_low),
861 vld1_qs8(in_low + 8),
862 vld1_qs8(in_low + 16)
863 }
864 };
865 qint16x8x2_t out =
866 {
867 {
868 vmull_qs8(vtop.val[0], m0.val[0], fixed_point_position),
869 vmull_qs8(vtop.val[1], m0.val[0], fixed_point_position)
870 }
871 };
872 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 1), m0.val[1], fixed_point_position);
873 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 2), m0.val[2], fixed_point_position);
874 out.val[0] = vqmlal_qs8(out.val[0], vmid.val[0], m1.val[0], fixed_point_position);
875 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 1), m1.val[1], fixed_point_position);
876 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 2), m1.val[2], fixed_point_position);
877 out.val[0] = vqmlal_qs8(out.val[0], vlow.val[0], m2.val[0], fixed_point_position);
878 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 1), m2.val[1], fixed_point_position);
879 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 2), m2.val[2], fixed_point_position);
880 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 1), m0.val[1], fixed_point_position);
881 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 2), m0.val[2], fixed_point_position);
882 out.val[1] = vqmlal_qs8(out.val[1], vmid.val[1], m1.val[0], fixed_point_position);
883 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 1), m1.val[1], fixed_point_position);
884 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 2), m1.val[2], fixed_point_position);
885 out.val[1] = vqmlal_qs8(out.val[1], vlow.val[1], m2.val[0], fixed_point_position);
886 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 1), m2.val[1], fixed_point_position);
887 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 2), m2.val[2], fixed_point_position);
888 return out;
889}
890
891template <>
892inline qint16x8x2_t convolve_3x3<2>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
893{
894 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
895 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 2), out.val[0], 1);
896 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 4), out.val[0], 2);
897 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 3);
898 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 0), out.val[0], 4);
899 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 2), out.val[0], 5);
900 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 4), out.val[0], 6);
901 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 6), out.val[0], 7);
902 return out;
903}
904
905template <>
906inline qint16x8x2_t convolve_3x3<3>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
907{
908 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
909 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 3), out.val[0], 1);
910 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 2);
911 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 1), out.val[0], 3);
912 return out;
913}
914
915template <unsigned int stridex>
916void store_results(float *buffer, const float32x4x2_t &values);
917
918template <>
919void store_results<1>(float *buffer, const float32x4x2_t &values)
920{
921 vst1q_f32(buffer, values.val[0]);
922 vst1q_f32(buffer + 4, values.val[1]);
923}
924
925template <>
926void store_results<2>(float *buffer, const float32x4x2_t &values)
927{
928 vst1q_f32(buffer, values.val[0]);
929}
930
931template <>
932void store_results<3>(float *buffer, const float32x4x2_t &values)
933{
934 vst1_f32(buffer, vget_low_f32(values.val[0]));
935}
936
937template <unsigned int stridex>
938void store_results(qint16_t *buffer, const qint16x8x2_t &values);
939
940template <>
941void store_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
942{
943 vst1q_qs16(buffer, values.val[0]);
944 vst1q_qs16(buffer + 8, values.val[1]);
945}
946
947template <>
948void store_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
949{
950 vst1q_qs16(buffer, values.val[0]);
951}
952
953template <>
954void store_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
955{
956 vst1_qs16(buffer, vget_low_s16(values.val[0]));
957}
958
959template <unsigned int stridex>
960void accumulate_results(float *buffer, const float32x4x2_t &values);
961
962template <>
963void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
964{
965 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
966 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
967}
968
969template <>
970void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
971{
972 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
973}
974
975template <>
976void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
977{
978 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
979}
980
981template <unsigned int stridex>
982void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
983
984template <>
985void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
986{
987 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
988 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
989}
990
991template <>
992void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
993{
994 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
995}
996
997template <>
998void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
999{
1000 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
1001}
1002
1003template <unsigned int stridex>
1004int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
1005
1006template <>
1007int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
1008{
1009 return num_elems_written_per_iteration;
1010}
1011
1012template <>
1013int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
1014{
1015 return num_elems_written_per_iteration << 1;
1016}
1017
1018template <>
1019int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
1020{
1021 return num_elems_written_per_iteration * 3;
1022}
1023
1024template <typename T1, typename T2, unsigned int stridex>
1025class convolver_3x3
1026{
1027public:
1028 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1029 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1030 {
1031 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
1032 const int input_stride_x = input->info()->strides_in_bytes().x();
1033 const int input_stride_y = input->info()->strides_in_bytes().y();
1034 const int input_stride_z = input->info()->strides_in_bytes().z();
1035 const int output_stride_y = output->info()->strides_in_bytes().y();
1036 const int output_stride_z = output->info()->strides_in_bytes().z();
1037 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
1038 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
1039 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
1040 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
1041 const int output_w = output->info()->dimension(0);
1042 const int output_h = output->info()->dimension(1);
1043 const int num_planes_z = window.z().end() - window.z().start();
1044 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
1045 const int kernel_depth = weights->info()->dimension(Window::DimZ);
1046 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
1047 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
1048 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
1049 const int fixed_point_position = input->info()->fixed_point_position();
1050
1051 // setup output window for the iterator
1052 Window window_out = window;
1053 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
1054 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
1055 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
1056
1057 // setup input window for the iterator
1058 Window window_in = window;
1059 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
1060 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
1061 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
1062 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
1063
1064 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
1065
1066 Iterator out(output, window_out);
1067 Iterator in(input, window_in);
1068 Iterator k(weights, window_k);
1069
1070 const uint8_t *k_ptr = k.ptr();
1071
1072 execute_window_loop(window_out, [&](const Coordinates & id)
1073 {
1074 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
1075 uint8_t *out_ptr = out.ptr();
1076 int ih = 0;
1077 int oh = 0;
1078 /*
1079 Each thread executing this kernel computes one or more output's volume planes.
1080
1081 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
1082 the third thread [16,24] and the fourth thread [25,31].
1083
1084 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +01001085 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001086
1087 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
1088 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
1089 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
1090 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001091 for(int oz = 0; oz < num_planes_z; ++oz)
1092 {
Pablo Tello0d176142017-07-06 16:43:14 +01001093 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001094 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
1095 // Step 1
1096 {
Pablo Tello0d176142017-07-06 16:43:14 +01001097 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
1098 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
1099 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001100 const auto vk_r0 = load_matrix_row(ptr_k_r0);
1101 const auto vk_r1 = load_matrix_row(ptr_k_r1);
1102 const auto vk_r2 = load_matrix_row(ptr_k_r2);
1103 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
1104 {
1105 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
1106 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
1107 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
1108 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1109 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1110 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
1111 {
1112 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
1113 store_results<stridex>(p_out, vres);
1114 }
1115 }
1116 }
1117 // Step 2
1118 for(int p = 1; p < kernel_depth; ++p)
1119 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001120 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
1121 const uint8_t *input_base = input_ptr + p * input_stride_z;
1122 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
1123 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
1124 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
1125 const auto vk_r0 = load_matrix_row(ptr_k_r0);
1126 const auto vk_r1 = load_matrix_row(ptr_k_r1);
1127 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001128 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
1129 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001130 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
1131 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
1132 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001133 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1134 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1135 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
1136 {
1137 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
1138 accumulate_results<stridex>(p_out, vres);
1139 }
1140 }
1141 }
1142 }
1143 },
1144 in, out);
1145 }
1146};
1147
Pablo Tello06da39d2017-08-10 15:10:40 +01001148template <typename T1, typename T2, unsigned int stridex>
1149class convolver_5x5
1150{
1151public:
1152 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1153 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1154 {
1155 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
1156 const int input_stride_x = input->info()->strides_in_bytes().x();
1157 const int input_stride_y = input->info()->strides_in_bytes().y();
1158 const int input_stride_z = input->info()->strides_in_bytes().z();
1159 const int output_stride_y = output->info()->strides_in_bytes().y();
1160 const int output_stride_z = output->info()->strides_in_bytes().z();
1161 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
1162 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
1163 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
1164 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
1165 const int output_w = output->info()->dimension(0);
1166 const int output_h = output->info()->dimension(1);
1167 const int num_planes_z = window.z().end() - window.z().start();
1168 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
1169 const int kernel_depth = weights->info()->dimension(Window::DimZ);
1170 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
1171 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
1172 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
1173 const int fixed_point_position = input->info()->fixed_point_position();
1174
1175 // setup output window for the iterator
1176 Window window_out = window;
1177 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
1178 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
1179 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
1180
1181 // setup input window for the iterator
1182 Window window_in = window;
1183 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
1184 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
1185 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
1186 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
1187
1188 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
1189
1190 Iterator out(output, window_out);
1191 Iterator in(input, window_in);
1192 Iterator k(weights, window_k);
1193
1194 const uint8_t *k_ptr = k.ptr();
1195
1196 execute_window_loop(window_out, [&](const Coordinates & id)
1197 {
1198 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
1199 uint8_t *out_ptr = out.ptr();
1200 int ih = 0;
1201 int oh = 0;
1202 for(int oz = 0; oz < num_planes_z; ++oz)
1203 {
1204 const int zoffset = id.z() + oz;
1205 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
1206 // Step 1
1207 {
1208 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
1209 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
1210 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
1211 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
1212 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
1213 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
1214 {
1215 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
1216 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
1217 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
1218 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
1219 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
1220 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1221 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1222 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
1223 {
1224 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
1225 store_results<stridex>(p_out, vres);
1226 }
1227 }
1228 }
1229 // Step 2
1230 for(int p = 1; p < kernel_depth; ++p)
1231 {
1232 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
1233 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
1234 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
1235 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
1236 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
1237
1238 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
1239 {
1240 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
1241 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
1242 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
1243 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
1244 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
1245 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1246 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1247 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
1248 {
1249 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
1250 accumulate_results<stridex>(p_out, vres);
1251 }
1252 }
1253 }
1254 }
1255 },
1256 in, out);
1257 }
1258};
1259
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001260template <typename T1, typename T2>
1261inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1262 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1263{
1264 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1265 switch(conv_stride_x)
1266 {
1267 case 1:
1268 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1269 break;
1270 case 2:
1271 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1272 break;
1273 case 3:
1274 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1275 break;
1276 default:
1277 ARM_COMPUTE_ERROR("Not implemented");
1278 }
1279}
1280
Pablo Telloc09314a2017-09-21 13:59:14 +01001281template <>
1282inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1283 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1284{
1285 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1286 if(run_optim_small_tensor(input))
1287 {
1288 switch(conv_stride_x)
1289 {
1290 case 1:
1291 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
1292 break;
1293 case 2:
1294 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
1295 break;
1296 case 3:
1297 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
1298 break;
1299 default:
1300 ARM_COMPUTE_ERROR("Not implemented");
1301 }
1302 }
1303 else
1304 {
1305 switch(conv_stride_x)
1306 {
1307 case 1:
1308 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1309 break;
1310 case 2:
1311 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1312 break;
1313 case 3:
1314 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1315 break;
1316 default:
1317 ARM_COMPUTE_ERROR("Not implemented");
1318 }
1319 }
1320}
1321
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001322template <typename T1, typename T2>
1323inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1324 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1325{
1326 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1327 switch(conv_stride_x)
1328 {
1329 case 1:
1330 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1331 break;
1332 case 2:
1333 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1334 break;
1335 case 3:
1336 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1337 break;
1338 default:
1339 ARM_COMPUTE_ERROR("Not implemented");
1340 }
1341}
Pablo Tello06da39d2017-08-10 15:10:40 +01001342
1343template <typename T1, typename T2>
1344inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1345 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1346{
1347 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1348 switch(conv_stride_x)
1349 {
1350 case 1:
1351 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1352 break;
1353 case 2:
1354 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1355 break;
1356 case 3:
1357 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1358 break;
1359 default:
1360 ARM_COMPUTE_ERROR("Not implemented");
1361 }
1362}
1363
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001364} // namespace
1365
1366NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001367 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1368 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001369{
1370}
1371
1372BorderSize NEDirectConvolutionLayerKernel::border_size() const
1373{
1374 return _border_size;
1375}
1376
1377void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1378{
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001379 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::QS16, DataType::F32);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001380 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001381 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
1382 "Pad > 0 not supported for 1x1 weights");
1383 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
1384 "Pad > 1 not supported for 3x3 weights");
Pablo Tello06da39d2017-08-10 15:10:40 +01001385 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 5 && (std::get<0>(conv_info.pad()) > 2 || std::get<1>(conv_info.pad()) > 2),
1386 "Pad > 2 not supported for 5x5 weights");
1387
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001388 ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001389 ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
1390 ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1));
1391 ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001392
1393 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1394 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
1395 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
1396
1397 _input = input;
1398 _weights = weights;
1399 _output = output;
1400 _conv_info = conv_info;
1401 _kernel_size = weights->info()->dimension(0);
1402 _border_size = BorderSize(conv_pad_y, conv_pad_x);
1403
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001404 const unsigned int kernel_size = weights->info()->dimension(0);
1405
1406 // Get convolved dimensions
1407 unsigned int output_width = 0;
1408 unsigned int output_height = 0;
1409 std::tie(output_width, output_height) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_size, kernel_size, conv_info);
1410
1411 TensorShape output_shape = input->info()->tensor_shape();
1412 output_shape.set(0, output_width);
1413 output_shape.set(1, output_height);
1414 output_shape.set(2, weights->info()->dimension(3));
1415
1416 DataType data_type = input->info()->data_type();
1417
1418 if(is_data_type_fixed_point(data_type))
1419 {
1420 // Promote data type in case of fixed point
1421 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1422 }
1423
1424 // Output auto inizialitation if not yet initialized
1425 auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
1426
1427 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
1428 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, output->info()->data_type());
1429
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001430 switch(_kernel_size)
1431 {
1432 case 1:
1433 {
Pablo Tello0d176142017-07-06 16:43:14 +01001434 switch(input->info()->data_type())
1435 {
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001436#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello0d176142017-07-06 16:43:14 +01001437 case DataType::F16:
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001438#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Pablo Tello0d176142017-07-06 16:43:14 +01001439 case DataType::QS8:
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001440 case DataType::QS16:
Pablo Tello0d176142017-07-06 16:43:14 +01001441 _num_elems_written_per_iteration = 8;
1442 break;
1443 case DataType::F32:
Pablo Telloc09314a2017-09-21 13:59:14 +01001444 if(run_optim_small_tensor(input))
1445 {
1446 _num_elems_written_per_iteration = 8;
1447 }
1448 else
1449 {
1450 _num_elems_written_per_iteration = 4;
1451 }
Pablo Tello0d176142017-07-06 16:43:14 +01001452 break;
1453 default:
1454 ARM_COMPUTE_ERROR("Data type not supported.");
1455 break;
1456 }
Georgios Pinitas898a8062017-09-12 19:19:12 +01001457 _num_weight_elems_read_per_row = kernel_size;
1458 _num_elems_read_per_iteration = conv_stride_x * _num_elems_written_per_iteration;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001459 break;
1460 }
1461 case 3:
Pablo Tello06da39d2017-08-10 15:10:40 +01001462 case 5:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001463 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001464 switch(input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001465 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001466 case DataType::F32:
Georgios Pinitas898a8062017-09-12 19:19:12 +01001467 _num_weight_elems_read_per_row = 4 + _kernel_size - 1;
Pablo Tello06da39d2017-08-10 15:10:40 +01001468 _num_elems_read_per_iteration = 12;
1469 _num_elems_written_per_iteration = 16 >> conv_stride_x;
1470 break;
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001471#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello06da39d2017-08-10 15:10:40 +01001472 case DataType::F16:
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001473#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Pablo Tello06da39d2017-08-10 15:10:40 +01001474 case DataType::QS8:
1475 case DataType::QS16:
Georgios Pinitas898a8062017-09-12 19:19:12 +01001476 _num_weight_elems_read_per_row = 8 + _kernel_size - 1;
Pablo Tello06da39d2017-08-10 15:10:40 +01001477 _num_elems_read_per_iteration = 24;
1478 _num_elems_written_per_iteration = 32 >> conv_stride_x;
1479 break;
1480 default:
1481 ARM_COMPUTE_ERROR("Data type not supported.");
1482 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001483 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001484 }
Georgios Pinitas898a8062017-09-12 19:19:12 +01001485 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001486 default:
1487 {
1488 ARM_COMPUTE_ERROR("Not implemented");
1489 break;
1490 }
1491 }
1492
Georgios Pinitas898a8062017-09-12 19:19:12 +01001493 // Calculate right and bottom border
1494 const unsigned int conv_stride_y = std::get<1>(_conv_info.stride());
1495 const int input_width = input->info()->dimension(0);
1496 const int input_height = input->info()->dimension(1);
1497 const int upper_bound_w = ceil_to_multiple(((output->info()->dimension(0) - 1) * conv_stride_x + _kernel_size), _num_elems_read_per_iteration) - conv_pad_x - input_width;
1498 const int upper_bound_h = ((output->info()->dimension(1) - 1) * conv_stride_y - conv_pad_y + _kernel_size) - input_height;
1499 _border_size.right = std::max(upper_bound_w, static_cast<int>(_kernel_size));
1500 _border_size.bottom = std::max(upper_bound_h, static_cast<int>(_kernel_size));
Pablo Telloc09314a2017-09-21 13:59:14 +01001501 Window win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
Georgios Pinitas898a8062017-09-12 19:19:12 +01001502 AccessWindowStatic input_access(input->info(), -conv_pad_x, -conv_pad_y, input_width + _border_size.right, input_height + _border_size.bottom);
1503 AccessWindowStatic weights_access(weights->info(), 0, 0, _num_weight_elems_read_per_row, _kernel_size);
1504 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
1505 update_window_and_padding(win, input_access, weights_access, output_access);
1506 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
1507
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001508 INEKernel::configure(win);
1509}
1510
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001511void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001512{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001513 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001514 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1515 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1516 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1517
1518 const int kernel_size = _weights->info()->dimension(0);
1519
1520 switch(kernel_size)
1521 {
1522 case 1:
1523 {
Pablo Tello0d176142017-07-06 16:43:14 +01001524 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001525 {
Pablo Tello0d176142017-07-06 16:43:14 +01001526 case DataType::QS8:
1527 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1528 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001529 case DataType::QS16:
1530 convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1531 break;
Pablo Tello0d176142017-07-06 16:43:14 +01001532 case DataType::F32:
1533 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1534 break;
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001535#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello0d176142017-07-06 16:43:14 +01001536 case DataType::F16:
1537 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1538 break;
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001539#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Pablo Tello0d176142017-07-06 16:43:14 +01001540 default:
1541 ARM_COMPUTE_ERROR("Data type not supported");
1542 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001543 }
1544 break;
1545 }
1546 case 3:
1547 {
Pablo Tello0d176142017-07-06 16:43:14 +01001548 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001549 {
Pablo Tello0d176142017-07-06 16:43:14 +01001550 case DataType::QS8:
1551 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1552 break;
1553 case DataType::F32:
1554 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1555 break;
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001556#ifdef ARM_COMPUTE_AARCH64_V8_2
Pablo Tello0d176142017-07-06 16:43:14 +01001557 case DataType::F16:
1558 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1559 break;
Ioan-Cristian Szabo33fd07b2017-10-26 15:42:24 +01001560#endif /* ARM_COMPUTE_AARCH64_V8_2 */
Pablo Tello0d176142017-07-06 16:43:14 +01001561 default:
1562 ARM_COMPUTE_ERROR("Data type not supported");
1563 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001564 }
1565 break;
1566 }
Pablo Tello06da39d2017-08-10 15:10:40 +01001567 case 5:
1568 {
1569 switch(_input->info()->data_type())
1570 {
1571 case DataType::F32:
1572 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1573 break;
1574 default:
1575 ARM_COMPUTE_ERROR("Data type not supported");
1576 break;
1577 }
1578 break;
1579 }
1580
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001581 default:
1582 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001583 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001584 break;
1585 }
1586 }
1587}