blob: db1b5f3c54cfdbb72e0e1f05a1b1eff4b8e41dd3 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Sheri Zhangac6499a2021-02-10 15:32:38 +00002 * Copyright (c) 2017-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/cpu/kernels/CpuDirectConv2dKernel.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010025
26#include "src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
27#include "src/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Error.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/ITensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010034#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/Validate.h"
Giorgio Arenac0f54432018-03-16 14:02:34 +000036#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010037#include "src/core/AccessWindowStatic.h"
38#include "src/core/CPP/Validate.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010039#include "src/core/NEON/NEFixedPoint.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010040#include "src/core/helpers/AutoConfiguration.h"
41#include "src/core/helpers/WindowHelpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
43#include <algorithm>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010045using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010046
Manuel Bottini87350f42020-09-15 13:03:34 +010047namespace arm_compute
48{
Manuel Bottini327225d2021-04-13 13:09:30 +010049namespace cpu
50{
51namespace kernels
52{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010053namespace
54{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000055#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010056template <unsigned int stridex>
57float16x8_t internal_vld1q(const float16_t *in);
58
59template <>
60float16x8_t internal_vld1q<1>(const float16_t *in)
61{
62 return vld1q_f16(in);
63}
64
65template <>
66float16x8_t internal_vld1q<2>(const float16_t *in)
67{
68 const float16x8x2_t tmp = vld2q_f16(in);
69 return tmp.val[0];
70}
71
72template <>
73float16x8_t internal_vld1q<3>(const float16_t *in)
74{
75 const float16x8x3_t tmp = vld3q_f16(in);
76 return tmp.val[0];
77}
78
79inline float16x8_t internal_vdupq_n(float16_t v)
80{
81 return vdupq_n_f16(v);
82}
83
84inline void internal_vst1q(float16_t *p, const float16x8_t &v)
85{
86 vst1q_f16(p, v);
87}
88
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010089float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y)
Pablo Tello0d176142017-07-06 16:43:14 +010090{
Pablo Tello0d176142017-07-06 16:43:14 +010091 return vmulq_f16(x, y);
92}
93
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010094inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z)
Pablo Tello0d176142017-07-06 16:43:14 +010095{
Pablo Tello0d176142017-07-06 16:43:14 +010096 return vaddq_f16(x, vmulq_f16(y, z));
97}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000098#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +010099
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100100template <unsigned int stridex>
101float32x4_t internal_vld1q(const float *in);
102
103template <>
104float32x4_t internal_vld1q<1>(const float *in)
105{
106 return vld1q_f32(in);
107}
108
109template <>
110float32x4_t internal_vld1q<2>(const float *in)
111{
112 const float32x4x2_t tmp = vld2q_f32(in);
113 return tmp.val[0];
114}
115
116template <>
117float32x4_t internal_vld1q<3>(const float *in)
118{
119 const float32x4x3_t tmp = vld3q_f32(in);
120 return tmp.val[0];
121}
122
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100123inline float32x4_t internal_vdupq_n(float v)
124{
125 return vdupq_n_f32(v);
126}
127
128inline void internal_vst1q(float *p, const float32x4_t &v)
129{
130 vst1q_f32(p, v);
131}
132
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100133float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100134{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100135 return vmulq_f32(x, y);
136}
137
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100138inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100139{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100140 return vmlaq_f32(x, y, z);
141}
142
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000143constexpr int small_tensor_size_optim = 8;
144inline bool run_optim_small_tensor_info(const ITensorInfo *t)
145{
146 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
147}
148
Pablo Telloc09314a2017-09-21 13:59:14 +0100149inline bool run_optim_small_tensor(const ITensor *t)
150{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000151 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100152}
153
154// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
155// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +0000156// store intermidiate results in memory. Temporary results are stored in SIMD registers directly and then written to the output buffer.
Pablo Telloc09314a2017-09-21 13:59:14 +0100157template <unsigned int stridex>
158class convolver_w1x1_i8x8_f32
159{
160public:
Manuel Bottini327225d2021-04-13 13:09:30 +0100161 static void convolve(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Pablo Telloc09314a2017-09-21 13:59:14 +0100162 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100163 ARM_COMPUTE_ERROR_ON(src->info()->dimension(Window::DimX) > small_tensor_size_optim);
164 ARM_COMPUTE_ERROR_ON(src->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100165
Manuel Bottini327225d2021-04-13 13:09:30 +0100166 const int input_stride_x = src->info()->strides_in_bytes().x();
167 const int input_stride_y = src->info()->strides_in_bytes().y();
168 const int input_stride_z = src->info()->strides_in_bytes().z();
169 const int output_stride_y = dst->info()->strides_in_bytes().y();
170 const int output_stride_z = dst->info()->strides_in_bytes().z();
Pablo Telloc09314a2017-09-21 13:59:14 +0100171 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
172 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
Manuel Bottini327225d2021-04-13 13:09:30 +0100173 const int output_h = dst->info()->dimension(1);
Pablo Telloc09314a2017-09-21 13:59:14 +0100174 const int range_z = window.z().end() - window.z().start();
175 const int kernel_depth = weights->info()->dimension(Window::DimZ);
176 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000177 const unsigned int conv_pad_left = conv_info.pad_left();
178 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Telloc09314a2017-09-21 13:59:14 +0100179
180 // setup output window for the iterator
181 Window window_out = window;
Manuel Bottini327225d2021-04-13 13:09:30 +0100182 window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
183 window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
Pablo Telloc09314a2017-09-21 13:59:14 +0100184 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
185
186 // setup input window for the iterator
187 Window window_in = window;
188 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
189 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
190 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
191 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
192
193 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Manuel Bottini327225d2021-04-13 13:09:30 +0100194 Iterator out(dst, window_out);
195 Iterator in(src, window_in);
Pablo Telloc09314a2017-09-21 13:59:14 +0100196 Iterator k(weights, window_k);
197
198 const uint8_t *k_ptr = k.ptr();
199
200 execute_window_loop(window_out, [&](const Coordinates & id)
201 {
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100202 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
203 uint8_t *out_ptr = out.ptr();
204 int ih = 0;
205 int oh = 0;
206 std::array<float32x4_t, 8> accum0 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
207 std::array<float32x4_t, 8> accum1 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100208 for(int oz = 0; oz < range_z; ++oz)
209 {
210 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
211 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
212 auto p_out_base = out_ptr + oz * output_stride_z;
213 for(int p = 0; p < kernel_depth; ++p)
214 {
215 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
216 const auto vk0 = internal_vdupq_n(*k_val);
217 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
218 {
219 const int offset_xy = ih * input_stride_y;
220 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
221 auto v_in0 = internal_vld1q<stridex>(in_val);
222 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
223 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
224 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
225 }
226 }
227 for(oh = 0; oh < output_h; ++oh)
228 {
229 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
230 vst1q_f32(p_out, accum0[oh]);
231 vst1q_f32(p_out + 4, accum1[oh]);
232 }
233 }
234 },
235 in, out);
236 }
237};
238
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100239template <typename T1, typename T2, unsigned int stridex>
240class convolver_1x1
241{
242public:
243 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100244 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100246 const int input_stride_x = src->info()->strides_in_bytes().x();
247 const int input_stride_y = src->info()->strides_in_bytes().y();
248 const int input_stride_z = src->info()->strides_in_bytes().z();
249 const int output_stride_y = dst->info()->strides_in_bytes().y();
250 const int output_stride_z = dst->info()->strides_in_bytes().z();
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100251 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
252 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
Manuel Bottini327225d2021-04-13 13:09:30 +0100253 const int output_w = dst->info()->dimension(0);
254 const int output_h = dst->info()->dimension(1);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100255 const int range_z = window.z().end() - window.z().start();
256 const int kernel_depth = weights->info()->dimension(Window::DimZ);
257 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
258 const unsigned int conv_pad_left = conv_info.pad_left();
259 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100260
261 // setup output window for the iterator
262 Window window_out = window;
Manuel Bottini327225d2021-04-13 13:09:30 +0100263 window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
264 window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100265 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
266
267 // setup input window for the iterator
268 Window window_in = window;
269 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
270 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
271 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
272 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
273
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100274 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Manuel Bottini327225d2021-04-13 13:09:30 +0100275 Iterator out(dst, window_out);
276 Iterator in(src, window_in);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100277 Iterator k(weights, window_k);
278
279 const uint8_t *k_ptr = k.ptr();
280
281 execute_window_loop(window_out, [&](const Coordinates & id)
282 {
283 /*
284 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
285 */
Georgios Pinitas15997872018-02-19 13:58:22 +0000286 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100287 uint8_t *out_ptr = out.ptr();
288 int ih = 0;
289 int oh = 0;
290 for(int oz = 0; oz < range_z; ++oz)
291 {
292 auto p_out_base = out_ptr + oz * output_stride_z;
293 // Step 1
294 {
295 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
296 const auto vk = internal_vdupq_n(*k_val);
297 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
298 {
299 const int offset_xy = ih * input_stride_y;
300 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
301 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
302 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
303 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100304 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100305 }
306 }
307 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100308
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100309 // Step 2
310 for(int p = 1; p < kernel_depth; ++p)
311 {
312 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
313 const auto vk = internal_vdupq_n(*k_val);
314 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
315 {
316 const int offset_xy = ih * input_stride_y;
317 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
318 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
319 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
320 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100321 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100322 }
323 }
324 }
325 }
326 },
327 in, out);
328 }
329};
330
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100331template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100332float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100333 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100334
335inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
336{
337 const float32x4x3_t m00 =
338 {
339 {
340 vld1q_dup_f32(m0),
341 vld1q_dup_f32(m1),
342 vld1q_dup_f32(m2)
343 }
344 };
345 return m00;
346}
347
348inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
349{
350 const float32x4x2_t m00 =
351 {
352 {
353 vld1q_dup_f32(m3),
354 vld1q_dup_f32(m4)
355 }
356 };
357 return m00;
358}
359
360inline float32x4x3_t load_input(const float *const in)
361{
362 const float32x4x3_t vin =
363 {
364 {
365 vld1q_f32(in),
366 vld1q_f32(in + 4),
367 vld1q_f32(in + 8)
368 }
369 };
370 return vin;
371}
372
373template <>
374inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100375 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100376{
Pablo Tello06da39d2017-08-10 15:10:40 +0100377 const float32x4x3_t vin0 = load_input(in_0);
378 const float32x4x3_t vin1 = load_input(in_1);
379 const float32x4x3_t vin2 = load_input(in_2);
380 const float32x4x3_t vin3 = load_input(in_3);
381 const float32x4x3_t vin4 = load_input(in_4);
382 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
383 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
384 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
385 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
386 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
387 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
388 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
389 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
390 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
391 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
392
393 float32x4x2_t out =
394 {
395 {
396 vmulq_f32(vin0.val[0], m00.val[0]),
397 vmulq_f32(vin0.val[1], m00.val[0])
398 }
399 };
400
401 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
402 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
403 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
404 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
405
406 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
407 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
408 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
409 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
410 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
411
412 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
413 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
414 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
415 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
416 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
417
418 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
419 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
420 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
421 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
422 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
423
424 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
425 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
426 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
427 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
428 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
429
430 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
431 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
432 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
433 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
434
435 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
436 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
437 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
438 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
439 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
440
441 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
442 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
443 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
444 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
445 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
446
447 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
448 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
449 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
450 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
451 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
452
453 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
454 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
455 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
456 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
457 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
458
459 return out;
460}
461
462template <>
463inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100464 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100465{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100466 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100467 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
468 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
469 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
470 return out;
471}
472
473template <>
474inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100475 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100476{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100477 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100478 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
479 return out;
480}
481
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100482template <typename T1, typename T2, unsigned int stridex>
483class convolver_3x3
484{
485public:
486 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100487 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100488 {
489 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Manuel Bottini327225d2021-04-13 13:09:30 +0100490 const int input_stride_x = src->info()->strides_in_bytes().x();
491 const int input_stride_y = src->info()->strides_in_bytes().y();
492 const int input_stride_z = src->info()->strides_in_bytes().z();
493 const int output_stride_y = dst->info()->strides_in_bytes().y();
494 const int output_stride_z = dst->info()->strides_in_bytes().z();
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100495 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
496 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
497 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
498 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
Manuel Bottini327225d2021-04-13 13:09:30 +0100499 const int output_w = dst->info()->dimension(0);
500 const int output_h = dst->info()->dimension(1);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100501 const int num_planes_z = window.z().end() - window.z().start();
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000502 const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100503 const int kernel_depth = weights->info()->dimension(Window::DimZ);
504 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
505 const unsigned int conv_pad_left = conv_info.pad_left();
506 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100507
508 // setup output window for the iterator
509 Window window_out = window;
Manuel Bottini327225d2021-04-13 13:09:30 +0100510 window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
511 window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100512 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
513
514 // setup input window for the iterator
515 Window window_in = window;
516 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
517 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
518 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
519 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
520
521 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
522
Manuel Bottini327225d2021-04-13 13:09:30 +0100523 Iterator out(dst, window_out);
524 Iterator in(src, window_in);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100525 Iterator k(weights, window_k);
526
527 const uint8_t *k_ptr = k.ptr();
528
529 execute_window_loop(window_out, [&](const Coordinates & id)
530 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000531 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100532 uint8_t *out_ptr = out.ptr();
533 int ih = 0;
534 int oh = 0;
535 /*
536 Each thread executing this kernel computes one or more output's volume planes.
537
538 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
539 the third thread [16,24] and the fourth thread [25,31].
540
541 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100542 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100543
544 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
545 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
546 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
547 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100548 for(int oz = 0; oz < num_planes_z; ++oz)
549 {
Pablo Tello0d176142017-07-06 16:43:14 +0100550 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100551 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
552 // Step 1
553 {
Pablo Tello0d176142017-07-06 16:43:14 +0100554 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
555 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
556 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100557 const auto vk_r0 = load_matrix_row(ptr_k_r0);
558 const auto vk_r1 = load_matrix_row(ptr_k_r1);
559 const auto vk_r2 = load_matrix_row(ptr_k_r2);
560 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
561 {
562 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
563 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
564 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
565 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
566 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
567 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
568 {
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000569 convolve_3x3<false>(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100570 }
571 }
572 }
573 // Step 2
574 for(int p = 1; p < kernel_depth; ++p)
575 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100576 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
577 const uint8_t *input_base = input_ptr + p * input_stride_z;
578 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
579 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
580 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
581 const auto vk_r0 = load_matrix_row(ptr_k_r0);
582 const auto vk_r1 = load_matrix_row(ptr_k_r1);
583 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100584 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
585 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100586 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
587 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
588 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100589 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
590 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
591 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
592 {
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000593 convolve_3x3<true>(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100594 }
595 }
596 }
597 }
598 },
599 in, out);
600 }
601};
602
Pablo Tello06da39d2017-08-10 15:10:40 +0100603template <typename T1, typename T2, unsigned int stridex>
604class convolver_5x5
605{
606public:
607 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100608 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Pablo Tello06da39d2017-08-10 15:10:40 +0100609 {
610 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Manuel Bottini327225d2021-04-13 13:09:30 +0100611 const int input_stride_x = src->info()->strides_in_bytes().x();
612 const int input_stride_y = src->info()->strides_in_bytes().y();
613 const int input_stride_z = src->info()->strides_in_bytes().z();
614 const int output_stride_y = dst->info()->strides_in_bytes().y();
615 const int output_stride_z = dst->info()->strides_in_bytes().z();
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100616 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
617 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
618 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
619 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
Manuel Bottini327225d2021-04-13 13:09:30 +0100620 const int output_w = dst->info()->dimension(0);
621 const int output_h = dst->info()->dimension(1);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100622 const int num_planes_z = window.z().end() - window.z().start();
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000623 const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100624 const int kernel_depth = weights->info()->dimension(Window::DimZ);
625 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
626 const unsigned int conv_pad_left = conv_info.pad_left();
627 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Tello06da39d2017-08-10 15:10:40 +0100628
629 // setup output window for the iterator
630 Window window_out = window;
Manuel Bottini327225d2021-04-13 13:09:30 +0100631 window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
632 window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
Pablo Tello06da39d2017-08-10 15:10:40 +0100633 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
634
635 // setup input window for the iterator
636 Window window_in = window;
637 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
638 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
639 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
640 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
641
642 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
643
Manuel Bottini327225d2021-04-13 13:09:30 +0100644 Iterator out(dst, window_out);
645 Iterator in(src, window_in);
Pablo Tello06da39d2017-08-10 15:10:40 +0100646 Iterator k(weights, window_k);
647
648 const uint8_t *k_ptr = k.ptr();
649
650 execute_window_loop(window_out, [&](const Coordinates & id)
651 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000652 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Pablo Tello06da39d2017-08-10 15:10:40 +0100653 uint8_t *out_ptr = out.ptr();
654 int ih = 0;
655 int oh = 0;
656 for(int oz = 0; oz < num_planes_z; ++oz)
657 {
658 const int zoffset = id.z() + oz;
659 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
660 // Step 1
661 {
662 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
663 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
664 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
665 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
666 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
667 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
668 {
669 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
670 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
671 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
672 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
673 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
674 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
675 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
676 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
677 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100678 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100679 store_results<stridex>(p_out, vres);
680 }
681 }
682 }
683 // Step 2
684 for(int p = 1; p < kernel_depth; ++p)
685 {
686 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
687 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
688 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
689 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
690 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
691
692 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
693 {
694 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
695 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
696 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
697 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
698 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
699 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
700 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
701 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
702 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100703 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100704 accumulate_results<stridex>(p_out, vres);
705 }
706 }
707 }
708 }
709 },
710 in, out);
711 }
712};
713
Gian Marco Iodice95f93612019-06-13 15:58:32 +0100714float vreduce(const float32x4_t &v)
715{
716 auto v0 = wrapper::vgethigh(v);
717 auto v1 = wrapper::vgetlow(v);
718 auto v_out = wrapper::vadd(v0, v1);
719
720 float a = wrapper::vgetlane(v_out, 0);
721 float b = wrapper::vgetlane(v_out, 1);
722 return a + b;
723}
724
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100725template <typename T1, typename T2>
726inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100727 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100728{
729 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
730 switch(conv_stride_x)
731 {
732 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100733 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100734 break;
735 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100736 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100737 break;
738 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100739 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100740 break;
741 default:
742 ARM_COMPUTE_ERROR("Not implemented");
743 }
744}
745
Pablo Telloc09314a2017-09-21 13:59:14 +0100746template <>
747inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100748 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Pablo Telloc09314a2017-09-21 13:59:14 +0100749{
750 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
Manuel Bottini327225d2021-04-13 13:09:30 +0100751 if(run_optim_small_tensor(src))
Pablo Telloc09314a2017-09-21 13:59:14 +0100752 {
753 switch(conv_stride_x)
754 {
755 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100756 convolver_w1x1_i8x8_f32<1>::convolve(window, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100757 break;
758 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100759 convolver_w1x1_i8x8_f32<2>::convolve(window, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100760 break;
761 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100762 convolver_w1x1_i8x8_f32<3>::convolve(window, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100763 break;
764 default:
765 ARM_COMPUTE_ERROR("Not implemented");
766 }
767 }
768 else
769 {
770 switch(conv_stride_x)
771 {
772 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100773 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100774 break;
775 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100776 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100777 break;
778 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100779 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100780 break;
781 default:
782 ARM_COMPUTE_ERROR("Not implemented");
783 }
784 }
785}
786
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100787template <typename T1, typename T2>
788inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100789 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100790{
791 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
792 switch(conv_stride_x)
793 {
794 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100795 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100796 break;
797 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100798 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100799 break;
800 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100801 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100802 break;
803 default:
804 ARM_COMPUTE_ERROR("Not implemented");
805 }
806}
Pablo Tello06da39d2017-08-10 15:10:40 +0100807
808template <typename T1, typename T2>
809inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100810 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Pablo Tello06da39d2017-08-10 15:10:40 +0100811{
812 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
813 switch(conv_stride_x)
814 {
815 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100816 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Tello06da39d2017-08-10 15:10:40 +0100817 break;
818 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100819 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Tello06da39d2017-08-10 15:10:40 +0100820 break;
821 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100822 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Tello06da39d2017-08-10 15:10:40 +0100823 break;
824 default:
825 ARM_COMPUTE_ERROR("Not implemented");
826 }
827}
828
Manuel Bottini327225d2021-04-13 13:09:30 +0100829Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000830{
Manuel Bottini327225d2021-04-13 13:09:30 +0100831 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
832 ARM_COMPUTE_RETURN_ERROR_ON(src->data_layout() == DataLayout::UNKNOWN);
833 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
834 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
835 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000836
Manuel Bottini327225d2021-04-13 13:09:30 +0100837 const DataLayout data_layout = src->data_layout();
Giorgio Arenac0f54432018-03-16 14:02:34 +0000838 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
839 const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
840 const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
841
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000842 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Manuel Bottini327225d2021-04-13 13:09:30 +0100843 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != src->dimension(channel_idx));
Giorgio Arenac0f54432018-03-16 14:02:34 +0000844 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000845 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Manuel Bottini327225d2021-04-13 13:09:30 +0100846 ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && src->data_type() != DataType::F32);
847 ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (src->data_type() == DataType::F16));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000848
849 // Checks performed when output is configured
Manuel Bottini327225d2021-04-13 13:09:30 +0100850 if(dst->total_size() != 0)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000851 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100852 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *weights, conv_info);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000853
Manuel Bottini327225d2021-04-13 13:09:30 +0100854 DataType data_type = src->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000855
Manuel Bottini327225d2021-04-13 13:09:30 +0100856 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), output_shape);
857 ARM_COMPUTE_RETURN_ERROR_ON(dst->data_type() != data_type);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000858 }
859
860 return Status{};
861}
862
Manuel Bottini327225d2021-04-13 13:09:30 +0100863std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src, ITensorInfo *weights, ITensorInfo *dst, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +0000864 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000865{
Manuel Bottini327225d2021-04-13 13:09:30 +0100866 ARM_COMPUTE_ERROR_ON(src->data_layout() == DataLayout::UNKNOWN);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000867
Manuel Bottini327225d2021-04-13 13:09:30 +0100868 const DataLayout data_layout = src->data_layout();
Giorgio Arenac0f54432018-03-16 14:02:34 +0000869 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
870
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000871 // Calculate right and bottom border
Giorgio Arenac0f54432018-03-16 14:02:34 +0000872 unsigned int kernel_size = weights->dimension(width_idx);
Georgios Pinitas1d6d2112018-02-05 17:40:12 +0000873 const int conv_stride_x = std::get<0>(conv_info.stride());
Georgios Pinitas1a03d762018-02-21 14:47:09 +0000874 const int conv_stride_y = std::get<1>(conv_info.stride());
Manuel Bottini327225d2021-04-13 13:09:30 +0100875 const int input_width = src->dimension(width_idx);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000876
Giorgio Arenac0f54432018-03-16 14:02:34 +0000877 Window win{};
878 bool window_changed = false;
879
880 if(data_layout == DataLayout::NCHW)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000881 {
Giorgio Arenac0f54432018-03-16 14:02:34 +0000882 switch(kernel_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000883 {
Giorgio Arenac0f54432018-03-16 14:02:34 +0000884 case 1:
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000885 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100886 switch(src->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +0000887 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000888#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +0000889 case DataType::F16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000890 num_elems_written_per_iteration = 8;
Giorgio Arenac0f54432018-03-16 14:02:34 +0000891 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100892#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +0000893 case DataType::F32:
Manuel Bottini327225d2021-04-13 13:09:30 +0100894 if(run_optim_small_tensor_info(src))
Giorgio Arenac0f54432018-03-16 14:02:34 +0000895 {
896 num_elems_written_per_iteration = 8;
897 }
898 else
899 {
900 num_elems_written_per_iteration = 4;
901 }
902 break;
903 default:
904 ARM_COMPUTE_ERROR("Data type not supported.");
905 break;
906 }
907 num_weight_elems_read_per_row = kernel_size;
908 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
909 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000910 }
Giorgio Arenac0f54432018-03-16 14:02:34 +0000911 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100912 switch(src->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +0000913 {
914 case DataType::F32:
915 num_weight_elems_read_per_row = 4 + kernel_size - 1;
916 num_elems_read_per_iteration = 12;
917 num_elems_written_per_iteration = 16 >> conv_stride_x;
918 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000919#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +0000920 case DataType::F16:
Giorgio Arenac0f54432018-03-16 14:02:34 +0000921 num_weight_elems_read_per_row = 8 + kernel_size - 1;
922 num_elems_read_per_iteration = 24;
923 num_elems_written_per_iteration = 32 >> conv_stride_x;
924 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100925#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +0000926 default:
927 ARM_COMPUTE_ERROR("Data type not supported.");
928 break;
929 }
Gian Marco Iodice41acb762018-08-23 10:25:06 +0100930 break;
931 case 5:
932 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100933 switch(src->data_type())
Gian Marco Iodice41acb762018-08-23 10:25:06 +0100934 {
935 case DataType::F32:
936 num_weight_elems_read_per_row = 4 + kernel_size - 1;
937 num_elems_read_per_iteration = 12;
938 num_elems_written_per_iteration = 16 >> conv_stride_x;
939 break;
940 default:
941 ARM_COMPUTE_ERROR("Data type not supported.");
942 break;
943 }
Giorgio Arenac0f54432018-03-16 14:02:34 +0000944 }
945 break;
946 default:
947 {
948 ARM_COMPUTE_ERROR("Not implemented");
949 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000950 }
951 }
Giorgio Arenac0f54432018-03-16 14:02:34 +0000952
953 // Calculate right pad
954 int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
Manuel Bottini327225d2021-04-13 13:09:30 +0100955 int end_x = ceil_to_multiple(static_cast<int>(dst->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
Giorgio Arenac0f54432018-03-16 14:02:34 +0000956 int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
957
958 // Calculate border
959 const unsigned int conv_pad_left = conv_info.pad_left();
960 const unsigned int conv_pad_top = conv_info.pad_top();
961 const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
962 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
963
964 border_size.left = conv_pad_left;
965 border_size.top = conv_pad_top;
966 border_size.right = conv_pad_right;
967 border_size.bottom = conv_pad_bottom;
968
969 // Configure window
Manuel Bottini327225d2021-04-13 13:09:30 +0100970 win = calculate_max_window(*dst, Steps(num_elems_written_per_iteration));
Giorgio Arenac0f54432018-03-16 14:02:34 +0000971
Manuel Bottini327225d2021-04-13 13:09:30 +0100972 AccessWindowRectangle input_access(src, -conv_pad_left, -conv_pad_top,
Giorgio Arenac0f54432018-03-16 14:02:34 +0000973 num_elems_read_per_iteration, kernel_size,
974 conv_stride_x, conv_stride_y);
975 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
Manuel Bottini327225d2021-04-13 13:09:30 +0100976 AccessWindowHorizontal output_access(dst, 0, num_elems_written_per_iteration);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000977 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
Manuel Bottini327225d2021-04-13 13:09:30 +0100978 output_access.set_valid_region(win, ValidRegion(Coordinates(), dst->tensor_shape()));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000979 }
Giorgio Arenac0f54432018-03-16 14:02:34 +0000980 else
981 {
Manuel Bottini87350f42020-09-15 13:03:34 +0100982 // Configure window NHWC without any padding
Manuel Bottini327225d2021-04-13 13:09:30 +0100983 win = calculate_max_window(*dst, Steps());
Giorgio Arenac0f54432018-03-16 14:02:34 +0000984 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000985
986 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
987 return std::make_pair(err, win);
988}
Manuel Bottini87350f42020-09-15 13:03:34 +0100989
Michalis Spyroub55f8e82021-07-22 11:23:11 +0100990bool have_zero_x_internal_padding(ITensorInfo *src, const ITensorInfo *weights)
Manuel Bottini87350f42020-09-15 13:03:34 +0100991{
Manuel Bottini327225d2021-04-13 13:09:30 +0100992 return (src->padding().left == 0 && weights->padding().left == 0 && src->padding().right == 0 && weights->padding().right == 0);
Manuel Bottini87350f42020-09-15 13:03:34 +0100993}
994
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100995} // namespace
996
Manuel Bottini87350f42020-09-15 13:03:34 +0100997template <typename T>
Manuel Bottinib4bb6a02021-05-24 16:01:32 +0100998void CpuDirectConv2dKernel::convolve_nhwc_optimized(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst)
Manuel Bottini87350f42020-09-15 13:03:34 +0100999{
1000 // This function assumes that input and weights have not padding in channel
1001
1002 // Declare useful types
1003 using vtype = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
1004 using vector_type = typename vtype::type;
1005 using tag_type = typename vtype::tag_type;
1006
1007 // Scalar quantities
Manuel Bottini327225d2021-04-13 13:09:30 +01001008 const int element_size = src->info()->element_size();
1009 const int input_stride_w = src->info()->strides_in_bytes().y() / element_size;
1010 const int input_stride_h = src->info()->strides_in_bytes().z() / element_size;
1011 const int input_stride_n = src->info()->strides_in_bytes()[3] / element_size;
1012 const int input_dim_w = src->info()->dimension(1);
1013 const int input_dim_h = src->info()->dimension(2);
Manuel Bottini87350f42020-09-15 13:03:34 +01001014
Manuel Bottini327225d2021-04-13 13:09:30 +01001015 const int output_stride_c = dst->info()->strides_in_bytes().x();
Manuel Bottini87350f42020-09-15 13:03:34 +01001016
Manuel Bottini327225d2021-04-13 13:09:30 +01001017 const unsigned int kernel_stride_w = weights->info()->strides_in_bytes().y() / element_size;
1018 const unsigned int kernel_stride_h = weights->info()->strides_in_bytes().z() / element_size;
1019 const int kernel_dim_w = weights->info()->dimension(1);
1020 const int kernel_dim_h = weights->info()->dimension(2);
Manuel Bottini87350f42020-09-15 13:03:34 +01001021
1022 const int conv_pad_top = _conv_info.pad_top();
1023 const int conv_pad_left = _conv_info.pad_left();
1024 const int conv_stride_w = std::get<0>(_conv_info.stride());
1025 const int conv_stride_h = std::get<1>(_conv_info.stride());
1026
1027 // Setup input window for the output iterator
1028 Window window_out = window;
1029 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
1030
1031 // Setup input window for the weights iterator
Manuel Bottini327225d2021-04-13 13:09:30 +01001032 Window window_w = calculate_max_window(*weights->info(), Steps());
Manuel Bottini87350f42020-09-15 13:03:34 +01001033 window_w.set(Window::DimX, Window::Dimension(0, 1, 1));
1034 window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
1035 window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
1036
Manuel Bottini327225d2021-04-13 13:09:30 +01001037 Iterator out(dst, window_out);
1038 Iterator wei(weights, window_w);
Manuel Bottini87350f42020-09-15 13:03:34 +01001039
1040 constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
1041 /*
1042 * This implementation parallelize the full WC plane of input and weights by
1043 * treating them as series of elements. So for example, a 3x3 weights and
1044 * floating point vector operations of 4 elements per time, the first 3
1045 * channel elements of the first row would be taken and additionally the first
1046 * element of the second row. The 9 elements in each single WC weight plane
1047 * would require 2 4-element vector operations and a last single element operation.
1048 *
1049 * This works since when we create the input vector to multiply with the weights,
1050 * the exact required elements are loaded in the same order. Therefore the
1051 * multiplication works on the correct input/weight elements.
1052 */
1053 execute_window_loop(window_out, [&](const Coordinates & id)
1054 {
1055 /*
1056 * In here we create theoretical indexes which then we validate for both
1057 * inputs and weights.
1058 * As a reminder, this loop take each output point in NHW, C is treated
1059 * in the weights loop.
1060 */
1061 // We are computing the theoretical starting input starting points
1062 const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
1063 const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
1064 const int in_w_end_t = in_w_start_t + kernel_dim_w;
1065 const int in_h_end_t = in_h_start_t + kernel_dim_h;
1066
1067 // We are computing the valid initial and ending input points by checking the borders
1068 const int in_w_start = std::max(in_w_start_t, 0);
1069 const int in_h_start = std::max(in_h_start_t, 0);
1070 const int in_w_end = std::min(in_w_end_t, input_dim_w);
1071 const int in_h_end = std::min(in_h_end_t, input_dim_h);
1072
1073 // We use the input points to select the valid weight points to use
1074 const int index_wc_start = (in_w_start - in_w_start_t) * kernel_stride_w;
1075 const int index_h_start = in_h_start - in_h_start_t;
1076 const int index_wc_end = (kernel_dim_w - (in_w_end_t - in_w_end)) * kernel_stride_w;
1077 const int index_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
1078
1079 execute_window_loop(window_w, [&](const Coordinates & id_w)
1080 {
1081 /*
1082 * This is the loop in the weights, and it goes along N (the batches)
1083 * As a reminder, the batches of the weights are translated into the
1084 * channels of the output
1085 */
Manuel Bottini327225d2021-04-13 13:09:30 +01001086 const T *in_ptr_row = reinterpret_cast<const T *>(src->buffer() + src->info()->offset_first_element_in_bytes())
Manuel Bottini87350f42020-09-15 13:03:34 +01001087 + id[3] * input_stride_n + in_w_start * input_stride_w + in_h_start * input_stride_h;
1088 const T *weights_ptr_row = reinterpret_cast<const T *>(wei.ptr()) + index_h_start * kernel_stride_h;
1089 uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c;
1090
1091 T out_temp = static_cast<T>(0);
1092 for(int index_h = index_h_start; index_h < index_h_end; ++index_h, in_ptr_row += input_stride_h, weights_ptr_row += kernel_stride_h)
1093 {
1094 const T *in_ptr_mover = in_ptr_row;
1095 int index_wc = index_wc_start;
1096 vector_type out_temp_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
1097 for(; index_wc <= index_wc_end - num_elems_read_per_iteration; index_wc += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
1098 {
1099 const auto src_vec = wrapper::vloadq(in_ptr_mover);
1100 const auto w_vec = wrapper::vloadq(weights_ptr_row + index_wc);
1101 out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec);
1102 }
1103 out_temp += vreduce(out_temp_vec);
1104 for(; index_wc < index_wc_end; ++index_wc, ++in_ptr_mover)
1105 {
1106 const auto src_val = *(in_ptr_mover);
1107 const auto w_val = *(weights_ptr_row + index_wc);
1108 out_temp += src_val * w_val;
1109 }
1110 }
1111 *(reinterpret_cast<T *>(out_ptr)) = out_temp;
1112 },
1113 wei);
1114 },
1115 out);
1116}
1117
1118template <typename T>
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001119void CpuDirectConv2dKernel::convolve_nhwc(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst)
Manuel Bottini87350f42020-09-15 13:03:34 +01001120{
1121 // Declare useful types
1122 using vtype = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
1123 using vector_type = typename vtype::type;
1124 using tag_type = typename vtype::tag_type;
1125
1126 // Scalar quantities
Manuel Bottini327225d2021-04-13 13:09:30 +01001127 const int element_size = src->info()->element_size();
1128 const int input_stride_w = src->info()->strides_in_bytes().y() / element_size;
1129 const int input_stride_h = src->info()->strides_in_bytes().z() / element_size;
1130 const int input_stride_n = src->info()->strides_in_bytes()[3] / element_size;
1131 const int input_dim_w = src->info()->dimension(1);
1132 const int input_dim_h = src->info()->dimension(2);
Manuel Bottini87350f42020-09-15 13:03:34 +01001133
Manuel Bottini327225d2021-04-13 13:09:30 +01001134 const int output_stride_c = dst->info()->strides_in_bytes().x();
Manuel Bottini87350f42020-09-15 13:03:34 +01001135
Manuel Bottini327225d2021-04-13 13:09:30 +01001136 const unsigned int kernel_stride_w = weights->info()->strides_in_bytes().y() / element_size;
1137 const unsigned int kernel_stride_h = weights->info()->strides_in_bytes().z() / element_size;
1138 const int kernel_dim_w = weights->info()->dimension(1);
1139 const int kernel_dim_h = weights->info()->dimension(2);
Manuel Bottini87350f42020-09-15 13:03:34 +01001140
1141 const int conv_pad_top = _conv_info.pad_top();
1142 const int conv_pad_left = _conv_info.pad_left();
1143 const int conv_stride_w = std::get<0>(_conv_info.stride());
1144 const int conv_stride_h = std::get<1>(_conv_info.stride());
1145
1146 // Setup input window for the output iterator
1147 Window window_out = window;
1148 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
1149
1150 // Setup input window for the weights iterator
Manuel Bottini327225d2021-04-13 13:09:30 +01001151 Window window_w = calculate_max_window(*weights->info(), Steps());
Manuel Bottini87350f42020-09-15 13:03:34 +01001152 window_w.set(Window::DimX, Window::Dimension(0, 1, 1));
1153 window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
1154 window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
1155
Manuel Bottini327225d2021-04-13 13:09:30 +01001156 Iterator out(dst, window_out);
1157 Iterator wei(weights, window_w);
Manuel Bottini87350f42020-09-15 13:03:34 +01001158
1159 constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
1160
1161 execute_window_loop(window_out, [&](const Coordinates & id)
1162 {
1163 // We are computing the theoretical starting input starting points
1164 const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
1165 const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
1166 const int in_w_end_t = in_w_start_t + kernel_dim_w;
1167 const int in_h_end_t = in_h_start_t + kernel_dim_h;
1168
1169 // We are computing the valid initial and ending input points by checking the borders
1170 const int in_w_start = std::max(in_w_start_t, 0);
1171 const int in_h_start = std::max(in_h_start_t, 0);
1172 const int in_w_end = std::min(in_w_end_t, input_dim_w);
1173 const int in_h_end = std::min(in_h_end_t, input_dim_h);
1174
1175 // We use the input points to select the valid weight points to use
1176 const int wei_w_start = in_w_start - in_w_start_t;
1177 const int wei_h_start = in_h_start - in_h_start_t;
1178 const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end);
1179 const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
1180
Manuel Bottini327225d2021-04-13 13:09:30 +01001181 const int index_c_end = weights->info()->dimension(0);
1182 const T *const in_ptr_start = reinterpret_cast<const T *>(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[3] * input_stride_n;
Manuel Bottini87350f42020-09-15 13:03:34 +01001183
1184 execute_window_loop(window_w, [&](const Coordinates & id_w)
1185 {
1186 const T *const weights_ptr_start = reinterpret_cast<const T *>(wei.ptr());
1187 uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c;
1188
1189 T out_temp = static_cast<T>(0);
1190 for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h)
1191 {
1192 const T *const in_ptr_row = in_ptr_start + index_in_h * input_stride_h;
1193 const T *const weights_ptr_row = weights_ptr_start + index_wei_h * kernel_stride_h;
1194 for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w)
1195 {
1196 const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w;
1197 const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
1198 int index_c = 0;
1199 vector_type out_temp_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
1200 for(; index_c <= index_c_end - num_elems_read_per_iteration; index_c += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration, weights_ptr_mover += num_elems_read_per_iteration)
1201 {
1202 const auto src_vec = wrapper::vloadq(in_ptr_mover);
1203 const auto w_vec = wrapper::vloadq(weights_ptr_mover);
1204 out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec);
1205 }
1206 out_temp += vreduce(out_temp_vec);
1207 for(; index_c < index_c_end; ++index_c, ++in_ptr_mover, ++weights_ptr_mover)
1208 {
1209 const auto src_val = *(in_ptr_mover);
1210 const auto w_val = *(weights_ptr_mover);
1211 out_temp += src_val * w_val;
1212 }
1213 }
1214 }
1215 *(reinterpret_cast<T *>(out_ptr)) = out_temp;
1216 },
1217 wei);
1218 },
1219 out);
1220}
1221
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001222BorderSize CpuDirectConv2dKernel::border_size() const
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001223{
1224 return _border_size;
1225}
1226
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001227void CpuDirectConv2dKernel::configure(ITensorInfo *src, ITensorInfo *weights, ITensorInfo *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001228{
Manuel Bottini327225d2021-04-13 13:09:30 +01001229 ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001230
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001231 _conv_info = conv_info;
Manuel Bottini327225d2021-04-13 13:09:30 +01001232 _data_layout = src->data_layout();
1233 _kernel_size = weights->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001234
1235 const unsigned int conv_pad_left = conv_info.pad_left();
1236 const unsigned int conv_pad_top = conv_info.pad_top();
1237 const unsigned int conv_pad_right = conv_info.pad_right();
1238 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
Manuel Bottinica62c6f2021-03-23 11:50:34 +00001239 if(_data_layout == DataLayout::NCHW)
Manuel Bottini87350f42020-09-15 13:03:34 +01001240 {
1241 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
1242 }
1243 else
1244 {
1245 _border_size = BorderSize(0);
1246 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001247
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001248 // Get convolved dimensions
Manuel Bottini327225d2021-04-13 13:09:30 +01001249 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *weights, conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001250
Manuel Bottini327225d2021-04-13 13:09:30 +01001251 DataType data_type = src->data_type();
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001252
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001253 // Output auto inizialitation if not yet initialized
Manuel Bottini327225d2021-04-13 13:09:30 +01001254 auto_init_if_empty(*dst, output_shape, 1, data_type);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001255
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001256 // Perform validation step
Manuel Bottini327225d2021-04-13 13:09:30 +01001257 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, dst, conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001258
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001259 // Configure kernel window
Manuel Bottini327225d2021-04-13 13:09:30 +01001260 auto win_config = validate_and_configure_window(src, weights, dst, conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001261 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001262 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Manuel Bottini327225d2021-04-13 13:09:30 +01001263 ICpuKernel::configure(win_config.second);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001264}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001265
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001266Status CpuDirectConv2dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001267{
1268 unsigned int num_weight_elems_read_per_row = 0;
1269 unsigned int num_elems_read_per_iteration = 0;
1270 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas15997872018-02-19 13:58:22 +00001271 BorderSize border_size = {};
Manuel Bottini327225d2021-04-13 13:09:30 +01001272 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, dst, conv_info));
1273 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src->clone().get(),
Georgios Pinitas0223a782017-12-12 11:44:44 +00001274 weights->clone().get(),
Manuel Bottini327225d2021-04-13 13:09:30 +01001275 dst->clone().get(),
Georgios Pinitas0223a782017-12-12 11:44:44 +00001276 conv_info,
1277 num_weight_elems_read_per_row,
1278 num_elems_read_per_iteration,
1279 num_elems_written_per_iteration,
1280 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001281 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001282
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001283 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001284}
1285
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001286void CpuDirectConv2dKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001287{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001288 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001289 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Manuel Bottini327225d2021-04-13 13:09:30 +01001290 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001291
Manuel Bottini327225d2021-04-13 13:09:30 +01001292 auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1293 auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1294 auto dst = tensors.get_tensor(TensorType::ACL_DST);
1295 const int kernel_size = weights->info()->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001296
Manuel Bottinica62c6f2021-03-23 11:50:34 +00001297 if(_data_layout == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001298 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001299 switch(kernel_size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001300 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001301 case 1:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001302 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001303 switch(src->info()->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +00001304 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001305 case DataType::F32:
Manuel Bottini327225d2021-04-13 13:09:30 +01001306 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001307 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001308#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001309 case DataType::F16:
Manuel Bottini327225d2021-04-13 13:09:30 +01001310 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001311 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001312#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001313 default:
1314 ARM_COMPUTE_ERROR("Data type not supported");
1315 break;
1316 }
1317 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001318 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001319 case 3:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001320 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001321 switch(src->info()->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +00001322 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001323 case DataType::F32:
Manuel Bottini327225d2021-04-13 13:09:30 +01001324 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001325 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001326#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001327 case DataType::F16:
Manuel Bottini327225d2021-04-13 13:09:30 +01001328 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001329 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001330#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001331 default:
1332 ARM_COMPUTE_ERROR("Data type not supported");
1333 break;
1334 }
1335 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001336 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001337 case 5:
Pablo Tello06da39d2017-08-10 15:10:40 +01001338 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001339 switch(src->info()->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +00001340 {
1341 case DataType::F32:
Manuel Bottini327225d2021-04-13 13:09:30 +01001342 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001343 break;
1344 default:
1345 ARM_COMPUTE_ERROR("Data type not supported");
1346 break;
1347 }
1348 break;
Pablo Tello06da39d2017-08-10 15:10:40 +01001349 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001350 default:
1351 {
1352 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
1353 break;
1354 }
1355 }
1356 }
1357 else
1358 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001359 switch(src->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001360 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001361 case DataType::F32:
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001362 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001363 if(have_zero_x_internal_padding(src->info(), weights->info()))
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001364 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001365 convolve_nhwc_optimized<float>(window, src, weights, dst);
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001366 }
1367 else
1368 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001369 convolve_nhwc<float>(window, src, weights, dst);
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001370 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001371 break;
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001372 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001373 default:
1374 ARM_COMPUTE_ERROR("Data type not supported");
1375 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001376 }
1377 }
1378}
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001379const char *CpuDirectConv2dKernel::name() const
Manuel Bottini327225d2021-04-13 13:09:30 +01001380{
1381 return "CpuDirectConvolutionLayerKernel";
1382}
1383} // namespace kernels
1384} // namespace cpu
Sheri Zhangac6499a2021-02-10 15:32:38 +00001385} // namespace arm_compute