blob: 5eafdf03631763208dfd618310ad10471b9350e2 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrou621965e2018-01-08 17:11:26 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
Georgios Pinitas4074c992018-01-30 18:13:46 +000025#include "arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
27#include "arm_compute/core/AccessWindowStatic.h"
28#include "arm_compute/core/Error.h"
29#include "arm_compute/core/Helpers.h"
30#include "arm_compute/core/IAccessWindow.h"
31#include "arm_compute/core/ITensor.h"
32#include "arm_compute/core/NEON/NEFixedPoint.h"
33#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010034#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/Validate.h"
Giorgio Arenac0f54432018-03-16 14:02:34 +000036#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010037
38#include <algorithm>
39#include <arm_neon.h>
40
41using namespace arm_compute;
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010042using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010043
44namespace
45{
Pablo Tellof87cc7f2017-07-26 10:28:40 +010046template <unsigned int stridex>
47qint16x8_t internal_vld1q(const qint16_t *in);
48
49template <>
50qint16x8_t internal_vld1q<1>(const qint16_t *in)
51{
52 return vld1q_qs16(in);
53}
54
55template <>
56qint16x8_t internal_vld1q<2>(const qint16_t *in)
57{
58 const int16x8x2_t tmp = vld2q_s16(in);
59 return tmp.val[0];
60}
61
62template <>
63qint16x8_t internal_vld1q<3>(const qint16_t *in)
64{
65 const int16x8x3_t tmp = vld3q_s16(in);
66 return tmp.val[0];
67}
68
69inline qint16x8_t internal_vdupq_n(qint16_t v)
70{
71 return vdupq_n_qs16(v);
72}
73
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000074#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010075template <unsigned int stridex>
76float16x8_t internal_vld1q(const float16_t *in);
77
78template <>
79float16x8_t internal_vld1q<1>(const float16_t *in)
80{
81 return vld1q_f16(in);
82}
83
84template <>
85float16x8_t internal_vld1q<2>(const float16_t *in)
86{
87 const float16x8x2_t tmp = vld2q_f16(in);
88 return tmp.val[0];
89}
90
91template <>
92float16x8_t internal_vld1q<3>(const float16_t *in)
93{
94 const float16x8x3_t tmp = vld3q_f16(in);
95 return tmp.val[0];
96}
97
98inline float16x8_t internal_vdupq_n(float16_t v)
99{
100 return vdupq_n_f16(v);
101}
102
103inline void internal_vst1q(float16_t *p, const float16x8_t &v)
104{
105 vst1q_f16(p, v);
106}
107
108float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
109{
110 ARM_COMPUTE_UNUSED(fixed_point_position);
111 return vmulq_f16(x, y);
112}
113
114inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
115{
116 ARM_COMPUTE_UNUSED(fixed_point_position);
117 return vaddq_f16(x, vmulq_f16(y, z));
118}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000119#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100120
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100121template <unsigned int stridex>
122float32x4_t internal_vld1q(const float *in);
123
124template <>
125float32x4_t internal_vld1q<1>(const float *in)
126{
127 return vld1q_f32(in);
128}
129
130template <>
131float32x4_t internal_vld1q<2>(const float *in)
132{
133 const float32x4x2_t tmp = vld2q_f32(in);
134 return tmp.val[0];
135}
136
137template <>
138float32x4_t internal_vld1q<3>(const float *in)
139{
140 const float32x4x3_t tmp = vld3q_f32(in);
141 return tmp.val[0];
142}
143
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100144inline float32x4_t internal_vdupq_n(float v)
145{
146 return vdupq_n_f32(v);
147}
148
149inline void internal_vst1q(float *p, const float32x4_t &v)
150{
151 vst1q_f32(p, v);
152}
153
154float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
155{
156 ARM_COMPUTE_UNUSED(fixed_point_position);
157 return vmulq_f32(x, y);
158}
159
160inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
161{
162 ARM_COMPUTE_UNUSED(fixed_point_position);
163 return vmlaq_f32(x, y, z);
164}
165
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100166template <unsigned int stridex>
167qint8x8_t internal_vld1q(const qint8_t *in);
168
169template <>
170qint8x8_t internal_vld1q<1>(const qint8_t *in)
171{
172 return vld1_qs8(in);
173}
174
175template <>
176qint8x8_t internal_vld1q<2>(const qint8_t *in)
177{
178 const qint8x8x2_t tmp = vld2_s8(in);
179 return tmp.val[0];
180}
181
182template <>
183qint8x8_t internal_vld1q<3>(const qint8_t *in)
184{
185 const qint8x8x3_t tmp = vld3_s8(in);
186 return tmp.val[0];
187}
188
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100189inline qint8x8_t internal_vdupq_n(qint8_t v)
190{
191 return vdup_n_qs8(v);
192}
193
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100194inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100195{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100196 return vmull_qs8(x, y, fixed_point_position);
197}
198
199inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
200{
201 return vqmlal_qs8(x, y, z, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100202}
203
204inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
205{
206 vst1q_qs16(p, v);
207}
208
Michalis Spyrou490bf2e2017-09-29 11:24:55 +0100209inline void internal_vst1q(int32_t *p, const qint32x4x2_t &v)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100210{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100211 vst1q_s32(p, v.val[0]);
212 vst1q_s32(p + 4, v.val[1]);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100213}
214
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100215template <unsigned int stridex>
216qint32x4x2_t internal_vld1q(const qint32_t *in);
217
218template <>
219qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100220{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100221 const qint32x4x2_t r =
222 {
223 {
224 vld1q_s32(in),
225 vld1q_s32(in + 4)
226 }
227 };
228 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100229}
230
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100231inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100233 const qint32x4x2_t r =
234 {
235 {
236 vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
237 vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
238 }
239 };
240 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100241}
242
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100243inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100244{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100245 const qint32x4x2_t r =
246 {
247 {
248 vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
249 vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
250 }
251 };
252 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100253}
254
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000255constexpr int small_tensor_size_optim = 8;
256inline bool run_optim_small_tensor_info(const ITensorInfo *t)
257{
258 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
259}
260
Pablo Telloc09314a2017-09-21 13:59:14 +0100261inline bool run_optim_small_tensor(const ITensor *t)
262{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000263 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100264}
265
266// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
267// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
268// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
269template <unsigned int stridex>
270class convolver_w1x1_i8x8_f32
271{
272public:
273 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
274 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000275 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim);
276 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100277
Georgios Pinitas15997872018-02-19 13:58:22 +0000278 const int input_stride_x = input->info()->strides_in_bytes().x();
Pablo Telloc09314a2017-09-21 13:59:14 +0100279 const int input_stride_y = input->info()->strides_in_bytes().y();
280 const int input_stride_z = input->info()->strides_in_bytes().z();
281 const int output_stride_y = output->info()->strides_in_bytes().y();
282 const int output_stride_z = output->info()->strides_in_bytes().z();
283 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
284 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
285 const int output_h = output->info()->dimension(1);
286 const int range_z = window.z().end() - window.z().start();
287 const int kernel_depth = weights->info()->dimension(Window::DimZ);
288 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000289 const unsigned int conv_pad_left = conv_info.pad_left();
290 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Telloc09314a2017-09-21 13:59:14 +0100291
292 // setup output window for the iterator
293 Window window_out = window;
294 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
295 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
296 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
297
298 // setup input window for the iterator
299 Window window_in = window;
300 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
301 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
302 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
303 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
304
305 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
306 Iterator out(output, window_out);
307 Iterator in(input, window_in);
308 Iterator k(weights, window_k);
309
310 const uint8_t *k_ptr = k.ptr();
311
312 execute_window_loop(window_out, [&](const Coordinates & id)
313 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000314 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000315 uint8_t *out_ptr = out.ptr();
316 int ih = 0;
317 int oh = 0;
318 float32x4_t accum0[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
319 float32x4_t accum1[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100320 for(int oz = 0; oz < range_z; ++oz)
321 {
322 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
323 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
324 auto p_out_base = out_ptr + oz * output_stride_z;
325 for(int p = 0; p < kernel_depth; ++p)
326 {
327 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
328 const auto vk0 = internal_vdupq_n(*k_val);
329 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
330 {
331 const int offset_xy = ih * input_stride_y;
332 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
333 auto v_in0 = internal_vld1q<stridex>(in_val);
334 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
335 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
336 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
337 }
338 }
339 for(oh = 0; oh < output_h; ++oh)
340 {
341 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
342 vst1q_f32(p_out, accum0[oh]);
343 vst1q_f32(p_out + 4, accum1[oh]);
344 }
345 }
346 },
347 in, out);
348 }
349};
350
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100351template <typename T1, typename T2, unsigned int stridex>
352class convolver_1x1
353{
354public:
355 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
356 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
357 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000358 const int input_stride_x = input->info()->strides_in_bytes().x();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100359 const int input_stride_y = input->info()->strides_in_bytes().y();
360 const int input_stride_z = input->info()->strides_in_bytes().z();
361 const int output_stride_y = output->info()->strides_in_bytes().y();
362 const int output_stride_z = output->info()->strides_in_bytes().z();
363 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
364 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
365 const int output_w = output->info()->dimension(0);
366 const int output_h = output->info()->dimension(1);
367 const int range_z = window.z().end() - window.z().start();
368 const int kernel_depth = weights->info()->dimension(Window::DimZ);
369 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000370 const unsigned int conv_pad_left = conv_info.pad_left();
371 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100372 const int fixed_point_position = input->info()->fixed_point_position();
373
374 // setup output window for the iterator
375 Window window_out = window;
376 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
377 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
378 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
379
380 // setup input window for the iterator
381 Window window_in = window;
382 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
383 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
384 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
385 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
386
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100387 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100388 Iterator out(output, window_out);
389 Iterator in(input, window_in);
390 Iterator k(weights, window_k);
391
392 const uint8_t *k_ptr = k.ptr();
393
394 execute_window_loop(window_out, [&](const Coordinates & id)
395 {
396 /*
397 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
398 */
Georgios Pinitas15997872018-02-19 13:58:22 +0000399 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100400 uint8_t *out_ptr = out.ptr();
401 int ih = 0;
402 int oh = 0;
403 for(int oz = 0; oz < range_z; ++oz)
404 {
405 auto p_out_base = out_ptr + oz * output_stride_z;
406 // Step 1
407 {
408 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
409 const auto vk = internal_vdupq_n(*k_val);
410 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
411 {
412 const int offset_xy = ih * input_stride_y;
413 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
414 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
415 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
416 {
417 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
418 }
419 }
420 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100421
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100422 // Step 2
423 for(int p = 1; p < kernel_depth; ++p)
424 {
425 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
426 const auto vk = internal_vdupq_n(*k_val);
427 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
428 {
429 const int offset_xy = ih * input_stride_y;
430 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
431 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
432 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
433 {
434 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
435 }
436 }
437 }
438 }
439 },
440 in, out);
441 }
442};
443
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000444#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100445
446template <unsigned int stridex>
447void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
448
449template <>
450void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
451{
452 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
453 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
454}
455
456template <>
457void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
458{
459 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
460}
461
462template <>
463void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
464{
465 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
466}
467
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000468#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100469
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100470template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100471float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
472 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position);
473
474inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
475{
476 const float32x4x3_t m00 =
477 {
478 {
479 vld1q_dup_f32(m0),
480 vld1q_dup_f32(m1),
481 vld1q_dup_f32(m2)
482 }
483 };
484 return m00;
485}
486
487inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
488{
489 const float32x4x2_t m00 =
490 {
491 {
492 vld1q_dup_f32(m3),
493 vld1q_dup_f32(m4)
494 }
495 };
496 return m00;
497}
498
499inline float32x4x3_t load_input(const float *const in)
500{
501 const float32x4x3_t vin =
502 {
503 {
504 vld1q_f32(in),
505 vld1q_f32(in + 4),
506 vld1q_f32(in + 8)
507 }
508 };
509 return vin;
510}
511
512template <>
513inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
514 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
515{
516 ARM_COMPUTE_UNUSED(fixed_point_position);
517 const float32x4x3_t vin0 = load_input(in_0);
518 const float32x4x3_t vin1 = load_input(in_1);
519 const float32x4x3_t vin2 = load_input(in_2);
520 const float32x4x3_t vin3 = load_input(in_3);
521 const float32x4x3_t vin4 = load_input(in_4);
522 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
523 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
524 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
525 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
526 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
527 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
528 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
529 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
530 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
531 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
532
533 float32x4x2_t out =
534 {
535 {
536 vmulq_f32(vin0.val[0], m00.val[0]),
537 vmulq_f32(vin0.val[1], m00.val[0])
538 }
539 };
540
541 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
542 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
543 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
544 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
545
546 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
547 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
548 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
549 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
550 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
551
552 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
553 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
554 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
555 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
556 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
557
558 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
559 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
560 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
561 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
562 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
563
564 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
565 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
566 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
567 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
568 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
569
570 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
571 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
572 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
573 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
574
575 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
576 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
577 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
578 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
579 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
580
581 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
582 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
583 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
584 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
585 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
586
587 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
588 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
589 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
590 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
591 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
592
593 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
594 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
595 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
596 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
597 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
598
599 return out;
600}
601
602template <>
603inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
604 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
605{
606 ARM_COMPUTE_UNUSED(fixed_point_position);
607 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
608 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
609 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
610 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
611 return out;
612}
613
614template <>
615inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
616 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
617{
618 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
619 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
620 return out;
621}
622
623template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100624void accumulate_results(float *buffer, const float32x4x2_t &values);
625
626template <>
627void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
628{
629 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
630 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
631}
632
633template <>
634void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
635{
636 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
637}
638
639template <>
640void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
641{
642 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
643}
644
645template <unsigned int stridex>
646void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
647
648template <>
649void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
650{
651 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
652 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
653}
654
655template <>
656void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
657{
658 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
659}
660
661template <>
662void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
663{
664 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
665}
666
Giorgio Arenac0f54432018-03-16 14:02:34 +0000667template <typename T1>
668class convolver_nhwc
669{
670public:
671 static void convolve(const Window &window, int kernel_size, unsigned int num_elems_read_per_iteration,
672 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
673 {
674 const int input_width = input->info()->dimension(0);
675 const int input_depth = input->info()->dimension(2);
676 const int input_stride_x = input->info()->strides_in_bytes().x();
677 const int input_stride_y = input->info()->strides_in_bytes().y();
678 const int input_stride_z = input->info()->strides_in_bytes().z();
679 const int output_stride_x = output->info()->strides_in_bytes().x();
680 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
681 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
682 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
683 const int conv_pad_top = conv_info.pad_top();
684 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
685 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
686 const T1 zero = 0;
687
688 // Setup input window for the input iterator
689 Window window_in = window;
690 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
691 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
692 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
693
694 // Setup input window for the output iterator
695 Window window_out = window;
696 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
697
698 // Setup input window for the weights iterator
699 Window window_k = calculate_max_window(*weights->info(), Steps());
700 window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
701 window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
702 window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
703 window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
704
705 Iterator in(input, window_in);
706 Iterator out(output, window_out);
707 Iterator k(weights, window_k);
708
709 execute_window_loop(window_k, [&](const Coordinates & id_k)
710 {
711 execute_window_loop(window_out, [&](const Coordinates & id)
712 {
713 const auto in_y = static_cast<int>(id.y() * conv_stride_x - conv_info.pad_left());
714 const auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
715
716 const uint8_t *in_ptr = in.ptr() + in_y * input_stride_y + in_z * input_stride_z;
717 uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x;
718
719 T1 out_val = 0;
720
721 auto in_addr_base0 = in_ptr;
722 auto we_addr_base0 = k.ptr();
723
724 for(int z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z)
725 {
726 const int in_z = id.z() * conv_stride_y + z - conv_pad_top;
727
728 if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom
729 {
730 auto in_addr_base1 = in_addr_base0;
731 auto we_addr_base1 = we_addr_base0;
732
733 for(int y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y)
734 {
735 auto out_values = internal_vdupq_n(zero);
736
737 int x = 0;
738 int no_leftover = input_width - num_elems_read_per_iteration;
739
740 for(; x < no_leftover; x += num_elems_read_per_iteration)
741 {
742 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
743 const auto in_values = internal_vld1q<1>(in_addr);
744
745 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
746 const auto we_values = internal_vld1q<1>(we_addr);
747
748 out_values = internal_vmlal(out_values, in_values, we_values, 0);
749 }
750
751 out_val += out_values[0];
752 out_val += out_values[1];
753 out_val += out_values[2];
754 out_val += out_values[3];
755
756 // Leftover
757 for(; x < input_width; ++x)
758 {
759 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
760 const auto in_value = *(in_addr);
761
762 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
763 const auto we_value = *(we_addr);
764
765 out_val += in_value * we_value;
766 }
767 }
768 }
769 }
770
771 *(reinterpret_cast<T1 *>(out_ptr)) = out_val;
772 },
773 in, out);
774 },
775 k);
776 }
777};
778
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100779template <typename T1, typename T2, unsigned int stridex>
780class convolver_3x3
781{
782public:
783 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
784 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
785 {
786 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
787 const int input_stride_x = input->info()->strides_in_bytes().x();
788 const int input_stride_y = input->info()->strides_in_bytes().y();
789 const int input_stride_z = input->info()->strides_in_bytes().z();
790 const int output_stride_y = output->info()->strides_in_bytes().y();
791 const int output_stride_z = output->info()->strides_in_bytes().z();
792 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
793 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
794 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
795 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
796 const int output_w = output->info()->dimension(0);
797 const int output_h = output->info()->dimension(1);
798 const int num_planes_z = window.z().end() - window.z().start();
799 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
800 const int kernel_depth = weights->info()->dimension(Window::DimZ);
801 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000802 const unsigned int conv_pad_left = conv_info.pad_left();
803 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100804 const int fixed_point_position = input->info()->fixed_point_position();
805
806 // setup output window for the iterator
807 Window window_out = window;
808 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
809 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
810 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
811
812 // setup input window for the iterator
813 Window window_in = window;
814 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
815 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
816 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
817 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
818
819 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
820
821 Iterator out(output, window_out);
822 Iterator in(input, window_in);
823 Iterator k(weights, window_k);
824
825 const uint8_t *k_ptr = k.ptr();
826
827 execute_window_loop(window_out, [&](const Coordinates & id)
828 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000829 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100830 uint8_t *out_ptr = out.ptr();
831 int ih = 0;
832 int oh = 0;
833 /*
834 Each thread executing this kernel computes one or more output's volume planes.
835
836 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
837 the third thread [16,24] and the fourth thread [25,31].
838
839 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100840 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100841
842 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
843 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
844 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
845 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100846 for(int oz = 0; oz < num_planes_z; ++oz)
847 {
Pablo Tello0d176142017-07-06 16:43:14 +0100848 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100849 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
850 // Step 1
851 {
Pablo Tello0d176142017-07-06 16:43:14 +0100852 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
853 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
854 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100855 const auto vk_r0 = load_matrix_row(ptr_k_r0);
856 const auto vk_r1 = load_matrix_row(ptr_k_r1);
857 const auto vk_r2 = load_matrix_row(ptr_k_r2);
858 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
859 {
860 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
861 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
862 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
863 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
864 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
865 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
866 {
867 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
868 store_results<stridex>(p_out, vres);
869 }
870 }
871 }
872 // Step 2
873 for(int p = 1; p < kernel_depth; ++p)
874 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100875 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
876 const uint8_t *input_base = input_ptr + p * input_stride_z;
877 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
878 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
879 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
880 const auto vk_r0 = load_matrix_row(ptr_k_r0);
881 const auto vk_r1 = load_matrix_row(ptr_k_r1);
882 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100883 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
884 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100885 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
886 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
887 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100888 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
889 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
890 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
891 {
892 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
893 accumulate_results<stridex>(p_out, vres);
894 }
895 }
896 }
897 }
898 },
899 in, out);
900 }
901};
902
Pablo Tello06da39d2017-08-10 15:10:40 +0100903template <typename T1, typename T2, unsigned int stridex>
904class convolver_5x5
905{
906public:
907 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
908 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
909 {
910 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
911 const int input_stride_x = input->info()->strides_in_bytes().x();
912 const int input_stride_y = input->info()->strides_in_bytes().y();
913 const int input_stride_z = input->info()->strides_in_bytes().z();
914 const int output_stride_y = output->info()->strides_in_bytes().y();
915 const int output_stride_z = output->info()->strides_in_bytes().z();
916 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
917 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
918 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
919 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
920 const int output_w = output->info()->dimension(0);
921 const int output_h = output->info()->dimension(1);
922 const int num_planes_z = window.z().end() - window.z().start();
923 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
924 const int kernel_depth = weights->info()->dimension(Window::DimZ);
925 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000926 const unsigned int conv_pad_left = conv_info.pad_left();
927 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Tello06da39d2017-08-10 15:10:40 +0100928 const int fixed_point_position = input->info()->fixed_point_position();
929
930 // setup output window for the iterator
931 Window window_out = window;
932 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
933 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
934 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
935
936 // setup input window for the iterator
937 Window window_in = window;
938 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
939 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
940 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
941 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
942
943 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
944
945 Iterator out(output, window_out);
946 Iterator in(input, window_in);
947 Iterator k(weights, window_k);
948
949 const uint8_t *k_ptr = k.ptr();
950
951 execute_window_loop(window_out, [&](const Coordinates & id)
952 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000953 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Pablo Tello06da39d2017-08-10 15:10:40 +0100954 uint8_t *out_ptr = out.ptr();
955 int ih = 0;
956 int oh = 0;
957 for(int oz = 0; oz < num_planes_z; ++oz)
958 {
959 const int zoffset = id.z() + oz;
960 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
961 // Step 1
962 {
963 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
964 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
965 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
966 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
967 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
968 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
969 {
970 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
971 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
972 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
973 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
974 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
975 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
976 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
977 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
978 {
979 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
980 store_results<stridex>(p_out, vres);
981 }
982 }
983 }
984 // Step 2
985 for(int p = 1; p < kernel_depth; ++p)
986 {
987 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
988 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
989 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
990 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
991 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
992
993 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
994 {
995 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
996 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
997 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
998 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
999 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
1000 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1001 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1002 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
1003 {
1004 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
1005 accumulate_results<stridex>(p_out, vres);
1006 }
1007 }
1008 }
1009 }
1010 },
1011 in, out);
1012 }
1013};
1014
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001015template <typename T1, typename T2>
1016inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1017 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1018{
1019 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1020 switch(conv_stride_x)
1021 {
1022 case 1:
1023 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1024 break;
1025 case 2:
1026 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1027 break;
1028 case 3:
1029 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1030 break;
1031 default:
1032 ARM_COMPUTE_ERROR("Not implemented");
1033 }
1034}
1035
Pablo Telloc09314a2017-09-21 13:59:14 +01001036template <>
1037inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1038 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1039{
1040 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1041 if(run_optim_small_tensor(input))
1042 {
1043 switch(conv_stride_x)
1044 {
1045 case 1:
1046 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
1047 break;
1048 case 2:
1049 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
1050 break;
1051 case 3:
1052 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
1053 break;
1054 default:
1055 ARM_COMPUTE_ERROR("Not implemented");
1056 }
1057 }
1058 else
1059 {
1060 switch(conv_stride_x)
1061 {
1062 case 1:
1063 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1064 break;
1065 case 2:
1066 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1067 break;
1068 case 3:
1069 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1070 break;
1071 default:
1072 ARM_COMPUTE_ERROR("Not implemented");
1073 }
1074 }
1075}
1076
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001077template <typename T1, typename T2>
1078inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1079 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1080{
1081 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1082 switch(conv_stride_x)
1083 {
1084 case 1:
1085 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1086 break;
1087 case 2:
1088 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1089 break;
1090 case 3:
1091 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1092 break;
1093 default:
1094 ARM_COMPUTE_ERROR("Not implemented");
1095 }
1096}
Pablo Tello06da39d2017-08-10 15:10:40 +01001097
1098template <typename T1, typename T2>
1099inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1100 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1101{
1102 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1103 switch(conv_stride_x)
1104 {
1105 case 1:
1106 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1107 break;
1108 case 2:
1109 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1110 break;
1111 case 3:
1112 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1113 break;
1114 default:
1115 ARM_COMPUTE_ERROR("Not implemented");
1116 }
1117}
1118
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001119Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1120{
1121 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001122 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001123 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
1124 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001125
Giorgio Arenac0f54432018-03-16 14:02:34 +00001126 const DataLayout data_layout = input->data_layout();
1127 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
1128 const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
1129 const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
1130
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001131 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Giorgio Arenac0f54432018-03-16 14:02:34 +00001132 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx));
1133 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001134 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001135 ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001136
1137 // Checks performed when output is configured
1138 if(output->total_size() != 0)
1139 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001140 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001141
1142 DataType data_type = input->data_type();
1143 if(is_data_type_fixed_point(data_type))
1144 {
1145 // Promote data type in case of fixed point
1146 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1147 }
1148
1149 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
1150 ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
1151 }
1152
1153 return Status{};
1154}
1155
1156std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001157 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001158{
Giorgio Arenac0f54432018-03-16 14:02:34 +00001159 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
1160
1161 const DataLayout data_layout = input->data_layout();
1162 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
1163
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001164 // Calculate right and bottom border
Giorgio Arenac0f54432018-03-16 14:02:34 +00001165 unsigned int kernel_size = weights->dimension(width_idx);
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001166 const int conv_stride_x = std::get<0>(conv_info.stride());
Georgios Pinitas1a03d762018-02-21 14:47:09 +00001167 const int conv_stride_y = std::get<1>(conv_info.stride());
Giorgio Arenac0f54432018-03-16 14:02:34 +00001168 const int input_width = input->dimension(width_idx);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001169
Giorgio Arenac0f54432018-03-16 14:02:34 +00001170 Window win{};
1171 bool window_changed = false;
1172
1173 if(data_layout == DataLayout::NCHW)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001174 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001175 switch(kernel_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001176 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001177 case 1:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001178 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001179 switch(input->data_type())
1180 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001181#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001182 case DataType::F16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001183#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001184 case DataType::QS8:
1185 case DataType::QS16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001186 num_elems_written_per_iteration = 8;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001187 break;
1188 case DataType::F32:
1189 if(run_optim_small_tensor_info(input))
1190 {
1191 num_elems_written_per_iteration = 8;
1192 }
1193 else
1194 {
1195 num_elems_written_per_iteration = 4;
1196 }
1197 break;
1198 default:
1199 ARM_COMPUTE_ERROR("Data type not supported.");
1200 break;
1201 }
1202 num_weight_elems_read_per_row = kernel_size;
1203 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
1204 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001205 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001206 case 3:
1207 case 5:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001208 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001209 switch(input->data_type())
1210 {
1211 case DataType::F32:
1212 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1213 num_elems_read_per_iteration = 12;
1214 num_elems_written_per_iteration = 16 >> conv_stride_x;
1215 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001216#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001217 case DataType::F16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001218#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001219 case DataType::QS8:
1220 case DataType::QS16:
1221 num_weight_elems_read_per_row = 8 + kernel_size - 1;
1222 num_elems_read_per_iteration = 24;
1223 num_elems_written_per_iteration = 32 >> conv_stride_x;
1224 break;
1225 default:
1226 ARM_COMPUTE_ERROR("Data type not supported.");
1227 break;
1228 }
1229 }
1230 break;
1231 default:
1232 {
1233 ARM_COMPUTE_ERROR("Not implemented");
1234 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001235 }
1236 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001237
1238 // Calculate right pad
1239 int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
1240 int end_x = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
1241 int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
1242
1243 // Calculate border
1244 const unsigned int conv_pad_left = conv_info.pad_left();
1245 const unsigned int conv_pad_top = conv_info.pad_top();
1246 const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
1247 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1248
1249 border_size.left = conv_pad_left;
1250 border_size.top = conv_pad_top;
1251 border_size.right = conv_pad_right;
1252 border_size.bottom = conv_pad_bottom;
1253
1254 // Configure window
1255 win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
1256
1257 AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
1258 num_elems_read_per_iteration, kernel_size,
1259 conv_stride_x, conv_stride_y);
1260 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
1261 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
1262 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1263 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001264 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001265 else
1266 {
1267 border_size.left = 0;
1268 border_size.top = conv_info.pad_left();
1269 border_size.right = 0;
1270 border_size.bottom = conv_info.pad_right();
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001271
Giorgio Arenac0f54432018-03-16 14:02:34 +00001272 num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type());
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001273
Giorgio Arenac0f54432018-03-16 14:02:34 +00001274 win = calculate_max_window(*output, Steps());
Michalis Spyrou621965e2018-01-08 17:11:26 +00001275
Giorgio Arenac0f54432018-03-16 14:02:34 +00001276 AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x);
1277 AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size);
1278 window_changed = update_window_and_padding(win, input_access, weights_access);
1279 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001280
1281 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1282 return std::make_pair(err, win);
1283}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001284} // namespace
1285
1286NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001287 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1288 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001289{
1290}
1291
1292BorderSize NEDirectConvolutionLayerKernel::border_size() const
1293{
1294 return _border_size;
1295}
1296
1297void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1298{
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001299 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001300
1301 _input = input;
1302 _weights = weights;
1303 _output = output;
1304 _conv_info = conv_info;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001305 _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001306
1307 const unsigned int conv_pad_left = conv_info.pad_left();
1308 const unsigned int conv_pad_top = conv_info.pad_top();
1309 const unsigned int conv_pad_right = conv_info.pad_right();
1310 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1311 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001312
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001313 // Get convolved dimensions
Giorgio Arenac0f54432018-03-16 14:02:34 +00001314 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001315
1316 DataType data_type = input->info()->data_type();
1317
1318 if(is_data_type_fixed_point(data_type))
1319 {
1320 // Promote data type in case of fixed point
1321 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1322 }
1323
1324 // Output auto inizialitation if not yet initialized
1325 auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
1326
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001327 // Perform validation step
1328 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001329
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001330 // Configure kernel window
1331 auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001332 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001333 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1334 INEKernel::configure(win_config.second);
1335}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001336
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001337Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1338{
1339 unsigned int num_weight_elems_read_per_row = 0;
1340 unsigned int num_elems_read_per_iteration = 0;
1341 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas15997872018-02-19 13:58:22 +00001342 BorderSize border_size = {};
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001343 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info));
Georgios Pinitas0223a782017-12-12 11:44:44 +00001344 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1345 weights->clone().get(),
1346 output->clone().get(),
1347 conv_info,
1348 num_weight_elems_read_per_row,
1349 num_elems_read_per_iteration,
1350 num_elems_written_per_iteration,
1351 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001352 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001353
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001354 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001355}
1356
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001357void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001358{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001359 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001360 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1361 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1362 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1363
Giorgio Arenac0f54432018-03-16 14:02:34 +00001364 const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001365
Giorgio Arenac0f54432018-03-16 14:02:34 +00001366 if(_input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001367 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001368 switch(kernel_size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001369 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001370 case 1:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001371 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001372 switch(_input->info()->data_type())
1373 {
1374 case DataType::QS8:
1375 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1376 break;
1377 case DataType::QS16:
1378 convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1379 break;
1380 case DataType::F32:
1381 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1382 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001383#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001384 case DataType::F16:
1385 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1386 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001387#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001388 default:
1389 ARM_COMPUTE_ERROR("Data type not supported");
1390 break;
1391 }
1392 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001393 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001394 case 3:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001395 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001396 switch(_input->info()->data_type())
1397 {
1398 case DataType::QS8:
1399 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1400 break;
1401 case DataType::F32:
1402 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1403 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001404#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001405 case DataType::F16:
1406 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1407 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001408#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001409 default:
1410 ARM_COMPUTE_ERROR("Data type not supported");
1411 break;
1412 }
1413 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001414 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001415 case 5:
Pablo Tello06da39d2017-08-10 15:10:40 +01001416 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001417 switch(_input->info()->data_type())
1418 {
1419 case DataType::F32:
1420 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1421 break;
1422 default:
1423 ARM_COMPUTE_ERROR("Data type not supported");
1424 break;
1425 }
1426 break;
Pablo Tello06da39d2017-08-10 15:10:40 +01001427 }
Pablo Tello06da39d2017-08-10 15:10:40 +01001428
Giorgio Arenac0f54432018-03-16 14:02:34 +00001429 default:
1430 {
1431 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
1432 break;
1433 }
1434 }
1435 }
1436 else
1437 {
1438 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001439 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001440 case DataType::F32:
1441 convolver_nhwc<float>::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
1442 break;
1443 default:
1444 ARM_COMPUTE_ERROR("Data type not supported");
1445 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001446 }
1447 }
1448}