blob: bcf70b3ad8cf2d05a10f515576a3309cd17645f9 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyroubcfd09a2019-05-01 13:03:59 +01002 * Copyright (c) 2017-2019 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
Georgios Pinitas4074c992018-01-30 18:13:46 +000025#include "arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
27#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010028#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Error.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/ITensor.h"
33#include "arm_compute/core/NEON/NEFixedPoint.h"
34#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010035#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include "arm_compute/core/Validate.h"
Giorgio Arenac0f54432018-03-16 14:02:34 +000037#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
Michalis Spyrou201c37c2018-10-25 17:25:54 +010039#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040#include <algorithm>
41#include <arm_neon.h>
42
43using namespace arm_compute;
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010044using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045
46namespace
47{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000048#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010049template <unsigned int stridex>
50float16x8_t internal_vld1q(const float16_t *in);
51
52template <>
53float16x8_t internal_vld1q<1>(const float16_t *in)
54{
55 return vld1q_f16(in);
56}
57
58template <>
59float16x8_t internal_vld1q<2>(const float16_t *in)
60{
61 const float16x8x2_t tmp = vld2q_f16(in);
62 return tmp.val[0];
63}
64
65template <>
66float16x8_t internal_vld1q<3>(const float16_t *in)
67{
68 const float16x8x3_t tmp = vld3q_f16(in);
69 return tmp.val[0];
70}
71
72inline float16x8_t internal_vdupq_n(float16_t v)
73{
74 return vdupq_n_f16(v);
75}
76
77inline void internal_vst1q(float16_t *p, const float16x8_t &v)
78{
79 vst1q_f16(p, v);
80}
81
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010082float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y)
Pablo Tello0d176142017-07-06 16:43:14 +010083{
Pablo Tello0d176142017-07-06 16:43:14 +010084 return vmulq_f16(x, y);
85}
86
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010087inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z)
Pablo Tello0d176142017-07-06 16:43:14 +010088{
Pablo Tello0d176142017-07-06 16:43:14 +010089 return vaddq_f16(x, vmulq_f16(y, z));
90}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000091#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +010092
Anthony Barbier6ff3b192017-09-04 18:44:23 +010093template <unsigned int stridex>
94float32x4_t internal_vld1q(const float *in);
95
96template <>
97float32x4_t internal_vld1q<1>(const float *in)
98{
99 return vld1q_f32(in);
100}
101
102template <>
103float32x4_t internal_vld1q<2>(const float *in)
104{
105 const float32x4x2_t tmp = vld2q_f32(in);
106 return tmp.val[0];
107}
108
109template <>
110float32x4_t internal_vld1q<3>(const float *in)
111{
112 const float32x4x3_t tmp = vld3q_f32(in);
113 return tmp.val[0];
114}
115
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100116inline float32x4_t internal_vdupq_n(float v)
117{
118 return vdupq_n_f32(v);
119}
120
121inline void internal_vst1q(float *p, const float32x4_t &v)
122{
123 vst1q_f32(p, v);
124}
125
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100126float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100127{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100128 return vmulq_f32(x, y);
129}
130
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100131inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100132{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100133 return vmlaq_f32(x, y, z);
134}
135
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000136constexpr int small_tensor_size_optim = 8;
137inline bool run_optim_small_tensor_info(const ITensorInfo *t)
138{
139 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
140}
141
Pablo Telloc09314a2017-09-21 13:59:14 +0100142inline bool run_optim_small_tensor(const ITensor *t)
143{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000144 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100145}
146
147// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
148// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
149// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
150template <unsigned int stridex>
151class convolver_w1x1_i8x8_f32
152{
153public:
154 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
155 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000156 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim);
157 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100158
Georgios Pinitas15997872018-02-19 13:58:22 +0000159 const int input_stride_x = input->info()->strides_in_bytes().x();
Pablo Telloc09314a2017-09-21 13:59:14 +0100160 const int input_stride_y = input->info()->strides_in_bytes().y();
161 const int input_stride_z = input->info()->strides_in_bytes().z();
162 const int output_stride_y = output->info()->strides_in_bytes().y();
163 const int output_stride_z = output->info()->strides_in_bytes().z();
164 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
165 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
166 const int output_h = output->info()->dimension(1);
167 const int range_z = window.z().end() - window.z().start();
168 const int kernel_depth = weights->info()->dimension(Window::DimZ);
169 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000170 const unsigned int conv_pad_left = conv_info.pad_left();
171 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Telloc09314a2017-09-21 13:59:14 +0100172
173 // setup output window for the iterator
174 Window window_out = window;
175 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
176 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
177 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
178
179 // setup input window for the iterator
180 Window window_in = window;
181 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
182 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
183 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
184 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
185
186 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
187 Iterator out(output, window_out);
188 Iterator in(input, window_in);
189 Iterator k(weights, window_k);
190
191 const uint8_t *k_ptr = k.ptr();
192
193 execute_window_loop(window_out, [&](const Coordinates & id)
194 {
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100195 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
196 uint8_t *out_ptr = out.ptr();
197 int ih = 0;
198 int oh = 0;
199 std::array<float32x4_t, 8> accum0 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
200 std::array<float32x4_t, 8> accum1 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100201 for(int oz = 0; oz < range_z; ++oz)
202 {
203 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
204 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
205 auto p_out_base = out_ptr + oz * output_stride_z;
206 for(int p = 0; p < kernel_depth; ++p)
207 {
208 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
209 const auto vk0 = internal_vdupq_n(*k_val);
210 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
211 {
212 const int offset_xy = ih * input_stride_y;
213 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
214 auto v_in0 = internal_vld1q<stridex>(in_val);
215 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
216 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
217 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
218 }
219 }
220 for(oh = 0; oh < output_h; ++oh)
221 {
222 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
223 vst1q_f32(p_out, accum0[oh]);
224 vst1q_f32(p_out + 4, accum1[oh]);
225 }
226 }
227 },
228 in, out);
229 }
230};
231
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232template <typename T1, typename T2, unsigned int stridex>
233class convolver_1x1
234{
235public:
236 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
237 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
238 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100239 const int input_stride_x = input->info()->strides_in_bytes().x();
240 const int input_stride_y = input->info()->strides_in_bytes().y();
241 const int input_stride_z = input->info()->strides_in_bytes().z();
242 const int output_stride_y = output->info()->strides_in_bytes().y();
243 const int output_stride_z = output->info()->strides_in_bytes().z();
244 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
245 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
246 const int output_w = output->info()->dimension(0);
247 const int output_h = output->info()->dimension(1);
248 const int range_z = window.z().end() - window.z().start();
249 const int kernel_depth = weights->info()->dimension(Window::DimZ);
250 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
251 const unsigned int conv_pad_left = conv_info.pad_left();
252 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100253
254 // setup output window for the iterator
255 Window window_out = window;
256 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
257 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
258 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
259
260 // setup input window for the iterator
261 Window window_in = window;
262 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
263 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
264 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
265 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
266
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100267 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100268 Iterator out(output, window_out);
269 Iterator in(input, window_in);
270 Iterator k(weights, window_k);
271
272 const uint8_t *k_ptr = k.ptr();
273
274 execute_window_loop(window_out, [&](const Coordinates & id)
275 {
276 /*
277 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
278 */
Georgios Pinitas15997872018-02-19 13:58:22 +0000279 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100280 uint8_t *out_ptr = out.ptr();
281 int ih = 0;
282 int oh = 0;
283 for(int oz = 0; oz < range_z; ++oz)
284 {
285 auto p_out_base = out_ptr + oz * output_stride_z;
286 // Step 1
287 {
288 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
289 const auto vk = internal_vdupq_n(*k_val);
290 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
291 {
292 const int offset_xy = ih * input_stride_y;
293 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
294 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
295 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
296 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100297 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100298 }
299 }
300 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100301
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100302 // Step 2
303 for(int p = 1; p < kernel_depth; ++p)
304 {
305 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
306 const auto vk = internal_vdupq_n(*k_val);
307 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
308 {
309 const int offset_xy = ih * input_stride_y;
310 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
311 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
312 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
313 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100314 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100315 }
316 }
317 }
318 }
319 },
320 in, out);
321 }
322};
323
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000324#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100325
326template <unsigned int stridex>
327void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
328
329template <>
330void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
331{
332 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
333 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
334}
335
336template <>
337void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
338{
339 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
340}
341
342template <>
343void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
344{
345 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
346}
347
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000348#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100349
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100350template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100351float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100352 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100353
354inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
355{
356 const float32x4x3_t m00 =
357 {
358 {
359 vld1q_dup_f32(m0),
360 vld1q_dup_f32(m1),
361 vld1q_dup_f32(m2)
362 }
363 };
364 return m00;
365}
366
367inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
368{
369 const float32x4x2_t m00 =
370 {
371 {
372 vld1q_dup_f32(m3),
373 vld1q_dup_f32(m4)
374 }
375 };
376 return m00;
377}
378
379inline float32x4x3_t load_input(const float *const in)
380{
381 const float32x4x3_t vin =
382 {
383 {
384 vld1q_f32(in),
385 vld1q_f32(in + 4),
386 vld1q_f32(in + 8)
387 }
388 };
389 return vin;
390}
391
392template <>
393inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100394 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100395{
Pablo Tello06da39d2017-08-10 15:10:40 +0100396 const float32x4x3_t vin0 = load_input(in_0);
397 const float32x4x3_t vin1 = load_input(in_1);
398 const float32x4x3_t vin2 = load_input(in_2);
399 const float32x4x3_t vin3 = load_input(in_3);
400 const float32x4x3_t vin4 = load_input(in_4);
401 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
402 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
403 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
404 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
405 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
406 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
407 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
408 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
409 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
410 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
411
412 float32x4x2_t out =
413 {
414 {
415 vmulq_f32(vin0.val[0], m00.val[0]),
416 vmulq_f32(vin0.val[1], m00.val[0])
417 }
418 };
419
420 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
421 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
422 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
423 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
424
425 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
426 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
427 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
428 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
429 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
430
431 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
432 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
433 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
434 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
435 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
436
437 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
438 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
439 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
440 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
441 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
442
443 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
444 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
445 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
446 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
447 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
448
449 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
450 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
451 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
452 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
453
454 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
455 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
456 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
457 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
458 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
459
460 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
461 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
462 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
463 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
464 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
465
466 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
467 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
468 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
469 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
470 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
471
472 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
473 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
474 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
475 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
476 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
477
478 return out;
479}
480
481template <>
482inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100483 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100484{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100485 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100486 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
487 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
488 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
489 return out;
490}
491
492template <>
493inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100494 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100495{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100496 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100497 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
498 return out;
499}
500
501template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100502void accumulate_results(float *buffer, const float32x4x2_t &values);
503
504template <>
505void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
506{
507 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
508 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
509}
510
511template <>
512void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
513{
514 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
515}
516
517template <>
518void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
519{
520 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
521}
522
Giorgio Arenac0f54432018-03-16 14:02:34 +0000523template <typename T1>
524class convolver_nhwc
525{
526public:
527 static void convolve(const Window &window, int kernel_size, unsigned int num_elems_read_per_iteration,
528 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
529 {
530 const int input_width = input->info()->dimension(0);
531 const int input_depth = input->info()->dimension(2);
532 const int input_stride_x = input->info()->strides_in_bytes().x();
533 const int input_stride_y = input->info()->strides_in_bytes().y();
534 const int input_stride_z = input->info()->strides_in_bytes().z();
535 const int output_stride_x = output->info()->strides_in_bytes().x();
536 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
537 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
538 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
539 const int conv_pad_top = conv_info.pad_top();
540 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
541 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
542 const T1 zero = 0;
543
544 // Setup input window for the input iterator
545 Window window_in = window;
546 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
547 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
548 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
549
550 // Setup input window for the output iterator
551 Window window_out = window;
552 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
553
554 // Setup input window for the weights iterator
555 Window window_k = calculate_max_window(*weights->info(), Steps());
556 window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
557 window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
558 window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
559 window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
560
561 Iterator in(input, window_in);
562 Iterator out(output, window_out);
563 Iterator k(weights, window_k);
564
565 execute_window_loop(window_k, [&](const Coordinates & id_k)
566 {
567 execute_window_loop(window_out, [&](const Coordinates & id)
568 {
569 const auto in_y = static_cast<int>(id.y() * conv_stride_x - conv_info.pad_left());
570 const auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
571
572 const uint8_t *in_ptr = in.ptr() + in_y * input_stride_y + in_z * input_stride_z;
573 uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x;
574
575 T1 out_val = 0;
576
577 auto in_addr_base0 = in_ptr;
578 auto we_addr_base0 = k.ptr();
579
580 for(int z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z)
581 {
582 const int in_z = id.z() * conv_stride_y + z - conv_pad_top;
583
584 if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom
585 {
586 auto in_addr_base1 = in_addr_base0;
587 auto we_addr_base1 = we_addr_base0;
588
589 for(int y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y)
590 {
591 auto out_values = internal_vdupq_n(zero);
592
593 int x = 0;
594 int no_leftover = input_width - num_elems_read_per_iteration;
595
596 for(; x < no_leftover; x += num_elems_read_per_iteration)
597 {
598 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
599 const auto in_values = internal_vld1q<1>(in_addr);
600
601 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
602 const auto we_values = internal_vld1q<1>(we_addr);
603
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100604 out_values = internal_vmlal(out_values, in_values, we_values);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000605 }
606
Michalis Spyrou201c37c2018-10-25 17:25:54 +0100607 auto carry_addition = wrapper::vpadd(wrapper::vgethigh(out_values), wrapper::vgetlow(out_values));
608 carry_addition = wrapper::vpadd(carry_addition, carry_addition);
609 out_val += wrapper::vgetlane(carry_addition, 0);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000610
611 // Leftover
612 for(; x < input_width; ++x)
613 {
614 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
615 const auto in_value = *(in_addr);
616
617 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
618 const auto we_value = *(we_addr);
619
620 out_val += in_value * we_value;
621 }
622 }
623 }
624 }
625
626 *(reinterpret_cast<T1 *>(out_ptr)) = out_val;
627 },
628 in, out);
629 },
630 k);
631 }
632};
633
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634template <typename T1, typename T2, unsigned int stridex>
635class convolver_3x3
636{
637public:
638 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
639 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
640 {
641 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100642 const int input_stride_x = input->info()->strides_in_bytes().x();
643 const int input_stride_y = input->info()->strides_in_bytes().y();
644 const int input_stride_z = input->info()->strides_in_bytes().z();
645 const int output_stride_y = output->info()->strides_in_bytes().y();
646 const int output_stride_z = output->info()->strides_in_bytes().z();
647 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
648 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
649 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
650 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
651 const int output_w = output->info()->dimension(0);
652 const int output_h = output->info()->dimension(1);
653 const int num_planes_z = window.z().end() - window.z().start();
654 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
655 const int kernel_depth = weights->info()->dimension(Window::DimZ);
656 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
657 const unsigned int conv_pad_left = conv_info.pad_left();
658 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100659
660 // setup output window for the iterator
661 Window window_out = window;
662 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
663 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
664 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
665
666 // setup input window for the iterator
667 Window window_in = window;
668 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
669 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
670 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
671 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
672
673 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
674
675 Iterator out(output, window_out);
676 Iterator in(input, window_in);
677 Iterator k(weights, window_k);
678
679 const uint8_t *k_ptr = k.ptr();
680
681 execute_window_loop(window_out, [&](const Coordinates & id)
682 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000683 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100684 uint8_t *out_ptr = out.ptr();
685 int ih = 0;
686 int oh = 0;
687 /*
688 Each thread executing this kernel computes one or more output's volume planes.
689
690 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
691 the third thread [16,24] and the fourth thread [25,31].
692
693 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100694 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100695
696 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
697 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
698 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
699 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100700 for(int oz = 0; oz < num_planes_z; ++oz)
701 {
Pablo Tello0d176142017-07-06 16:43:14 +0100702 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100703 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
704 // Step 1
705 {
Pablo Tello0d176142017-07-06 16:43:14 +0100706 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
707 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
708 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100709 const auto vk_r0 = load_matrix_row(ptr_k_r0);
710 const auto vk_r1 = load_matrix_row(ptr_k_r1);
711 const auto vk_r2 = load_matrix_row(ptr_k_r2);
712 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
713 {
714 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
715 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
716 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
717 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
718 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
719 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
720 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100721 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100722 store_results<stridex>(p_out, vres);
723 }
724 }
725 }
726 // Step 2
727 for(int p = 1; p < kernel_depth; ++p)
728 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100729 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
730 const uint8_t *input_base = input_ptr + p * input_stride_z;
731 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
732 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
733 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
734 const auto vk_r0 = load_matrix_row(ptr_k_r0);
735 const auto vk_r1 = load_matrix_row(ptr_k_r1);
736 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100737 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
738 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100739 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
740 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
741 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100742 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
743 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
744 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
745 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100746 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100747 accumulate_results<stridex>(p_out, vres);
748 }
749 }
750 }
751 }
752 },
753 in, out);
754 }
755};
756
Pablo Tello06da39d2017-08-10 15:10:40 +0100757template <typename T1, typename T2, unsigned int stridex>
758class convolver_5x5
759{
760public:
761 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
762 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
763 {
764 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100765 const int input_stride_x = input->info()->strides_in_bytes().x();
766 const int input_stride_y = input->info()->strides_in_bytes().y();
767 const int input_stride_z = input->info()->strides_in_bytes().z();
768 const int output_stride_y = output->info()->strides_in_bytes().y();
769 const int output_stride_z = output->info()->strides_in_bytes().z();
770 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
771 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
772 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
773 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
774 const int output_w = output->info()->dimension(0);
775 const int output_h = output->info()->dimension(1);
776 const int num_planes_z = window.z().end() - window.z().start();
777 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
778 const int kernel_depth = weights->info()->dimension(Window::DimZ);
779 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
780 const unsigned int conv_pad_left = conv_info.pad_left();
781 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Tello06da39d2017-08-10 15:10:40 +0100782
783 // setup output window for the iterator
784 Window window_out = window;
785 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
786 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
787 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
788
789 // setup input window for the iterator
790 Window window_in = window;
791 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
792 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
793 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
794 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
795
796 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
797
798 Iterator out(output, window_out);
799 Iterator in(input, window_in);
800 Iterator k(weights, window_k);
801
802 const uint8_t *k_ptr = k.ptr();
803
804 execute_window_loop(window_out, [&](const Coordinates & id)
805 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000806 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Pablo Tello06da39d2017-08-10 15:10:40 +0100807 uint8_t *out_ptr = out.ptr();
808 int ih = 0;
809 int oh = 0;
810 for(int oz = 0; oz < num_planes_z; ++oz)
811 {
812 const int zoffset = id.z() + oz;
813 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
814 // Step 1
815 {
816 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
817 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
818 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
819 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
820 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
821 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
822 {
823 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
824 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
825 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
826 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
827 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
828 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
829 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
830 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
831 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100832 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100833 store_results<stridex>(p_out, vres);
834 }
835 }
836 }
837 // Step 2
838 for(int p = 1; p < kernel_depth; ++p)
839 {
840 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
841 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
842 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
843 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
844 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
845
846 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
847 {
848 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
849 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
850 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
851 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
852 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
853 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
854 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
855 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
856 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100857 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100858 accumulate_results<stridex>(p_out, vres);
859 }
860 }
861 }
862 }
863 },
864 in, out);
865 }
866};
867
Gian Marco Iodice95f93612019-06-13 15:58:32 +0100868inline void convolve_row1x9_nhwc(const float *row_ptr, const float *weights_ptr, size_t src_stride_y, size_t weights_stride_y,
869 float32x4_t &acc0, float32x4_t &acc1, float32x4_t &acc2, float32x4_t &acc3)
870{
871 // Load 4 channels for each of the 12 inputs values along the same X spatial dimension
872 const float32x4_t src0 = wrapper::vloadq(row_ptr);
873 const float32x4_t src1 = wrapper::vloadq(row_ptr + 1 * src_stride_y);
874 const float32x4_t src2 = wrapper::vloadq(row_ptr + 2 * src_stride_y);
875 const float32x4_t src3 = wrapper::vloadq(row_ptr + 3 * src_stride_y);
876 const float32x4_t src4 = wrapper::vloadq(row_ptr + 4 * src_stride_y);
877 const float32x4_t src5 = wrapper::vloadq(row_ptr + 5 * src_stride_y);
878 const float32x4_t src6 = wrapper::vloadq(row_ptr + 6 * src_stride_y);
879 const float32x4_t src7 = wrapper::vloadq(row_ptr + 7 * src_stride_y);
880 const float32x4_t src8 = wrapper::vloadq(row_ptr + 8 * src_stride_y);
881 const float32x4_t src9 = wrapper::vloadq(row_ptr + 9 * src_stride_y);
882 const float32x4_t src10 = wrapper::vloadq(row_ptr + 10 * src_stride_y);
883 const float32x4_t src11 = wrapper::vloadq(row_ptr + 11 * src_stride_y);
884
885 // Load 4 channels for each of the 9 weights values along the same X spatial dimension
886 const float32x4_t w0 = wrapper::vloadq(weights_ptr);
887 const float32x4_t w1 = wrapper::vloadq(weights_ptr + 1 * weights_stride_y);
888 const float32x4_t w2 = wrapper::vloadq(weights_ptr + 2 * weights_stride_y);
889 const float32x4_t w3 = wrapper::vloadq(weights_ptr + 3 * weights_stride_y);
890 const float32x4_t w4 = wrapper::vloadq(weights_ptr + 4 * weights_stride_y);
891 const float32x4_t w5 = wrapper::vloadq(weights_ptr + 5 * weights_stride_y);
892 const float32x4_t w6 = wrapper::vloadq(weights_ptr + 6 * weights_stride_y);
893 const float32x4_t w7 = wrapper::vloadq(weights_ptr + 7 * weights_stride_y);
894 const float32x4_t w8 = wrapper::vloadq(weights_ptr + 8 * weights_stride_y);
895
896 // Store 4 channels for each of the 4 output values along the same X spatial dimension
897 acc0 = wrapper::vmla(acc0, w0, src0);
898 acc0 = wrapper::vmla(acc0, w1, src1);
899 acc0 = wrapper::vmla(acc0, w2, src2);
900 acc0 = wrapper::vmla(acc0, w3, src3);
901 acc0 = wrapper::vmla(acc0, w4, src4);
902 acc0 = wrapper::vmla(acc0, w5, src5);
903 acc0 = wrapper::vmla(acc0, w6, src6);
904 acc0 = wrapper::vmla(acc0, w7, src7);
905 acc0 = wrapper::vmla(acc0, w8, src8);
906
907 acc1 = wrapper::vmla(acc1, w0, src1);
908 acc1 = wrapper::vmla(acc1, w1, src2);
909 acc1 = wrapper::vmla(acc1, w2, src3);
910 acc1 = wrapper::vmla(acc1, w3, src4);
911 acc1 = wrapper::vmla(acc1, w4, src5);
912 acc1 = wrapper::vmla(acc1, w5, src6);
913 acc1 = wrapper::vmla(acc1, w6, src7);
914 acc1 = wrapper::vmla(acc1, w7, src8);
915 acc1 = wrapper::vmla(acc1, w8, src9);
916
917 acc2 = wrapper::vmla(acc2, w0, src2);
918 acc2 = wrapper::vmla(acc2, w1, src3);
919 acc2 = wrapper::vmla(acc2, w2, src4);
920 acc2 = wrapper::vmla(acc2, w3, src5);
921 acc2 = wrapper::vmla(acc2, w4, src6);
922 acc2 = wrapper::vmla(acc2, w5, src7);
923 acc2 = wrapper::vmla(acc2, w6, src8);
924 acc2 = wrapper::vmla(acc2, w7, src9);
925 acc2 = wrapper::vmla(acc2, w8, src10);
926
927 acc3 = wrapper::vmla(acc3, w0, src3);
928 acc3 = wrapper::vmla(acc3, w1, src4);
929 acc3 = wrapper::vmla(acc3, w2, src5);
930 acc3 = wrapper::vmla(acc3, w3, src6);
931 acc3 = wrapper::vmla(acc3, w4, src7);
932 acc3 = wrapper::vmla(acc3, w5, src8);
933 acc3 = wrapper::vmla(acc3, w6, src9);
934 acc3 = wrapper::vmla(acc3, w7, src10);
935 acc3 = wrapper::vmla(acc3, w8, src11);
936}
937
938float vreduce(const float32x4_t &v)
939{
940 auto v0 = wrapper::vgethigh(v);
941 auto v1 = wrapper::vgetlow(v);
942 auto v_out = wrapper::vadd(v0, v1);
943
944 float a = wrapper::vgetlane(v_out, 0);
945 float b = wrapper::vgetlane(v_out, 1);
946 return a + b;
947}
948
949template <typename V>
950class convolver_9x9_nhwc
951{
952public:
953 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration,
954 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
955 {
956 // Declare useful types
957 using vector_type = typename V::type;
958 using scalar_type = typename V::scalar_type;
959 using tag_type = typename V::tag_type;
960
961 // Scalar quantities
962 const int element_size = input->info()->element_size();
963 const int input_width = input->info()->dimension(0);
964 const int input_depth = input->info()->dimension(2);
965 const int input_stride_y = input->info()->strides_in_bytes().y() / element_size;
966 const int input_stride_z = input->info()->strides_in_bytes().z() / element_size;
967 const int input_stride_w = input->info()->strides_in_bytes()[3];
968 const int output_stride_x = output->info()->strides_in_bytes().x();
969 const int output_stride_y = output->info()->strides_in_bytes().y();
970 const int kernel_stride_y = weights->info()->strides_in_bytes().y() / element_size;
971 const int kernel_stride_z = weights->info()->strides_in_bytes().z() / element_size;
972 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
973 const unsigned int conv_pad_top = conv_info.pad_top();
974 const unsigned int conv_pad_left = conv_info.pad_left();
975
976 // Setup input window for the input iterator
977 Window window_in = window;
978 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
979 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
980 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
981
982 // Setup input window for the output iterator
983 Window window_out = window;
984 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
985
986 // Setup input window for the weights iterator
987 Window window_k = calculate_max_window(*weights->info(), Steps());
988 window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
989 window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
990 window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
991 window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
992
993 Iterator in(input, window_in);
994 Iterator out(output, window_out);
995 Iterator k(weights, window_k);
996
997 // Calculate the max_offset.
998 // max_offset is the offset for the last NOT valid value in the Z dimension (spatial dimension Y for NHWC)
999 // |******************|
1000 // | pad_top |
1001 // |******************|
1002 // | |
1003 // | plane0 |
1004 // | batch0 |
1005 // |__________________|
1006 // |******************| Batch 0
1007 // | pad_bottom |
1008 // | pad_top |
1009 // |******************|
1010 // | |
1011 // | plane1 |
1012 // | batch0 |
1013 // |__________________|-----> max_offset
1014 // |******************|
1015 // | pad_bottom |
1016 // | pad_top |
1017 // |******************|
1018 // | |
1019 // | plane0 |
1020 // | batch1 |
1021 // |__________________|
1022 // |******************| Batch 1
1023 // | pad_bottom |
1024 // | pad_top |
1025 // |******************|
1026 // | |
1027 // | plane1 |
1028 // | batch1 |
1029 // |__________________|
1030 // | pad_bottom |
1031 // |******************|
1032 const int max_offset = input_stride_z * input_depth - (input->info()->padding().bottom + input->info()->padding().top) * input_stride_y;
1033 execute_window_loop(window_k, [&](const Coordinates & id_k) // loop on the batch size
1034 {
1035
1036 execute_window_loop(window_out, [&](const Coordinates & id)
1037 {
1038 const auto y_offset = int(id.y() - conv_pad_left) * input_stride_y;
1039
1040 // Buffer pointers
1041 const scalar_type *in_ptr = reinterpret_cast<scalar_type *>(input->buffer() + input->info()->offset_first_element_in_bytes() + id[3] * input_stride_w);
1042 const scalar_type *weights_ptr = reinterpret_cast<scalar_type *>(k.ptr());
1043 uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x;
1044
1045 // Output elements
1046 vector_type out0 = wrapper::vdup_n(scalar_type(0), tag_type());
1047 vector_type out1 = wrapper::vdup_n(scalar_type(0), tag_type());
1048 vector_type out2 = wrapper::vdup_n(scalar_type(0), tag_type());
1049 vector_type out3 = wrapper::vdup_n(scalar_type(0), tag_type());
1050
1051 // Reduce along the feature maps
1052 for(int x = 0; x < input_width; x += num_elems_read_per_iteration)
1053 {
1054 // z == 0
1055 auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
1056 in_z = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
1057 auto offset = y_offset + in_z * input_stride_z;
1058 offset = std::min(offset, max_offset);
1059 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 0 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1060
1061 // z == 1
1062 in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top + 1);
1063 in_z = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
1064 offset = y_offset + in_z * input_stride_z;
1065 offset = std::min(offset, max_offset);
1066 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 1 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1067
1068 // z == 2
1069 in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top + 2);
1070 in_z = std::min(static_cast<unsigned int>(in_z), static_cast<unsigned int>(input_depth));
1071 offset = y_offset + in_z * input_stride_z;
1072 offset = std::min(offset, max_offset);
1073 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 2 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1074
1075 // z == 3
1076 in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top + 3);
1077 offset = y_offset + in_z * input_stride_z;
1078 offset = std::min(offset, max_offset);
1079 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 3 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1080
1081 // z == 4
1082 in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top + 4);
1083 offset = y_offset + in_z * input_stride_z;
1084 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 4 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1085
1086 // z == 5
1087 offset += input_stride_z;
1088 offset = std::min(offset, max_offset);
1089 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 5 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1090
1091 // z == 6
1092 offset += input_stride_z;
1093 offset = std::min(offset, max_offset);
1094 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 6 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1095
1096 // z == 7
1097 offset += input_stride_z;
1098 offset = std::min(offset, max_offset);
1099 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 7 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1100
1101 // z == 8
1102 offset += input_stride_z;
1103 offset = std::min(offset, max_offset);
1104 convolve_row1x9_nhwc(in_ptr + offset + x, weights_ptr + 8 * kernel_stride_z + x, input_stride_y, kernel_stride_y, out0, out1, out2, out3);
1105 }
1106
1107 *(reinterpret_cast<scalar_type *>(out_ptr + 0 * output_stride_y)) = vreduce(out0);
1108 *(reinterpret_cast<scalar_type *>(out_ptr + 1 * output_stride_y)) = vreduce(out1);
1109 *(reinterpret_cast<scalar_type *>(out_ptr + 2 * output_stride_y)) = vreduce(out2);
1110 *(reinterpret_cast<scalar_type *>(out_ptr + 3 * output_stride_y)) = vreduce(out3);
1111 },
1112 in, out);
1113 },
1114 k);
1115 }
1116};
1117
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001118template <typename T1, typename T2>
1119inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1120 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1121{
1122 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1123 switch(conv_stride_x)
1124 {
1125 case 1:
1126 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1127 break;
1128 case 2:
1129 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1130 break;
1131 case 3:
1132 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1133 break;
1134 default:
1135 ARM_COMPUTE_ERROR("Not implemented");
1136 }
1137}
1138
Pablo Telloc09314a2017-09-21 13:59:14 +01001139template <>
1140inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1141 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1142{
1143 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1144 if(run_optim_small_tensor(input))
1145 {
1146 switch(conv_stride_x)
1147 {
1148 case 1:
1149 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
1150 break;
1151 case 2:
1152 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
1153 break;
1154 case 3:
1155 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
1156 break;
1157 default:
1158 ARM_COMPUTE_ERROR("Not implemented");
1159 }
1160 }
1161 else
1162 {
1163 switch(conv_stride_x)
1164 {
1165 case 1:
1166 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1167 break;
1168 case 2:
1169 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1170 break;
1171 case 3:
1172 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1173 break;
1174 default:
1175 ARM_COMPUTE_ERROR("Not implemented");
1176 }
1177 }
1178}
1179
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001180template <typename T1, typename T2>
1181inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1182 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1183{
1184 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1185 switch(conv_stride_x)
1186 {
1187 case 1:
1188 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1189 break;
1190 case 2:
1191 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1192 break;
1193 case 3:
1194 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1195 break;
1196 default:
1197 ARM_COMPUTE_ERROR("Not implemented");
1198 }
1199}
Pablo Tello06da39d2017-08-10 15:10:40 +01001200
1201template <typename T1, typename T2>
1202inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1203 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1204{
1205 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1206 switch(conv_stride_x)
1207 {
1208 case 1:
1209 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1210 break;
1211 case 2:
1212 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1213 break;
1214 case 3:
1215 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1216 break;
1217 default:
1218 ARM_COMPUTE_ERROR("Not implemented");
1219 }
1220}
1221
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001222template <typename V>
1223inline void convolve_9x9_nhwc(const Window &window, unsigned int num_elems_read_per_iteration,
1224 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1225{
1226 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1227 switch(conv_stride_x)
1228 {
1229 case 1:
1230 convolver_9x9_nhwc<V>::convolve(window, num_elems_read_per_iteration, input, weights, output, conv_info);
1231 break;
1232 default:
1233 ARM_COMPUTE_ERROR("Not implemented");
1234 }
1235}
1236
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001237Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1238{
1239 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001240 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Anthony Barbiereaefd002018-07-20 17:49:35 +01001241 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001242 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001243 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001244
Giorgio Arenac0f54432018-03-16 14:02:34 +00001245 const DataLayout data_layout = input->data_layout();
1246 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
1247 const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
1248 const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
1249
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001250 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Giorgio Arenac0f54432018-03-16 14:02:34 +00001251 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx));
1252 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001253 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001254 ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
Gian Marco Iodice41acb762018-08-23 10:25:06 +01001255 ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (input->data_type() == DataType::F16));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001256
1257 // Checks performed when output is configured
1258 if(output->total_size() != 0)
1259 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001260 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001261
1262 DataType data_type = input->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001263
1264 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
1265 ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
1266 }
1267
1268 return Status{};
1269}
1270
1271std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001272 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001273{
Giorgio Arenac0f54432018-03-16 14:02:34 +00001274 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
1275
1276 const DataLayout data_layout = input->data_layout();
1277 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
1278
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001279 // Calculate right and bottom border
Giorgio Arenac0f54432018-03-16 14:02:34 +00001280 unsigned int kernel_size = weights->dimension(width_idx);
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001281 const int conv_stride_x = std::get<0>(conv_info.stride());
Georgios Pinitas1a03d762018-02-21 14:47:09 +00001282 const int conv_stride_y = std::get<1>(conv_info.stride());
Giorgio Arenac0f54432018-03-16 14:02:34 +00001283 const int input_width = input->dimension(width_idx);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001284
Giorgio Arenac0f54432018-03-16 14:02:34 +00001285 Window win{};
1286 bool window_changed = false;
1287
1288 if(data_layout == DataLayout::NCHW)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001289 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001290 switch(kernel_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001291 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001292 case 1:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001293 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001294 switch(input->data_type())
1295 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001296#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001297 case DataType::F16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001298 num_elems_written_per_iteration = 8;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001299 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001300#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001301 case DataType::F32:
1302 if(run_optim_small_tensor_info(input))
1303 {
1304 num_elems_written_per_iteration = 8;
1305 }
1306 else
1307 {
1308 num_elems_written_per_iteration = 4;
1309 }
1310 break;
1311 default:
1312 ARM_COMPUTE_ERROR("Data type not supported.");
1313 break;
1314 }
1315 num_weight_elems_read_per_row = kernel_size;
1316 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
1317 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001318 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001319 case 3:
Giorgio Arenac0f54432018-03-16 14:02:34 +00001320 switch(input->data_type())
1321 {
1322 case DataType::F32:
1323 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1324 num_elems_read_per_iteration = 12;
1325 num_elems_written_per_iteration = 16 >> conv_stride_x;
1326 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001327#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001328 case DataType::F16:
Giorgio Arenac0f54432018-03-16 14:02:34 +00001329 num_weight_elems_read_per_row = 8 + kernel_size - 1;
1330 num_elems_read_per_iteration = 24;
1331 num_elems_written_per_iteration = 32 >> conv_stride_x;
1332 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001333#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001334 default:
1335 ARM_COMPUTE_ERROR("Data type not supported.");
1336 break;
1337 }
Gian Marco Iodice41acb762018-08-23 10:25:06 +01001338 break;
1339 case 5:
1340 {
1341 switch(input->data_type())
1342 {
1343 case DataType::F32:
1344 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1345 num_elems_read_per_iteration = 12;
1346 num_elems_written_per_iteration = 16 >> conv_stride_x;
1347 break;
1348 default:
1349 ARM_COMPUTE_ERROR("Data type not supported.");
1350 break;
1351 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001352 }
1353 break;
1354 default:
1355 {
1356 ARM_COMPUTE_ERROR("Not implemented");
1357 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001358 }
1359 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001360
1361 // Calculate right pad
1362 int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
1363 int end_x = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
1364 int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
1365
1366 // Calculate border
1367 const unsigned int conv_pad_left = conv_info.pad_left();
1368 const unsigned int conv_pad_top = conv_info.pad_top();
1369 const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
1370 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1371
1372 border_size.left = conv_pad_left;
1373 border_size.top = conv_pad_top;
1374 border_size.right = conv_pad_right;
1375 border_size.bottom = conv_pad_bottom;
1376
1377 // Configure window
1378 win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
1379
1380 AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
1381 num_elems_read_per_iteration, kernel_size,
1382 conv_stride_x, conv_stride_y);
1383 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
1384 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
1385 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1386 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001387 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001388 else
1389 {
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001390 if(kernel_size == 9)
1391 {
1392 border_size.left = 0;
1393 border_size.top = conv_info.pad_left();
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001394
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001395 const int num_elems_read_per_iteration_x = 4;
1396 const int num_elems_written_per_iteration_x = 1;
1397 const int num_elems_read_per_iteration_y = 12;
1398 const int num_elems_written_per_iteration_y = 4;
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001399
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001400 num_elems_read_per_iteration = num_elems_read_per_iteration_x;
1401 num_elems_written_per_iteration = num_elems_written_per_iteration_x;
Michalis Spyrou621965e2018-01-08 17:11:26 +00001402
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001403 border_size.right = num_elems_read_per_iteration_x;
1404 if((conv_info.pad_bottom() != 0) || (conv_info.pad_top() != 0))
1405 {
1406 // If bottom or top padding are set, we need to read num_elems_read_per_iteration_y rows to zero.
1407 // Since num_elems_read_per_iteration_y is always greater than conv_info.pad_right() we can set
1408 // the bottom padding to num_elems_read_per_iteration_y
1409 border_size.bottom = num_elems_read_per_iteration_y;
1410 }
1411 else if(conv_info.pad_right() != 0)
1412 {
1413 // Convetional border padding. Fill the bottom paddings so that we can read in batch of num_elems_read_per_iteration_y
1414 border_size.bottom = ceil_to_multiple(input->dimension(1) + conv_info.pad_right(), num_elems_read_per_iteration_y) - input->dimension(1);
1415 }
1416 else
1417 {
1418 // No padding
1419 border_size.bottom = 0;
1420 }
1421
1422 win = calculate_max_window(*output, Steps(num_elems_written_per_iteration_x, num_elems_written_per_iteration_y));
1423
1424 AccessWindowStatic input_access(input, 0, -border_size.top,
1425 ceil_to_multiple(input->dimension(0), num_elems_read_per_iteration_x),
1426 input->dimension(1) + border_size.bottom);
1427
1428 AccessWindowStatic weights_access(weights, 0, 0, ceil_to_multiple(weights->dimension(0), num_elems_read_per_iteration_x), weights->dimension(1));
1429 AccessWindowRectangle output_access(output, 0, 0, num_elems_written_per_iteration_x, num_elems_written_per_iteration_y);
1430 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1431 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
1432 }
1433 else
1434 {
1435 border_size.left = 0;
1436 border_size.top = conv_info.pad_left();
1437 border_size.right = 0;
1438 border_size.bottom = conv_info.pad_right();
1439 num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type());
1440 win = calculate_max_window(*output, Steps());
1441
1442 AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x);
1443 AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size);
1444 window_changed = update_window_and_padding(win, input_access, weights_access);
1445 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001446 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001447
1448 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1449 return std::make_pair(err, win);
1450}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001451} // namespace
1452
1453NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001454 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1455 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001456{
1457}
1458
1459BorderSize NEDirectConvolutionLayerKernel::border_size() const
1460{
1461 return _border_size;
1462}
1463
1464void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1465{
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001466 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001467
1468 _input = input;
1469 _weights = weights;
1470 _output = output;
1471 _conv_info = conv_info;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001472 _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001473
1474 const unsigned int conv_pad_left = conv_info.pad_left();
1475 const unsigned int conv_pad_top = conv_info.pad_top();
1476 const unsigned int conv_pad_right = conv_info.pad_right();
1477 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1478 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001479
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001480 // Get convolved dimensions
Giorgio Arenac0f54432018-03-16 14:02:34 +00001481 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001482
1483 DataType data_type = input->info()->data_type();
1484
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001485 // Output auto inizialitation if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001486 auto_init_if_empty(*output->info(), output_shape, 1, data_type);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001487
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001488 // Perform validation step
1489 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001490
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001491 // Configure kernel window
1492 auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001493 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001494 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1495 INEKernel::configure(win_config.second);
1496}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001497
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001498Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1499{
1500 unsigned int num_weight_elems_read_per_row = 0;
1501 unsigned int num_elems_read_per_iteration = 0;
1502 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas15997872018-02-19 13:58:22 +00001503 BorderSize border_size = {};
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001504 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info));
Georgios Pinitas0223a782017-12-12 11:44:44 +00001505 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1506 weights->clone().get(),
1507 output->clone().get(),
1508 conv_info,
1509 num_weight_elems_read_per_row,
1510 num_elems_read_per_iteration,
1511 num_elems_written_per_iteration,
1512 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001513 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001514
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001515 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001516}
1517
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001518void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001519{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001520 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001521 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1522 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1523 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1524
Giorgio Arenac0f54432018-03-16 14:02:34 +00001525 const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001526
Giorgio Arenac0f54432018-03-16 14:02:34 +00001527 if(_input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001528 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001529 switch(kernel_size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001530 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001531 case 1:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001532 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001533 switch(_input->info()->data_type())
1534 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001535 case DataType::F32:
1536 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1537 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001538#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001539 case DataType::F16:
1540 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1541 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001542#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001543 default:
1544 ARM_COMPUTE_ERROR("Data type not supported");
1545 break;
1546 }
1547 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001548 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001549 case 3:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001550 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001551 switch(_input->info()->data_type())
1552 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001553 case DataType::F32:
1554 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1555 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001556#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001557 case DataType::F16:
1558 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1559 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001560#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001561 default:
1562 ARM_COMPUTE_ERROR("Data type not supported");
1563 break;
1564 }
1565 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001566 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001567 case 5:
Pablo Tello06da39d2017-08-10 15:10:40 +01001568 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001569 switch(_input->info()->data_type())
1570 {
1571 case DataType::F32:
1572 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1573 break;
1574 default:
1575 ARM_COMPUTE_ERROR("Data type not supported");
1576 break;
1577 }
1578 break;
Pablo Tello06da39d2017-08-10 15:10:40 +01001579 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001580 default:
1581 {
1582 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
1583 break;
1584 }
1585 }
1586 }
1587 else
1588 {
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001589 const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
1590 const int stride_x = std::get<0>(_conv_info.stride());
1591 const int stride_y = std::get<1>(_conv_info.stride());
1592
Giorgio Arenac0f54432018-03-16 14:02:34 +00001593 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001594 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001595 case DataType::F32:
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001596 {
1597 if(kernel_size == 9 && stride_x == 1 && stride_y == 1)
1598 {
1599 using vtype = wrapper::traits::neon_vector<float, 4>;
1600 convolve_9x9_nhwc<vtype>(window, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
1601 }
1602 else
1603 {
1604 convolver_nhwc<float>::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
1605 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001606 break;
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001607 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001608 default:
1609 ARM_COMPUTE_ERROR("Data type not supported");
1610 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001611 }
1612 }
1613}