blob: 78afbc2c2087d9e7b38eff8c8041fdca70dc9643 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010025#include "arm_compute/core/NEON/kernels/convolution/NEDirectConvolutionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
27#include "arm_compute/core/AccessWindowStatic.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/IAccessWindow.h"
31#include "arm_compute/core/ITensor.h"
32#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010034#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/Validate.h"
36
37#include <algorithm>
38#include <arm_neon.h>
39
40using namespace arm_compute;
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010041using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
43namespace
44{
Pablo Tellof87cc7f2017-07-26 10:28:40 +010045template <unsigned int stridex>
46qint16x8_t internal_vld1q(const qint16_t *in);
47
48template <>
49qint16x8_t internal_vld1q<1>(const qint16_t *in)
50{
51 return vld1q_qs16(in);
52}
53
54template <>
55qint16x8_t internal_vld1q<2>(const qint16_t *in)
56{
57 const int16x8x2_t tmp = vld2q_s16(in);
58 return tmp.val[0];
59}
60
61template <>
62qint16x8_t internal_vld1q<3>(const qint16_t *in)
63{
64 const int16x8x3_t tmp = vld3q_s16(in);
65 return tmp.val[0];
66}
67
68inline qint16x8_t internal_vdupq_n(qint16_t v)
69{
70 return vdupq_n_qs16(v);
71}
72
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000073#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010074template <unsigned int stridex>
75float16x8_t internal_vld1q(const float16_t *in);
76
77template <>
78float16x8_t internal_vld1q<1>(const float16_t *in)
79{
80 return vld1q_f16(in);
81}
82
83template <>
84float16x8_t internal_vld1q<2>(const float16_t *in)
85{
86 const float16x8x2_t tmp = vld2q_f16(in);
87 return tmp.val[0];
88}
89
90template <>
91float16x8_t internal_vld1q<3>(const float16_t *in)
92{
93 const float16x8x3_t tmp = vld3q_f16(in);
94 return tmp.val[0];
95}
96
97inline float16x8_t internal_vdupq_n(float16_t v)
98{
99 return vdupq_n_f16(v);
100}
101
102inline void internal_vst1q(float16_t *p, const float16x8_t &v)
103{
104 vst1q_f16(p, v);
105}
106
107float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
108{
109 ARM_COMPUTE_UNUSED(fixed_point_position);
110 return vmulq_f16(x, y);
111}
112
113inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
114{
115 ARM_COMPUTE_UNUSED(fixed_point_position);
116 return vaddq_f16(x, vmulq_f16(y, z));
117}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000118#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100119
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100120template <unsigned int stridex>
121float32x4_t internal_vld1q(const float *in);
122
123template <>
124float32x4_t internal_vld1q<1>(const float *in)
125{
126 return vld1q_f32(in);
127}
128
129template <>
130float32x4_t internal_vld1q<2>(const float *in)
131{
132 const float32x4x2_t tmp = vld2q_f32(in);
133 return tmp.val[0];
134}
135
136template <>
137float32x4_t internal_vld1q<3>(const float *in)
138{
139 const float32x4x3_t tmp = vld3q_f32(in);
140 return tmp.val[0];
141}
142
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100143inline float32x4_t internal_vdupq_n(float v)
144{
145 return vdupq_n_f32(v);
146}
147
148inline void internal_vst1q(float *p, const float32x4_t &v)
149{
150 vst1q_f32(p, v);
151}
152
153float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
154{
155 ARM_COMPUTE_UNUSED(fixed_point_position);
156 return vmulq_f32(x, y);
157}
158
159inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
160{
161 ARM_COMPUTE_UNUSED(fixed_point_position);
162 return vmlaq_f32(x, y, z);
163}
164
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100165template <unsigned int stridex>
166qint8x8_t internal_vld1q(const qint8_t *in);
167
168template <>
169qint8x8_t internal_vld1q<1>(const qint8_t *in)
170{
171 return vld1_qs8(in);
172}
173
174template <>
175qint8x8_t internal_vld1q<2>(const qint8_t *in)
176{
177 const qint8x8x2_t tmp = vld2_s8(in);
178 return tmp.val[0];
179}
180
181template <>
182qint8x8_t internal_vld1q<3>(const qint8_t *in)
183{
184 const qint8x8x3_t tmp = vld3_s8(in);
185 return tmp.val[0];
186}
187
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100188inline qint8x8_t internal_vdupq_n(qint8_t v)
189{
190 return vdup_n_qs8(v);
191}
192
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100193inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100194{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100195 return vmull_qs8(x, y, fixed_point_position);
196}
197
198inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
199{
200 return vqmlal_qs8(x, y, z, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100201}
202
203inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
204{
205 vst1q_qs16(p, v);
206}
207
Michalis Spyrou490bf2e2017-09-29 11:24:55 +0100208inline void internal_vst1q(int32_t *p, const qint32x4x2_t &v)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100210 vst1q_s32(p, v.val[0]);
211 vst1q_s32(p + 4, v.val[1]);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100212}
213
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100214template <unsigned int stridex>
215qint32x4x2_t internal_vld1q(const qint32_t *in);
216
217template <>
218qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100219{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100220 const qint32x4x2_t r =
221 {
222 {
223 vld1q_s32(in),
224 vld1q_s32(in + 4)
225 }
226 };
227 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100228}
229
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100230inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100232 const qint32x4x2_t r =
233 {
234 {
235 vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
236 vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
237 }
238 };
239 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100240}
241
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100242inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100243{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100244 const qint32x4x2_t r =
245 {
246 {
247 vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
248 vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
249 }
250 };
251 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100252}
253
Pablo Telloc09314a2017-09-21 13:59:14 +0100254constexpr int SmallTensorSizeOptim = 8;
255inline bool run_optim_small_tensor(const ITensor *t)
256{
257 return t->info()->dimension(Window::DimX) <= SmallTensorSizeOptim && t->info()->dimension(Window::DimY) <= SmallTensorSizeOptim;
258}
259
260// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
261// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
262// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
263template <unsigned int stridex>
264class convolver_w1x1_i8x8_f32
265{
266public:
267 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
268 {
269 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > SmallTensorSizeOptim);
270 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > SmallTensorSizeOptim);
271
272 const int input_stride_y = input->info()->strides_in_bytes().y();
273 const int input_stride_z = input->info()->strides_in_bytes().z();
274 const int output_stride_y = output->info()->strides_in_bytes().y();
275 const int output_stride_z = output->info()->strides_in_bytes().z();
276 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
277 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
278 const int output_h = output->info()->dimension(1);
279 const int range_z = window.z().end() - window.z().start();
280 const int kernel_depth = weights->info()->dimension(Window::DimZ);
281 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
282
283 // setup output window for the iterator
284 Window window_out = window;
285 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
286 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
287 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
288
289 // setup input window for the iterator
290 Window window_in = window;
291 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
292 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
293 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
294 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
295
296 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
297 Iterator out(output, window_out);
298 Iterator in(input, window_in);
299 Iterator k(weights, window_k);
300
301 const uint8_t *k_ptr = k.ptr();
302
303 execute_window_loop(window_out, [&](const Coordinates & id)
304 {
305 const uint8_t *input_ptr = in.ptr();
306 uint8_t *out_ptr = out.ptr();
307 int ih = 0;
308 int oh = 0;
309 float32x4_t accum0[SmallTensorSizeOptim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
310 float32x4_t accum1[SmallTensorSizeOptim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
311 for(int oz = 0; oz < range_z; ++oz)
312 {
313 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
314 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
315 auto p_out_base = out_ptr + oz * output_stride_z;
316 for(int p = 0; p < kernel_depth; ++p)
317 {
318 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
319 const auto vk0 = internal_vdupq_n(*k_val);
320 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
321 {
322 const int offset_xy = ih * input_stride_y;
323 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
324 auto v_in0 = internal_vld1q<stridex>(in_val);
325 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
326 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
327 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
328 }
329 }
330 for(oh = 0; oh < output_h; ++oh)
331 {
332 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
333 vst1q_f32(p_out, accum0[oh]);
334 vst1q_f32(p_out + 4, accum1[oh]);
335 }
336 }
337 },
338 in, out);
339 }
340};
341
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342template <typename T1, typename T2, unsigned int stridex>
343class convolver_1x1
344{
345public:
346 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
347 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
348 {
349 const int input_stride_y = input->info()->strides_in_bytes().y();
350 const int input_stride_z = input->info()->strides_in_bytes().z();
351 const int output_stride_y = output->info()->strides_in_bytes().y();
352 const int output_stride_z = output->info()->strides_in_bytes().z();
353 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
354 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
355 const int output_w = output->info()->dimension(0);
356 const int output_h = output->info()->dimension(1);
357 const int range_z = window.z().end() - window.z().start();
358 const int kernel_depth = weights->info()->dimension(Window::DimZ);
359 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
360 const int fixed_point_position = input->info()->fixed_point_position();
361
362 // setup output window for the iterator
363 Window window_out = window;
364 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
365 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
366 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
367
368 // setup input window for the iterator
369 Window window_in = window;
370 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
371 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
372 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
373 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
374
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100375 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100376 Iterator out(output, window_out);
377 Iterator in(input, window_in);
378 Iterator k(weights, window_k);
379
380 const uint8_t *k_ptr = k.ptr();
381
382 execute_window_loop(window_out, [&](const Coordinates & id)
383 {
384 /*
385 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
386 */
387 const uint8_t *input_ptr = in.ptr();
388 uint8_t *out_ptr = out.ptr();
389 int ih = 0;
390 int oh = 0;
391 for(int oz = 0; oz < range_z; ++oz)
392 {
393 auto p_out_base = out_ptr + oz * output_stride_z;
394 // Step 1
395 {
396 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
397 const auto vk = internal_vdupq_n(*k_val);
398 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
399 {
400 const int offset_xy = ih * input_stride_y;
401 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
402 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
403 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
404 {
405 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
406 }
407 }
408 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100409
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100410 // Step 2
411 for(int p = 1; p < kernel_depth; ++p)
412 {
413 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
414 const auto vk = internal_vdupq_n(*k_val);
415 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
416 {
417 const int offset_xy = ih * input_stride_y;
418 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
419 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
420 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
421 {
422 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
423 }
424 }
425 }
426 }
427 },
428 in, out);
429 }
430};
431
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000432#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100433
434template <unsigned int stridex>
435void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
436
437template <>
438void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
439{
440 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
441 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
442}
443
444template <>
445void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
446{
447 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
448}
449
450template <>
451void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
452{
453 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
454}
455
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000456#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100457
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100458template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100459float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
460 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position);
461
462inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
463{
464 const float32x4x3_t m00 =
465 {
466 {
467 vld1q_dup_f32(m0),
468 vld1q_dup_f32(m1),
469 vld1q_dup_f32(m2)
470 }
471 };
472 return m00;
473}
474
475inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
476{
477 const float32x4x2_t m00 =
478 {
479 {
480 vld1q_dup_f32(m3),
481 vld1q_dup_f32(m4)
482 }
483 };
484 return m00;
485}
486
487inline float32x4x3_t load_input(const float *const in)
488{
489 const float32x4x3_t vin =
490 {
491 {
492 vld1q_f32(in),
493 vld1q_f32(in + 4),
494 vld1q_f32(in + 8)
495 }
496 };
497 return vin;
498}
499
500template <>
501inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
502 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
503{
504 ARM_COMPUTE_UNUSED(fixed_point_position);
505 const float32x4x3_t vin0 = load_input(in_0);
506 const float32x4x3_t vin1 = load_input(in_1);
507 const float32x4x3_t vin2 = load_input(in_2);
508 const float32x4x3_t vin3 = load_input(in_3);
509 const float32x4x3_t vin4 = load_input(in_4);
510 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
511 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
512 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
513 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
514 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
515 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
516 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
517 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
518 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
519 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
520
521 float32x4x2_t out =
522 {
523 {
524 vmulq_f32(vin0.val[0], m00.val[0]),
525 vmulq_f32(vin0.val[1], m00.val[0])
526 }
527 };
528
529 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
530 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
531 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
532 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
533
534 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
535 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
536 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
537 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
538 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
539
540 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
541 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
542 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
543 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
544 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
545
546 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
547 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
548 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
549 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
550 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
551
552 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
553 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
554 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
555 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
556 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
557
558 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
559 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
560 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
561 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
562
563 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
564 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
565 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
566 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
567 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
568
569 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
570 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
571 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
572 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
573 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
574
575 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
576 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
577 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
578 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
579 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
580
581 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
582 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
583 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
584 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
585 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
586
587 return out;
588}
589
590template <>
591inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
592 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
593{
594 ARM_COMPUTE_UNUSED(fixed_point_position);
595 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
596 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
597 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
598 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
599 return out;
600}
601
602template <>
603inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
604 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
605{
606 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
607 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
608 return out;
609}
610
611template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100612void accumulate_results(float *buffer, const float32x4x2_t &values);
613
614template <>
615void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
616{
617 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
618 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
619}
620
621template <>
622void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
623{
624 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
625}
626
627template <>
628void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
629{
630 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
631}
632
633template <unsigned int stridex>
634void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
635
636template <>
637void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
638{
639 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
640 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
641}
642
643template <>
644void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
645{
646 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
647}
648
649template <>
650void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
651{
652 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
653}
654
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100655template <typename T1, typename T2, unsigned int stridex>
656class convolver_3x3
657{
658public:
659 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
660 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
661 {
662 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
663 const int input_stride_x = input->info()->strides_in_bytes().x();
664 const int input_stride_y = input->info()->strides_in_bytes().y();
665 const int input_stride_z = input->info()->strides_in_bytes().z();
666 const int output_stride_y = output->info()->strides_in_bytes().y();
667 const int output_stride_z = output->info()->strides_in_bytes().z();
668 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
669 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
670 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
671 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
672 const int output_w = output->info()->dimension(0);
673 const int output_h = output->info()->dimension(1);
674 const int num_planes_z = window.z().end() - window.z().start();
675 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
676 const int kernel_depth = weights->info()->dimension(Window::DimZ);
677 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
678 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
679 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
680 const int fixed_point_position = input->info()->fixed_point_position();
681
682 // setup output window for the iterator
683 Window window_out = window;
684 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
685 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
686 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
687
688 // setup input window for the iterator
689 Window window_in = window;
690 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
691 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
692 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
693 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
694
695 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
696
697 Iterator out(output, window_out);
698 Iterator in(input, window_in);
699 Iterator k(weights, window_k);
700
701 const uint8_t *k_ptr = k.ptr();
702
703 execute_window_loop(window_out, [&](const Coordinates & id)
704 {
705 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
706 uint8_t *out_ptr = out.ptr();
707 int ih = 0;
708 int oh = 0;
709 /*
710 Each thread executing this kernel computes one or more output's volume planes.
711
712 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
713 the third thread [16,24] and the fourth thread [25,31].
714
715 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100716 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100717
718 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
719 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
720 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
721 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100722 for(int oz = 0; oz < num_planes_z; ++oz)
723 {
Pablo Tello0d176142017-07-06 16:43:14 +0100724 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100725 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
726 // Step 1
727 {
Pablo Tello0d176142017-07-06 16:43:14 +0100728 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
729 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
730 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100731 const auto vk_r0 = load_matrix_row(ptr_k_r0);
732 const auto vk_r1 = load_matrix_row(ptr_k_r1);
733 const auto vk_r2 = load_matrix_row(ptr_k_r2);
734 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
735 {
736 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
737 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
738 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
739 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
740 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
741 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
742 {
743 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
744 store_results<stridex>(p_out, vres);
745 }
746 }
747 }
748 // Step 2
749 for(int p = 1; p < kernel_depth; ++p)
750 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100751 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
752 const uint8_t *input_base = input_ptr + p * input_stride_z;
753 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
754 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
755 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
756 const auto vk_r0 = load_matrix_row(ptr_k_r0);
757 const auto vk_r1 = load_matrix_row(ptr_k_r1);
758 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100759 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
760 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100761 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
762 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
763 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100764 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
765 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
766 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
767 {
768 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
769 accumulate_results<stridex>(p_out, vres);
770 }
771 }
772 }
773 }
774 },
775 in, out);
776 }
777};
778
Pablo Tello06da39d2017-08-10 15:10:40 +0100779template <typename T1, typename T2, unsigned int stridex>
780class convolver_5x5
781{
782public:
783 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
784 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
785 {
786 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
787 const int input_stride_x = input->info()->strides_in_bytes().x();
788 const int input_stride_y = input->info()->strides_in_bytes().y();
789 const int input_stride_z = input->info()->strides_in_bytes().z();
790 const int output_stride_y = output->info()->strides_in_bytes().y();
791 const int output_stride_z = output->info()->strides_in_bytes().z();
792 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
793 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
794 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
795 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
796 const int output_w = output->info()->dimension(0);
797 const int output_h = output->info()->dimension(1);
798 const int num_planes_z = window.z().end() - window.z().start();
799 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
800 const int kernel_depth = weights->info()->dimension(Window::DimZ);
801 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
802 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
803 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
804 const int fixed_point_position = input->info()->fixed_point_position();
805
806 // setup output window for the iterator
807 Window window_out = window;
808 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
809 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
810 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
811
812 // setup input window for the iterator
813 Window window_in = window;
814 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
815 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
816 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
817 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
818
819 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
820
821 Iterator out(output, window_out);
822 Iterator in(input, window_in);
823 Iterator k(weights, window_k);
824
825 const uint8_t *k_ptr = k.ptr();
826
827 execute_window_loop(window_out, [&](const Coordinates & id)
828 {
829 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
830 uint8_t *out_ptr = out.ptr();
831 int ih = 0;
832 int oh = 0;
833 for(int oz = 0; oz < num_planes_z; ++oz)
834 {
835 const int zoffset = id.z() + oz;
836 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
837 // Step 1
838 {
839 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
840 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
841 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
842 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
843 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
844 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
845 {
846 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
847 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
848 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
849 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
850 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
851 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
852 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
853 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
854 {
855 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
856 store_results<stridex>(p_out, vres);
857 }
858 }
859 }
860 // Step 2
861 for(int p = 1; p < kernel_depth; ++p)
862 {
863 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
864 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
865 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
866 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
867 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
868
869 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
870 {
871 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
872 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
873 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
874 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
875 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
876 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
877 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
878 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
879 {
880 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
881 accumulate_results<stridex>(p_out, vres);
882 }
883 }
884 }
885 }
886 },
887 in, out);
888 }
889};
890
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100891template <typename T1, typename T2>
892inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
893 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
894{
895 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
896 switch(conv_stride_x)
897 {
898 case 1:
899 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
900 break;
901 case 2:
902 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
903 break;
904 case 3:
905 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
906 break;
907 default:
908 ARM_COMPUTE_ERROR("Not implemented");
909 }
910}
911
Pablo Telloc09314a2017-09-21 13:59:14 +0100912template <>
913inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
914 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
915{
916 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
917 if(run_optim_small_tensor(input))
918 {
919 switch(conv_stride_x)
920 {
921 case 1:
922 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
923 break;
924 case 2:
925 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
926 break;
927 case 3:
928 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
929 break;
930 default:
931 ARM_COMPUTE_ERROR("Not implemented");
932 }
933 }
934 else
935 {
936 switch(conv_stride_x)
937 {
938 case 1:
939 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
940 break;
941 case 2:
942 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
943 break;
944 case 3:
945 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
946 break;
947 default:
948 ARM_COMPUTE_ERROR("Not implemented");
949 }
950 }
951}
952
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100953template <typename T1, typename T2>
954inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
955 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
956{
957 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
958 switch(conv_stride_x)
959 {
960 case 1:
961 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
962 break;
963 case 2:
964 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
965 break;
966 case 3:
967 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
968 break;
969 default:
970 ARM_COMPUTE_ERROR("Not implemented");
971 }
972}
Pablo Tello06da39d2017-08-10 15:10:40 +0100973
974template <typename T1, typename T2>
975inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
976 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
977{
978 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
979 switch(conv_stride_x)
980 {
981 case 1:
982 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
983 break;
984 case 2:
985 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
986 break;
987 case 3:
988 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
989 break;
990 default:
991 ARM_COMPUTE_ERROR("Not implemented");
992 }
993}
994
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100995} // namespace
996
997NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +0100998 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
999 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001000{
1001}
1002
1003BorderSize NEDirectConvolutionLayerKernel::border_size() const
1004{
1005 return _border_size;
1006}
1007
1008void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1009{
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001010 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::QS16, DataType::F32);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001011 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001012 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
1013 "Pad > 0 not supported for 1x1 weights");
1014 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
1015 "Pad > 1 not supported for 3x3 weights");
Pablo Tello06da39d2017-08-10 15:10:40 +01001016 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 5 && (std::get<0>(conv_info.pad()) > 2 || std::get<1>(conv_info.pad()) > 2),
1017 "Pad > 2 not supported for 5x5 weights");
1018
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001019 ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001020 ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
1021 ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1));
1022 ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001023
1024 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1025 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
1026 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
1027
1028 _input = input;
1029 _weights = weights;
1030 _output = output;
1031 _conv_info = conv_info;
1032 _kernel_size = weights->info()->dimension(0);
1033 _border_size = BorderSize(conv_pad_y, conv_pad_x);
1034
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001035 const unsigned int kernel_size = weights->info()->dimension(0);
1036
1037 // Get convolved dimensions
1038 unsigned int output_width = 0;
1039 unsigned int output_height = 0;
1040 std::tie(output_width, output_height) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_size, kernel_size, conv_info);
1041
1042 TensorShape output_shape = input->info()->tensor_shape();
1043 output_shape.set(0, output_width);
1044 output_shape.set(1, output_height);
1045 output_shape.set(2, weights->info()->dimension(3));
1046
1047 DataType data_type = input->info()->data_type();
1048
1049 if(is_data_type_fixed_point(data_type))
1050 {
1051 // Promote data type in case of fixed point
1052 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1053 }
1054
1055 // Output auto inizialitation if not yet initialized
1056 auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
1057
1058 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
1059 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, output->info()->data_type());
1060
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001061 switch(_kernel_size)
1062 {
1063 case 1:
1064 {
Pablo Tello0d176142017-07-06 16:43:14 +01001065 switch(input->info()->data_type())
1066 {
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001067#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +01001068 case DataType::F16:
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001069#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +01001070 case DataType::QS8:
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001071 case DataType::QS16:
Pablo Tello0d176142017-07-06 16:43:14 +01001072 _num_elems_written_per_iteration = 8;
1073 break;
1074 case DataType::F32:
Pablo Telloc09314a2017-09-21 13:59:14 +01001075 if(run_optim_small_tensor(input))
1076 {
1077 _num_elems_written_per_iteration = 8;
1078 }
1079 else
1080 {
1081 _num_elems_written_per_iteration = 4;
1082 }
Pablo Tello0d176142017-07-06 16:43:14 +01001083 break;
1084 default:
1085 ARM_COMPUTE_ERROR("Data type not supported.");
1086 break;
1087 }
Georgios Pinitas898a8062017-09-12 19:19:12 +01001088 _num_weight_elems_read_per_row = kernel_size;
1089 _num_elems_read_per_iteration = conv_stride_x * _num_elems_written_per_iteration;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001090 break;
1091 }
1092 case 3:
Pablo Tello06da39d2017-08-10 15:10:40 +01001093 case 5:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001094 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001095 switch(input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001096 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001097 case DataType::F32:
Georgios Pinitas898a8062017-09-12 19:19:12 +01001098 _num_weight_elems_read_per_row = 4 + _kernel_size - 1;
Pablo Tello06da39d2017-08-10 15:10:40 +01001099 _num_elems_read_per_iteration = 12;
1100 _num_elems_written_per_iteration = 16 >> conv_stride_x;
1101 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001102#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello06da39d2017-08-10 15:10:40 +01001103 case DataType::F16:
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001104#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello06da39d2017-08-10 15:10:40 +01001105 case DataType::QS8:
1106 case DataType::QS16:
Georgios Pinitas898a8062017-09-12 19:19:12 +01001107 _num_weight_elems_read_per_row = 8 + _kernel_size - 1;
Pablo Tello06da39d2017-08-10 15:10:40 +01001108 _num_elems_read_per_iteration = 24;
1109 _num_elems_written_per_iteration = 32 >> conv_stride_x;
1110 break;
1111 default:
1112 ARM_COMPUTE_ERROR("Data type not supported.");
1113 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001114 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001115 }
Georgios Pinitas898a8062017-09-12 19:19:12 +01001116 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001117 default:
1118 {
1119 ARM_COMPUTE_ERROR("Not implemented");
1120 break;
1121 }
1122 }
1123
Georgios Pinitas898a8062017-09-12 19:19:12 +01001124 // Calculate right and bottom border
1125 const unsigned int conv_stride_y = std::get<1>(_conv_info.stride());
1126 const int input_width = input->info()->dimension(0);
1127 const int input_height = input->info()->dimension(1);
1128 const int upper_bound_w = ceil_to_multiple(((output->info()->dimension(0) - 1) * conv_stride_x + _kernel_size), _num_elems_read_per_iteration) - conv_pad_x - input_width;
1129 const int upper_bound_h = ((output->info()->dimension(1) - 1) * conv_stride_y - conv_pad_y + _kernel_size) - input_height;
1130 _border_size.right = std::max(upper_bound_w, static_cast<int>(_kernel_size));
1131 _border_size.bottom = std::max(upper_bound_h, static_cast<int>(_kernel_size));
Pablo Telloc09314a2017-09-21 13:59:14 +01001132 Window win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
Georgios Pinitas898a8062017-09-12 19:19:12 +01001133 AccessWindowStatic input_access(input->info(), -conv_pad_x, -conv_pad_y, input_width + _border_size.right, input_height + _border_size.bottom);
1134 AccessWindowStatic weights_access(weights->info(), 0, 0, _num_weight_elems_read_per_row, _kernel_size);
1135 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
1136 update_window_and_padding(win, input_access, weights_access, output_access);
1137 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
1138
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001139 INEKernel::configure(win);
1140}
1141
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001142void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001143{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001144 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001145 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1146 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1147 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1148
1149 const int kernel_size = _weights->info()->dimension(0);
1150
1151 switch(kernel_size)
1152 {
1153 case 1:
1154 {
Pablo Tello0d176142017-07-06 16:43:14 +01001155 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001156 {
Pablo Tello0d176142017-07-06 16:43:14 +01001157 case DataType::QS8:
1158 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1159 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001160 case DataType::QS16:
1161 convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1162 break;
Pablo Tello0d176142017-07-06 16:43:14 +01001163 case DataType::F32:
1164 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1165 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001166#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +01001167 case DataType::F16:
1168 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1169 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001170#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +01001171 default:
1172 ARM_COMPUTE_ERROR("Data type not supported");
1173 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001174 }
1175 break;
1176 }
1177 case 3:
1178 {
Pablo Tello0d176142017-07-06 16:43:14 +01001179 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001180 {
Pablo Tello0d176142017-07-06 16:43:14 +01001181 case DataType::QS8:
1182 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1183 break;
1184 case DataType::F32:
1185 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1186 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001187#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +01001188 case DataType::F16:
1189 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1190 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001191#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +01001192 default:
1193 ARM_COMPUTE_ERROR("Data type not supported");
1194 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001195 }
1196 break;
1197 }
Pablo Tello06da39d2017-08-10 15:10:40 +01001198 case 5:
1199 {
1200 switch(_input->info()->data_type())
1201 {
1202 case DataType::F32:
1203 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1204 break;
1205 default:
1206 ARM_COMPUTE_ERROR("Data type not supported");
1207 break;
1208 }
1209 break;
1210 }
1211
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001212 default:
1213 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001214 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001215 break;
1216 }
1217 }
1218}