blob: 91b03687d8455520f12d034de8fc3c98305fb55d [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2017-2020 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
Georgios Pinitas4074c992018-01-30 18:13:46 +000025#include "arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
27#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010028#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Error.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/ITensor.h"
33#include "arm_compute/core/NEON/NEFixedPoint.h"
34#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010035#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include "arm_compute/core/Validate.h"
Giorgio Arenac0f54432018-03-16 14:02:34 +000037#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
Michalis Spyrou201c37c2018-10-25 17:25:54 +010039#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040#include <algorithm>
41#include <arm_neon.h>
42
43using namespace arm_compute;
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010044using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045
46namespace
47{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000048#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010049template <unsigned int stridex>
50float16x8_t internal_vld1q(const float16_t *in);
51
52template <>
53float16x8_t internal_vld1q<1>(const float16_t *in)
54{
55 return vld1q_f16(in);
56}
57
58template <>
59float16x8_t internal_vld1q<2>(const float16_t *in)
60{
61 const float16x8x2_t tmp = vld2q_f16(in);
62 return tmp.val[0];
63}
64
65template <>
66float16x8_t internal_vld1q<3>(const float16_t *in)
67{
68 const float16x8x3_t tmp = vld3q_f16(in);
69 return tmp.val[0];
70}
71
72inline float16x8_t internal_vdupq_n(float16_t v)
73{
74 return vdupq_n_f16(v);
75}
76
77inline void internal_vst1q(float16_t *p, const float16x8_t &v)
78{
79 vst1q_f16(p, v);
80}
81
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010082float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y)
Pablo Tello0d176142017-07-06 16:43:14 +010083{
Pablo Tello0d176142017-07-06 16:43:14 +010084 return vmulq_f16(x, y);
85}
86
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010087inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z)
Pablo Tello0d176142017-07-06 16:43:14 +010088{
Pablo Tello0d176142017-07-06 16:43:14 +010089 return vaddq_f16(x, vmulq_f16(y, z));
90}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000091#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +010092
Anthony Barbier6ff3b192017-09-04 18:44:23 +010093template <unsigned int stridex>
94float32x4_t internal_vld1q(const float *in);
95
96template <>
97float32x4_t internal_vld1q<1>(const float *in)
98{
99 return vld1q_f32(in);
100}
101
102template <>
103float32x4_t internal_vld1q<2>(const float *in)
104{
105 const float32x4x2_t tmp = vld2q_f32(in);
106 return tmp.val[0];
107}
108
109template <>
110float32x4_t internal_vld1q<3>(const float *in)
111{
112 const float32x4x3_t tmp = vld3q_f32(in);
113 return tmp.val[0];
114}
115
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100116inline float32x4_t internal_vdupq_n(float v)
117{
118 return vdupq_n_f32(v);
119}
120
121inline void internal_vst1q(float *p, const float32x4_t &v)
122{
123 vst1q_f32(p, v);
124}
125
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100126float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100127{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100128 return vmulq_f32(x, y);
129}
130
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100131inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100132{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100133 return vmlaq_f32(x, y, z);
134}
135
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000136constexpr int small_tensor_size_optim = 8;
137inline bool run_optim_small_tensor_info(const ITensorInfo *t)
138{
139 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
140}
141
Pablo Telloc09314a2017-09-21 13:59:14 +0100142inline bool run_optim_small_tensor(const ITensor *t)
143{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000144 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100145}
146
147// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
148// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
149// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
150template <unsigned int stridex>
151class convolver_w1x1_i8x8_f32
152{
153public:
154 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
155 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000156 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim);
157 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100158
Georgios Pinitas15997872018-02-19 13:58:22 +0000159 const int input_stride_x = input->info()->strides_in_bytes().x();
Pablo Telloc09314a2017-09-21 13:59:14 +0100160 const int input_stride_y = input->info()->strides_in_bytes().y();
161 const int input_stride_z = input->info()->strides_in_bytes().z();
162 const int output_stride_y = output->info()->strides_in_bytes().y();
163 const int output_stride_z = output->info()->strides_in_bytes().z();
164 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
165 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
166 const int output_h = output->info()->dimension(1);
167 const int range_z = window.z().end() - window.z().start();
168 const int kernel_depth = weights->info()->dimension(Window::DimZ);
169 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000170 const unsigned int conv_pad_left = conv_info.pad_left();
171 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Telloc09314a2017-09-21 13:59:14 +0100172
173 // setup output window for the iterator
174 Window window_out = window;
175 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
176 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
177 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
178
179 // setup input window for the iterator
180 Window window_in = window;
181 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
182 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
183 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
184 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
185
186 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
187 Iterator out(output, window_out);
188 Iterator in(input, window_in);
189 Iterator k(weights, window_k);
190
191 const uint8_t *k_ptr = k.ptr();
192
193 execute_window_loop(window_out, [&](const Coordinates & id)
194 {
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100195 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
196 uint8_t *out_ptr = out.ptr();
197 int ih = 0;
198 int oh = 0;
199 std::array<float32x4_t, 8> accum0 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
200 std::array<float32x4_t, 8> accum1 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100201 for(int oz = 0; oz < range_z; ++oz)
202 {
203 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
204 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
205 auto p_out_base = out_ptr + oz * output_stride_z;
206 for(int p = 0; p < kernel_depth; ++p)
207 {
208 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
209 const auto vk0 = internal_vdupq_n(*k_val);
210 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
211 {
212 const int offset_xy = ih * input_stride_y;
213 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
214 auto v_in0 = internal_vld1q<stridex>(in_val);
215 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
216 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
217 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
218 }
219 }
220 for(oh = 0; oh < output_h; ++oh)
221 {
222 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
223 vst1q_f32(p_out, accum0[oh]);
224 vst1q_f32(p_out + 4, accum1[oh]);
225 }
226 }
227 },
228 in, out);
229 }
230};
231
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232template <typename T1, typename T2, unsigned int stridex>
233class convolver_1x1
234{
235public:
236 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
237 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
238 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100239 const int input_stride_x = input->info()->strides_in_bytes().x();
240 const int input_stride_y = input->info()->strides_in_bytes().y();
241 const int input_stride_z = input->info()->strides_in_bytes().z();
242 const int output_stride_y = output->info()->strides_in_bytes().y();
243 const int output_stride_z = output->info()->strides_in_bytes().z();
244 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
245 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
246 const int output_w = output->info()->dimension(0);
247 const int output_h = output->info()->dimension(1);
248 const int range_z = window.z().end() - window.z().start();
249 const int kernel_depth = weights->info()->dimension(Window::DimZ);
250 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
251 const unsigned int conv_pad_left = conv_info.pad_left();
252 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100253
254 // setup output window for the iterator
255 Window window_out = window;
256 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
257 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
258 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
259
260 // setup input window for the iterator
261 Window window_in = window;
262 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
263 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
264 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
265 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
266
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100267 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100268 Iterator out(output, window_out);
269 Iterator in(input, window_in);
270 Iterator k(weights, window_k);
271
272 const uint8_t *k_ptr = k.ptr();
273
274 execute_window_loop(window_out, [&](const Coordinates & id)
275 {
276 /*
277 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
278 */
Georgios Pinitas15997872018-02-19 13:58:22 +0000279 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100280 uint8_t *out_ptr = out.ptr();
281 int ih = 0;
282 int oh = 0;
283 for(int oz = 0; oz < range_z; ++oz)
284 {
285 auto p_out_base = out_ptr + oz * output_stride_z;
286 // Step 1
287 {
288 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
289 const auto vk = internal_vdupq_n(*k_val);
290 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
291 {
292 const int offset_xy = ih * input_stride_y;
293 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
294 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
295 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
296 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100297 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100298 }
299 }
300 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100301
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100302 // Step 2
303 for(int p = 1; p < kernel_depth; ++p)
304 {
305 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
306 const auto vk = internal_vdupq_n(*k_val);
307 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
308 {
309 const int offset_xy = ih * input_stride_y;
310 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
311 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
312 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
313 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100314 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100315 }
316 }
317 }
318 }
319 },
320 in, out);
321 }
322};
323
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100324template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100325float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100326 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100327
328inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
329{
330 const float32x4x3_t m00 =
331 {
332 {
333 vld1q_dup_f32(m0),
334 vld1q_dup_f32(m1),
335 vld1q_dup_f32(m2)
336 }
337 };
338 return m00;
339}
340
341inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
342{
343 const float32x4x2_t m00 =
344 {
345 {
346 vld1q_dup_f32(m3),
347 vld1q_dup_f32(m4)
348 }
349 };
350 return m00;
351}
352
353inline float32x4x3_t load_input(const float *const in)
354{
355 const float32x4x3_t vin =
356 {
357 {
358 vld1q_f32(in),
359 vld1q_f32(in + 4),
360 vld1q_f32(in + 8)
361 }
362 };
363 return vin;
364}
365
366template <>
367inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100368 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100369{
Pablo Tello06da39d2017-08-10 15:10:40 +0100370 const float32x4x3_t vin0 = load_input(in_0);
371 const float32x4x3_t vin1 = load_input(in_1);
372 const float32x4x3_t vin2 = load_input(in_2);
373 const float32x4x3_t vin3 = load_input(in_3);
374 const float32x4x3_t vin4 = load_input(in_4);
375 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
376 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
377 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
378 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
379 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
380 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
381 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
382 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
383 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
384 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
385
386 float32x4x2_t out =
387 {
388 {
389 vmulq_f32(vin0.val[0], m00.val[0]),
390 vmulq_f32(vin0.val[1], m00.val[0])
391 }
392 };
393
394 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
395 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
396 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
397 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
398
399 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
400 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
401 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
402 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
403 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
404
405 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
406 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
407 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
408 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
409 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
410
411 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
412 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
413 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
414 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
415 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
416
417 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
418 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
419 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
420 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
421 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
422
423 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
424 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
425 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
426 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
427
428 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
429 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
430 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
431 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
432 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
433
434 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
435 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
436 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
437 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
438 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
439
440 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
441 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
442 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
443 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
444 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
445
446 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
447 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
448 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
449 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
450 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
451
452 return out;
453}
454
455template <>
456inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100457 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100458{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100459 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100460 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
461 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
462 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
463 return out;
464}
465
466template <>
467inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100468 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100469{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100470 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100471 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
472 return out;
473}
474
Giorgio Arenac0f54432018-03-16 14:02:34 +0000475template <typename T1>
476class convolver_nhwc
477{
478public:
Michalis Spyrou8ef57062020-01-14 12:15:48 +0000479 static void convolve(const Window &window, uint32_t kernel_size, unsigned int num_elems_read_per_iteration,
Giorgio Arenac0f54432018-03-16 14:02:34 +0000480 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
481 {
482 const int input_width = input->info()->dimension(0);
483 const int input_depth = input->info()->dimension(2);
484 const int input_stride_x = input->info()->strides_in_bytes().x();
485 const int input_stride_y = input->info()->strides_in_bytes().y();
486 const int input_stride_z = input->info()->strides_in_bytes().z();
487 const int output_stride_x = output->info()->strides_in_bytes().x();
488 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
489 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
490 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
491 const int conv_pad_top = conv_info.pad_top();
492 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
493 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
494 const T1 zero = 0;
495
496 // Setup input window for the input iterator
497 Window window_in = window;
498 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
499 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
500 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
501
502 // Setup input window for the output iterator
503 Window window_out = window;
504 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
505
506 // Setup input window for the weights iterator
507 Window window_k = calculate_max_window(*weights->info(), Steps());
508 window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
509 window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
510 window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
511 window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
512
513 Iterator in(input, window_in);
514 Iterator out(output, window_out);
515 Iterator k(weights, window_k);
516
517 execute_window_loop(window_k, [&](const Coordinates & id_k)
518 {
519 execute_window_loop(window_out, [&](const Coordinates & id)
520 {
521 const auto in_y = static_cast<int>(id.y() * conv_stride_x - conv_info.pad_left());
522 const auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
523
524 const uint8_t *in_ptr = in.ptr() + in_y * input_stride_y + in_z * input_stride_z;
525 uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x;
526
527 T1 out_val = 0;
528
529 auto in_addr_base0 = in_ptr;
530 auto we_addr_base0 = k.ptr();
531
Michalis Spyrou8ef57062020-01-14 12:15:48 +0000532 for(uint32_t z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z)
Giorgio Arenac0f54432018-03-16 14:02:34 +0000533 {
534 const int in_z = id.z() * conv_stride_y + z - conv_pad_top;
535
536 if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom
537 {
538 auto in_addr_base1 = in_addr_base0;
539 auto we_addr_base1 = we_addr_base0;
540
Michalis Spyrou8ef57062020-01-14 12:15:48 +0000541 for(uint32_t y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y)
Giorgio Arenac0f54432018-03-16 14:02:34 +0000542 {
543 auto out_values = internal_vdupq_n(zero);
544
545 int x = 0;
546 int no_leftover = input_width - num_elems_read_per_iteration;
547
548 for(; x < no_leftover; x += num_elems_read_per_iteration)
549 {
550 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
551 const auto in_values = internal_vld1q<1>(in_addr);
552
553 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
554 const auto we_values = internal_vld1q<1>(we_addr);
555
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100556 out_values = internal_vmlal(out_values, in_values, we_values);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000557 }
558
Michalis Spyrou201c37c2018-10-25 17:25:54 +0100559 auto carry_addition = wrapper::vpadd(wrapper::vgethigh(out_values), wrapper::vgetlow(out_values));
560 carry_addition = wrapper::vpadd(carry_addition, carry_addition);
561 out_val += wrapper::vgetlane(carry_addition, 0);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000562
563 // Leftover
564 for(; x < input_width; ++x)
565 {
566 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
567 const auto in_value = *(in_addr);
568
569 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
570 const auto we_value = *(we_addr);
571
572 out_val += in_value * we_value;
573 }
574 }
575 }
576 }
577
578 *(reinterpret_cast<T1 *>(out_ptr)) = out_val;
579 },
580 in, out);
581 },
582 k);
583 }
584};
585
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100586template <typename T1, typename T2, unsigned int stridex>
587class convolver_3x3
588{
589public:
590 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
591 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
592 {
593 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100594 const int input_stride_x = input->info()->strides_in_bytes().x();
595 const int input_stride_y = input->info()->strides_in_bytes().y();
596 const int input_stride_z = input->info()->strides_in_bytes().z();
597 const int output_stride_y = output->info()->strides_in_bytes().y();
598 const int output_stride_z = output->info()->strides_in_bytes().z();
599 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
600 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
601 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
602 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
603 const int output_w = output->info()->dimension(0);
604 const int output_h = output->info()->dimension(1);
605 const int num_planes_z = window.z().end() - window.z().start();
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000606 const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100607 const int kernel_depth = weights->info()->dimension(Window::DimZ);
608 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
609 const unsigned int conv_pad_left = conv_info.pad_left();
610 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100611
612 // setup output window for the iterator
613 Window window_out = window;
614 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
615 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
616 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
617
618 // setup input window for the iterator
619 Window window_in = window;
620 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
621 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
622 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
623 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
624
625 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
626
627 Iterator out(output, window_out);
628 Iterator in(input, window_in);
629 Iterator k(weights, window_k);
630
631 const uint8_t *k_ptr = k.ptr();
632
633 execute_window_loop(window_out, [&](const Coordinates & id)
634 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000635 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100636 uint8_t *out_ptr = out.ptr();
637 int ih = 0;
638 int oh = 0;
639 /*
640 Each thread executing this kernel computes one or more output's volume planes.
641
642 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
643 the third thread [16,24] and the fourth thread [25,31].
644
645 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100646 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100647
648 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
649 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
650 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
651 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100652 for(int oz = 0; oz < num_planes_z; ++oz)
653 {
Pablo Tello0d176142017-07-06 16:43:14 +0100654 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100655 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
656 // Step 1
657 {
Pablo Tello0d176142017-07-06 16:43:14 +0100658 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
659 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
660 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100661 const auto vk_r0 = load_matrix_row(ptr_k_r0);
662 const auto vk_r1 = load_matrix_row(ptr_k_r1);
663 const auto vk_r2 = load_matrix_row(ptr_k_r2);
664 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
665 {
666 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
667 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
668 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
669 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
670 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
671 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
672 {
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000673 convolve_3x3<false>(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100674 }
675 }
676 }
677 // Step 2
678 for(int p = 1; p < kernel_depth; ++p)
679 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100680 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
681 const uint8_t *input_base = input_ptr + p * input_stride_z;
682 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
683 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
684 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
685 const auto vk_r0 = load_matrix_row(ptr_k_r0);
686 const auto vk_r1 = load_matrix_row(ptr_k_r1);
687 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100688 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
689 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100690 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
691 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
692 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100693 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
694 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
695 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
696 {
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000697 convolve_3x3<true>(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100698 }
699 }
700 }
701 }
702 },
703 in, out);
704 }
705};
706
Pablo Tello06da39d2017-08-10 15:10:40 +0100707template <typename T1, typename T2, unsigned int stridex>
708class convolver_5x5
709{
710public:
711 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
712 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
713 {
714 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100715 const int input_stride_x = input->info()->strides_in_bytes().x();
716 const int input_stride_y = input->info()->strides_in_bytes().y();
717 const int input_stride_z = input->info()->strides_in_bytes().z();
718 const int output_stride_y = output->info()->strides_in_bytes().y();
719 const int output_stride_z = output->info()->strides_in_bytes().z();
720 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
721 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
722 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
723 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
724 const int output_w = output->info()->dimension(0);
725 const int output_h = output->info()->dimension(1);
726 const int num_planes_z = window.z().end() - window.z().start();
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000727 const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100728 const int kernel_depth = weights->info()->dimension(Window::DimZ);
729 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
730 const unsigned int conv_pad_left = conv_info.pad_left();
731 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Tello06da39d2017-08-10 15:10:40 +0100732
733 // setup output window for the iterator
734 Window window_out = window;
735 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
736 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
737 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
738
739 // setup input window for the iterator
740 Window window_in = window;
741 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
742 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
743 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
744 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
745
746 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
747
748 Iterator out(output, window_out);
749 Iterator in(input, window_in);
750 Iterator k(weights, window_k);
751
752 const uint8_t *k_ptr = k.ptr();
753
754 execute_window_loop(window_out, [&](const Coordinates & id)
755 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000756 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Pablo Tello06da39d2017-08-10 15:10:40 +0100757 uint8_t *out_ptr = out.ptr();
758 int ih = 0;
759 int oh = 0;
760 for(int oz = 0; oz < num_planes_z; ++oz)
761 {
762 const int zoffset = id.z() + oz;
763 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
764 // Step 1
765 {
766 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
767 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
768 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
769 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
770 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
771 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
772 {
773 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
774 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
775 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
776 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
777 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
778 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
779 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
780 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
781 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100782 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100783 store_results<stridex>(p_out, vres);
784 }
785 }
786 }
787 // Step 2
788 for(int p = 1; p < kernel_depth; ++p)
789 {
790 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
791 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
792 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
793 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
794 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
795
796 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
797 {
798 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
799 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
800 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
801 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
802 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
803 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
804 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
805 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
806 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100807 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100808 accumulate_results<stridex>(p_out, vres);
809 }
810 }
811 }
812 }
813 },
814 in, out);
815 }
816};
817
Gian Marco Iodice95f93612019-06-13 15:58:32 +0100818inline void convolve_row1x9_nhwc(const float *row_ptr, const float *weights_ptr, size_t src_stride_y, size_t weights_stride_y,
819 float32x4_t &acc0, float32x4_t &acc1, float32x4_t &acc2, float32x4_t &acc3)
820{
821 // Load 4 channels for each of the 12 inputs values along the same X spatial dimension
822 const float32x4_t src0 = wrapper::vloadq(row_ptr);
823 const float32x4_t src1 = wrapper::vloadq(row_ptr + 1 * src_stride_y);
824 const float32x4_t src2 = wrapper::vloadq(row_ptr + 2 * src_stride_y);
825 const float32x4_t src3 = wrapper::vloadq(row_ptr + 3 * src_stride_y);
826 const float32x4_t src4 = wrapper::vloadq(row_ptr + 4 * src_stride_y);
827 const float32x4_t src5 = wrapper::vloadq(row_ptr + 5 * src_stride_y);
828 const float32x4_t src6 = wrapper::vloadq(row_ptr + 6 * src_stride_y);
829 const float32x4_t src7 = wrapper::vloadq(row_ptr + 7 * src_stride_y);
830 const float32x4_t src8 = wrapper::vloadq(row_ptr + 8 * src_stride_y);
831 const float32x4_t src9 = wrapper::vloadq(row_ptr + 9 * src_stride_y);
832 const float32x4_t src10 = wrapper::vloadq(row_ptr + 10 * src_stride_y);
833 const float32x4_t src11 = wrapper::vloadq(row_ptr + 11 * src_stride_y);
834
835 // Load 4 channels for each of the 9 weights values along the same X spatial dimension
836 const float32x4_t w0 = wrapper::vloadq(weights_ptr);
837 const float32x4_t w1 = wrapper::vloadq(weights_ptr + 1 * weights_stride_y);
838 const float32x4_t w2 = wrapper::vloadq(weights_ptr + 2 * weights_stride_y);
839 const float32x4_t w3 = wrapper::vloadq(weights_ptr + 3 * weights_stride_y);
840 const float32x4_t w4 = wrapper::vloadq(weights_ptr + 4 * weights_stride_y);
841 const float32x4_t w5 = wrapper::vloadq(weights_ptr + 5 * weights_stride_y);
842 const float32x4_t w6 = wrapper::vloadq(weights_ptr + 6 * weights_stride_y);
843 const float32x4_t w7 = wrapper::vloadq(weights_ptr + 7 * weights_stride_y);
844 const float32x4_t w8 = wrapper::vloadq(weights_ptr + 8 * weights_stride_y);
845
846 // Store 4 channels for each of the 4 output values along the same X spatial dimension
847 acc0 = wrapper::vmla(acc0, w0, src0);
848 acc0 = wrapper::vmla(acc0, w1, src1);
849 acc0 = wrapper::vmla(acc0, w2, src2);
850 acc0 = wrapper::vmla(acc0, w3, src3);
851 acc0 = wrapper::vmla(acc0, w4, src4);
852 acc0 = wrapper::vmla(acc0, w5, src5);
853 acc0 = wrapper::vmla(acc0, w6, src6);
854 acc0 = wrapper::vmla(acc0, w7, src7);
855 acc0 = wrapper::vmla(acc0, w8, src8);
856
857 acc1 = wrapper::vmla(acc1, w0, src1);
858 acc1 = wrapper::vmla(acc1, w1, src2);
859 acc1 = wrapper::vmla(acc1, w2, src3);
860 acc1 = wrapper::vmla(acc1, w3, src4);
861 acc1 = wrapper::vmla(acc1, w4, src5);
862 acc1 = wrapper::vmla(acc1, w5, src6);
863 acc1 = wrapper::vmla(acc1, w6, src7);
864 acc1 = wrapper::vmla(acc1, w7, src8);
865 acc1 = wrapper::vmla(acc1, w8, src9);
866
867 acc2 = wrapper::vmla(acc2, w0, src2);
868 acc2 = wrapper::vmla(acc2, w1, src3);
869 acc2 = wrapper::vmla(acc2, w2, src4);
870 acc2 = wrapper::vmla(acc2, w3, src5);
871 acc2 = wrapper::vmla(acc2, w4, src6);
872 acc2 = wrapper::vmla(acc2, w5, src7);
873 acc2 = wrapper::vmla(acc2, w6, src8);
874 acc2 = wrapper::vmla(acc2, w7, src9);
875 acc2 = wrapper::vmla(acc2, w8, src10);
876
877 acc3 = wrapper::vmla(acc3, w0, src3);
878 acc3 = wrapper::vmla(acc3, w1, src4);
879 acc3 = wrapper::vmla(acc3, w2, src5);
880 acc3 = wrapper::vmla(acc3, w3, src6);
881 acc3 = wrapper::vmla(acc3, w4, src7);
882 acc3 = wrapper::vmla(acc3, w5, src8);
883 acc3 = wrapper::vmla(acc3, w6, src9);
884 acc3 = wrapper::vmla(acc3, w7, src10);
885 acc3 = wrapper::vmla(acc3, w8, src11);
886}
887
888float vreduce(const float32x4_t &v)
889{
890 auto v0 = wrapper::vgethigh(v);
891 auto v1 = wrapper::vgetlow(v);
892 auto v_out = wrapper::vadd(v0, v1);
893
894 float a = wrapper::vgetlane(v_out, 0);
895 float b = wrapper::vgetlane(v_out, 1);
896 return a + b;
897}
898
899template <typename V>
900class convolver_9x9_nhwc
901{
902public:
903 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration,
904 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
905 {
906 // Declare useful types
907 using vector_type = typename V::type;
908 using scalar_type = typename V::scalar_type;
909 using tag_type = typename V::tag_type;
910
911 // Scalar quantities
912 const int element_size = input->info()->element_size();
913 const int input_width = input->info()->dimension(0);
914 const int input_depth = input->info()->dimension(2);
915 const int input_stride_y = input->info()->strides_in_bytes().y() / element_size;
916 const int input_stride_z = input->info()->strides_in_bytes().z() / element_size;
917 const int input_stride_w = input->info()->strides_in_bytes()[3];
918 const int output_stride_x = output->info()->strides_in_bytes().x();
919 const int output_stride_y = output->info()->strides_in_bytes().y();
920 const int kernel_stride_y = weights->info()->strides_in_bytes().y() / element_size;
921 const int kernel_stride_z = weights->info()->strides_in_bytes().z() / element_size;
922 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
923 const unsigned int conv_pad_top = conv_info.pad_top();
924 const unsigned int conv_pad_left = conv_info.pad_left();
925
926 // Setup input window for the input iterator
927 Window window_in = window;
928 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
929 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
930 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
931
932 // Setup input window for the output iterator
933 Window window_out = window;
934 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
935
936 // Setup input window for the weights iterator
937 Window window_k = calculate_max_window(*weights->info(), Steps());
938 window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
939 window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
940 window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
941 window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
942
943 Iterator in(input, window_in);
944 Iterator out(output, window_out);
945 Iterator k(weights, window_k);
946
947 // Calculate the max_offset.
948 // max_offset is the offset for the last NOT valid value in the Z dimension (spatial dimension Y for NHWC)
949 // |******************|
950 // | pad_top |
951 // |******************|
952 // | |
953 // | plane0 |
954 // | batch0 |
955 // |__________________|
956 // |******************| Batch 0
957 // | pad_bottom |
958 // | pad_top |
959 // |******************|
960 // | |
961 // | plane1 |
962 // | batch0 |
963 // |__________________|-----> max_offset
964 // |******************|
965 // | pad_bottom |
966 // | pad_top |
967 // |******************|
968 // | |
969 // | plane0 |
970 // | batch1 |
971 // |__________________|
972 // |******************| Batch 1
973 // | pad_bottom |
974 // | pad_top |
975 // |******************|
976 // | |
977 // | plane1 |
978 // | batch1 |
979 // |__________________|
980 // | pad_bottom |
981 // |******************|
982 const int max_offset = input_stride_z * input_depth - (input->info()->padding().bottom + input->info()->padding().top) * input_stride_y;
983 execute_window_loop(window_k, [&](const Coordinates & id_k) // loop on the batch size
984 {
985
986 execute_window_loop(window_out, [&](const Coordinates & id)
987 {
988 const auto y_offset = int(id.y() - conv_pad_left) * input_stride_y;
989
990 // Buffer pointers
991 const scalar_type *in_ptr = reinterpret_cast<scalar_type *>(input->buffer() + input->info()->offset_first_element_in_bytes() + id[3] * input_stride_w);
992 const scalar_type *weights_ptr = reinterpret_cast<scalar_type *>(k.ptr());
993 uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x;
994
995 // Output elements
996 vector_type out0 = wrapper::vdup_n(scalar_type(0), tag_type());
997 vector_type out1 = wrapper::vdup_n(scalar_type(0), tag_type());
998 vector_type out2 = wrapper::vdup_n(scalar_type(0), tag_type());
999 vector_type out3 = wrapper::vdup_n(scalar_type(0), tag_type());
1000
1001 // Reduce along the feature maps
1002 for(int x = 0; x < input_width; x += num_elems_read_per_iteration)
1003 {
1004 // z == 0
1005 auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
1006 in_z = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
1007 auto offset = y_offset + in_z * input_stride_z;
1008 offset = std::min(offset, max_offset);
1009 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 0 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1010
1011 // z == 1
1012 in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top + 1);
1013 in_z = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
1014 offset = y_offset + in_z * input_stride_z;
1015 offset = std::min(offset, max_offset);
1016 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 1 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1017
1018 // z == 2
1019 in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top + 2);
1020 in_z = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
1021 offset = y_offset + in_z * input_stride_z;
1022 offset = std::min(offset, max_offset);
1023 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 2 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1024
1025 // z == 3
1026 in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top + 3);
1027 offset = y_offset + in_z * input_stride_z;
1028 offset = std::min(offset, max_offset);
1029 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 3 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1030
1031 // z == 4
1032 in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top + 4);
1033 offset = y_offset + in_z * input_stride_z;
1034 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 4 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1035
1036 // z == 5
1037 offset += input_stride_z;
1038 offset = std::min(offset, max_offset);
1039 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 5 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1040
1041 // z == 6
1042 offset += input_stride_z;
1043 offset = std::min(offset, max_offset);
1044 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 6 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1045
1046 // z == 7
1047 offset += input_stride_z;
1048 offset = std::min(offset, max_offset);
1049 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 7 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1050
1051 // z == 8
1052 offset += input_stride_z;
1053 offset = std::min(offset, max_offset);
1054 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 8 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1055 }
1056
1057 *(reinterpret_cast<scalar_type *>(out_ptr + 0 * output_stride_y)) = vreduce(out0);
1058 *(reinterpret_cast<scalar_type *>(out_ptr + 1 * output_stride_y)) = vreduce(out1);
1059 *(reinterpret_cast<scalar_type *>(out_ptr + 2 * output_stride_y)) = vreduce(out2);
1060 *(reinterpret_cast<scalar_type *>(out_ptr + 3 * output_stride_y)) = vreduce(out3);
1061 },
1062 in, out);
1063 },
1064 k);
1065 }
1066};
1067
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001068template <typename T1, typename T2>
1069inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1070 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1071{
1072 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1073 switch(conv_stride_x)
1074 {
1075 case 1:
1076 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1077 break;
1078 case 2:
1079 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1080 break;
1081 case 3:
1082 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1083 break;
1084 default:
1085 ARM_COMPUTE_ERROR("Not implemented");
1086 }
1087}
1088
Pablo Telloc09314a2017-09-21 13:59:14 +01001089template <>
1090inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1091 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1092{
1093 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1094 if(run_optim_small_tensor(input))
1095 {
1096 switch(conv_stride_x)
1097 {
1098 case 1:
1099 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
1100 break;
1101 case 2:
1102 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
1103 break;
1104 case 3:
1105 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
1106 break;
1107 default:
1108 ARM_COMPUTE_ERROR("Not implemented");
1109 }
1110 }
1111 else
1112 {
1113 switch(conv_stride_x)
1114 {
1115 case 1:
1116 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1117 break;
1118 case 2:
1119 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1120 break;
1121 case 3:
1122 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1123 break;
1124 default:
1125 ARM_COMPUTE_ERROR("Not implemented");
1126 }
1127 }
1128}
1129
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001130template <typename T1, typename T2>
1131inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1132 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1133{
1134 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1135 switch(conv_stride_x)
1136 {
1137 case 1:
1138 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1139 break;
1140 case 2:
1141 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1142 break;
1143 case 3:
1144 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1145 break;
1146 default:
1147 ARM_COMPUTE_ERROR("Not implemented");
1148 }
1149}
Pablo Tello06da39d2017-08-10 15:10:40 +01001150
1151template <typename T1, typename T2>
1152inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1153 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1154{
1155 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1156 switch(conv_stride_x)
1157 {
1158 case 1:
1159 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1160 break;
1161 case 2:
1162 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1163 break;
1164 case 3:
1165 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1166 break;
1167 default:
1168 ARM_COMPUTE_ERROR("Not implemented");
1169 }
1170}
1171
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001172template <typename V>
1173inline void convolve_9x9_nhwc(const Window &window, unsigned int num_elems_read_per_iteration,
1174 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1175{
1176 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1177 switch(conv_stride_x)
1178 {
1179 case 1:
1180 convolver_9x9_nhwc<V>::convolve(window, num_elems_read_per_iteration, input, weights, output, conv_info);
1181 break;
1182 default:
1183 ARM_COMPUTE_ERROR("Not implemented");
1184 }
1185}
1186
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001187Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1188{
1189 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001190 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Anthony Barbiereaefd002018-07-20 17:49:35 +01001191 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001192 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001193 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001194
Giorgio Arenac0f54432018-03-16 14:02:34 +00001195 const DataLayout data_layout = input->data_layout();
1196 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
1197 const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
1198 const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
1199
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001200 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Giorgio Arenac0f54432018-03-16 14:02:34 +00001201 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx));
1202 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001203 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001204 ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
Gian Marco Iodice41acb762018-08-23 10:25:06 +01001205 ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (input->data_type() == DataType::F16));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001206
1207 // Checks performed when output is configured
1208 if(output->total_size() != 0)
1209 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001210 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001211
1212 DataType data_type = input->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001213
1214 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
1215 ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
1216 }
1217
1218 return Status{};
1219}
1220
1221std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001222 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001223{
Giorgio Arenac0f54432018-03-16 14:02:34 +00001224 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
1225
1226 const DataLayout data_layout = input->data_layout();
1227 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
1228
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001229 // Calculate right and bottom border
Giorgio Arenac0f54432018-03-16 14:02:34 +00001230 unsigned int kernel_size = weights->dimension(width_idx);
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001231 const int conv_stride_x = std::get<0>(conv_info.stride());
Georgios Pinitas1a03d762018-02-21 14:47:09 +00001232 const int conv_stride_y = std::get<1>(conv_info.stride());
Giorgio Arenac0f54432018-03-16 14:02:34 +00001233 const int input_width = input->dimension(width_idx);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001234
Giorgio Arenac0f54432018-03-16 14:02:34 +00001235 Window win{};
1236 bool window_changed = false;
1237
1238 if(data_layout == DataLayout::NCHW)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001239 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001240 switch(kernel_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001241 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001242 case 1:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001243 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001244 switch(input->data_type())
1245 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001246#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001247 case DataType::F16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001248 num_elems_written_per_iteration = 8;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001249 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001250#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001251 case DataType::F32:
1252 if(run_optim_small_tensor_info(input))
1253 {
1254 num_elems_written_per_iteration = 8;
1255 }
1256 else
1257 {
1258 num_elems_written_per_iteration = 4;
1259 }
1260 break;
1261 default:
1262 ARM_COMPUTE_ERROR("Data type not supported.");
1263 break;
1264 }
1265 num_weight_elems_read_per_row = kernel_size;
1266 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
1267 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001268 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001269 case 3:
Giorgio Arenac0f54432018-03-16 14:02:34 +00001270 switch(input->data_type())
1271 {
1272 case DataType::F32:
1273 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1274 num_elems_read_per_iteration = 12;
1275 num_elems_written_per_iteration = 16 >> conv_stride_x;
1276 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001277#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001278 case DataType::F16:
Giorgio Arenac0f54432018-03-16 14:02:34 +00001279 num_weight_elems_read_per_row = 8 + kernel_size - 1;
1280 num_elems_read_per_iteration = 24;
1281 num_elems_written_per_iteration = 32 >> conv_stride_x;
1282 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001283#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001284 default:
1285 ARM_COMPUTE_ERROR("Data type not supported.");
1286 break;
1287 }
Gian Marco Iodice41acb762018-08-23 10:25:06 +01001288 break;
1289 case 5:
1290 {
1291 switch(input->data_type())
1292 {
1293 case DataType::F32:
1294 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1295 num_elems_read_per_iteration = 12;
1296 num_elems_written_per_iteration = 16 >> conv_stride_x;
1297 break;
1298 default:
1299 ARM_COMPUTE_ERROR("Data type not supported.");
1300 break;
1301 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001302 }
1303 break;
1304 default:
1305 {
1306 ARM_COMPUTE_ERROR("Not implemented");
1307 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001308 }
1309 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001310
1311 // Calculate right pad
1312 int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
1313 int end_x = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
1314 int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
1315
1316 // Calculate border
1317 const unsigned int conv_pad_left = conv_info.pad_left();
1318 const unsigned int conv_pad_top = conv_info.pad_top();
1319 const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
1320 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1321
1322 border_size.left = conv_pad_left;
1323 border_size.top = conv_pad_top;
1324 border_size.right = conv_pad_right;
1325 border_size.bottom = conv_pad_bottom;
1326
1327 // Configure window
1328 win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
1329
1330 AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
1331 num_elems_read_per_iteration, kernel_size,
1332 conv_stride_x, conv_stride_y);
1333 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
1334 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
1335 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1336 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001337 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001338 else
1339 {
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001340 if(kernel_size == 9)
1341 {
1342 border_size.left = 0;
1343 border_size.top = conv_info.pad_left();
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001344
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001345 const int num_elems_read_per_iteration_x = 4;
1346 const int num_elems_written_per_iteration_x = 1;
1347 const int num_elems_read_per_iteration_y = 12;
1348 const int num_elems_written_per_iteration_y = 4;
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001349
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001350 num_elems_read_per_iteration = num_elems_read_per_iteration_x;
1351 num_elems_written_per_iteration = num_elems_written_per_iteration_x;
Michalis Spyrou621965e2018-01-08 17:11:26 +00001352
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001353 border_size.right = num_elems_read_per_iteration_x;
1354 if((conv_info.pad_bottom() != 0) || (conv_info.pad_top() != 0))
1355 {
1356 // If bottom or top padding are set, we need to read num_elems_read_per_iteration_y rows to zero.
1357 // Since num_elems_read_per_iteration_y is always greater than conv_info.pad_right() we can set
1358 // the bottom padding to num_elems_read_per_iteration_y
1359 border_size.bottom = num_elems_read_per_iteration_y;
1360 }
1361 else if(conv_info.pad_right() != 0)
1362 {
1363 // Convetional border padding. Fill the bottom paddings so that we can read in batch of num_elems_read_per_iteration_y
1364 border_size.bottom = ceil_to_multiple(input->dimension(1) + conv_info.pad_right(), num_elems_read_per_iteration_y) - input->dimension(1);
1365 }
1366 else
1367 {
1368 // No padding
1369 border_size.bottom = 0;
1370 }
1371
1372 win = calculate_max_window(*output, Steps(num_elems_written_per_iteration_x, num_elems_written_per_iteration_y));
1373
1374 AccessWindowStatic input_access(input, 0, -border_size.top,
1375 ceil_to_multiple(input->dimension(0), num_elems_read_per_iteration_x),
1376 input->dimension(1) + border_size.bottom);
1377
1378 AccessWindowStatic weights_access(weights, 0, 0, ceil_to_multiple(weights->dimension(0), num_elems_read_per_iteration_x), weights->dimension(1));
1379 AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_y);
1380 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1381 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1382 }
1383 else
1384 {
1385 border_size.left = 0;
1386 border_size.top = conv_info.pad_left();
1387 border_size.right = 0;
1388 border_size.bottom = conv_info.pad_right();
1389 num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type());
1390 win = calculate_max_window(*output, Steps());
1391
1392 AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x);
1393 AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size);
1394 window_changed = update_window_and_padding(win, input_access, weights_access);
1395 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001396 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001397
1398 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1399 return std::make_pair(err, win);
1400}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001401} // namespace
1402
1403NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001404 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1405 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001406{
1407}
1408
1409BorderSize NEDirectConvolutionLayerKernel::border_size() const
1410{
1411 return _border_size;
1412}
1413
1414void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1415{
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001416 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001417
1418 _input = input;
1419 _weights = weights;
1420 _output = output;
1421 _conv_info = conv_info;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001422 _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001423
1424 const unsigned int conv_pad_left = conv_info.pad_left();
1425 const unsigned int conv_pad_top = conv_info.pad_top();
1426 const unsigned int conv_pad_right = conv_info.pad_right();
1427 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1428 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001429
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001430 // Get convolved dimensions
Giorgio Arenac0f54432018-03-16 14:02:34 +00001431 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001432
1433 DataType data_type = input->info()->data_type();
1434
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001435 // Output auto inizialitation if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001436 auto_init_if_empty(*output->info(), output_shape, 1, data_type);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001437
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001438 // Perform validation step
1439 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001440
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001441 // Configure kernel window
1442 auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001443 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001444 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1445 INEKernel::configure(win_config.second);
1446}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001447
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001448Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1449{
1450 unsigned int num_weight_elems_read_per_row = 0;
1451 unsigned int num_elems_read_per_iteration = 0;
1452 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas15997872018-02-19 13:58:22 +00001453 BorderSize border_size = {};
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001454 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info));
Georgios Pinitas0223a782017-12-12 11:44:44 +00001455 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1456 weights->clone().get(),
1457 output->clone().get(),
1458 conv_info,
1459 num_weight_elems_read_per_row,
1460 num_elems_read_per_iteration,
1461 num_elems_written_per_iteration,
1462 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001463 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001464
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001465 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001466}
1467
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001468void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001469{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001470 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001471 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1472 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1473 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1474
Giorgio Arenac0f54432018-03-16 14:02:34 +00001475 const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001476
Giorgio Arenac0f54432018-03-16 14:02:34 +00001477 if(_input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001478 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001479 switch(kernel_size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001480 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001481 case 1:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001482 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001483 switch(_input->info()->data_type())
1484 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001485 case DataType::F32:
1486 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1487 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001488#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001489 case DataType::F16:
1490 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1491 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001492#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001493 default:
1494 ARM_COMPUTE_ERROR("Data type not supported");
1495 break;
1496 }
1497 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001498 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001499 case 3:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001500 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001501 switch(_input->info()->data_type())
1502 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001503 case DataType::F32:
1504 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1505 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001506#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001507 case DataType::F16:
1508 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1509 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001510#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001511 default:
1512 ARM_COMPUTE_ERROR("Data type not supported");
1513 break;
1514 }
1515 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001516 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001517 case 5:
Pablo Tello06da39d2017-08-10 15:10:40 +01001518 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001519 switch(_input->info()->data_type())
1520 {
1521 case DataType::F32:
1522 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1523 break;
1524 default:
1525 ARM_COMPUTE_ERROR("Data type not supported");
1526 break;
1527 }
1528 break;
Pablo Tello06da39d2017-08-10 15:10:40 +01001529 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001530 default:
1531 {
1532 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
1533 break;
1534 }
1535 }
1536 }
1537 else
1538 {
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001539 const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
1540 const int stride_x = std::get<0>(_conv_info.stride());
1541 const int stride_y = std::get<1>(_conv_info.stride());
1542
Giorgio Arenac0f54432018-03-16 14:02:34 +00001543 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001544 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001545 case DataType::F32:
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001546 {
1547 if(kernel_size == 9 && stride_x == 1 && stride_y == 1)
1548 {
1549 using vtype = wrapper::traits::neon_vector<float, 4>;
1550 convolve_9x9_nhwc<vtype>(window, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
1551 }
1552 else
1553 {
1554 convolver_nhwc<float>::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
1555 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001556 break;
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001557 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001558 default:
1559 ARM_COMPUTE_ERROR("Data type not supported");
1560 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001561 }
1562 }
1563}