blob: 68de9803eb83bfcadfe33960e652f45493ee1660 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Sheri Zhangac6499a2021-02-10 15:32:38 +00002 * Copyright (c) 2017-2021 Arm Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Georgios Pinitas7891a732021-08-20 21:39:25 +010024#include "src/cpu/kernels/CpuDirectConv2dKernel.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010025
26#include "src/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
27#include "src/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010028
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Error.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/ITensor.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010033#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010034#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010035#include "arm_compute/core/Validate.h"
Giorgio Arenac0f54432018-03-16 14:02:34 +000036#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010037#include "src/core/AccessWindowStatic.h"
38#include "src/core/CPP/Validate.h"
Georgios Pinitasddb93bb2020-10-02 16:38:59 +010039#include "src/core/NEON/NEFixedPoint.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010040#include "src/core/helpers/AutoConfiguration.h"
41#include "src/core/helpers/WindowHelpers.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010042
43#include <algorithm>
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010045using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010046
Manuel Bottini87350f42020-09-15 13:03:34 +010047namespace arm_compute
48{
Manuel Bottini327225d2021-04-13 13:09:30 +010049namespace cpu
50{
51namespace kernels
52{
Anthony Barbier6ff3b192017-09-04 18:44:23 +010053namespace
54{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000055#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010056template <unsigned int stridex>
57float16x8_t internal_vld1q(const float16_t *in);
58
59template <>
60float16x8_t internal_vld1q<1>(const float16_t *in)
61{
62 return vld1q_f16(in);
63}
64
65template <>
66float16x8_t internal_vld1q<2>(const float16_t *in)
67{
68 const float16x8x2_t tmp = vld2q_f16(in);
69 return tmp.val[0];
70}
71
72template <>
73float16x8_t internal_vld1q<3>(const float16_t *in)
74{
75 const float16x8x3_t tmp = vld3q_f16(in);
76 return tmp.val[0];
77}
78
79inline float16x8_t internal_vdupq_n(float16_t v)
80{
81 return vdupq_n_f16(v);
82}
83
84inline void internal_vst1q(float16_t *p, const float16x8_t &v)
85{
86 vst1q_f16(p, v);
87}
88
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010089float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y)
Pablo Tello0d176142017-07-06 16:43:14 +010090{
Pablo Tello0d176142017-07-06 16:43:14 +010091 return vmulq_f16(x, y);
92}
93
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010094inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z)
Pablo Tello0d176142017-07-06 16:43:14 +010095{
Pablo Tello0d176142017-07-06 16:43:14 +010096 return vaddq_f16(x, vmulq_f16(y, z));
97}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000098#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +010099
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100100template <unsigned int stridex>
101float32x4_t internal_vld1q(const float *in);
102
103template <>
104float32x4_t internal_vld1q<1>(const float *in)
105{
106 return vld1q_f32(in);
107}
108
109template <>
110float32x4_t internal_vld1q<2>(const float *in)
111{
112 const float32x4x2_t tmp = vld2q_f32(in);
113 return tmp.val[0];
114}
115
116template <>
117float32x4_t internal_vld1q<3>(const float *in)
118{
119 const float32x4x3_t tmp = vld3q_f32(in);
120 return tmp.val[0];
121}
122
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100123inline float32x4_t internal_vdupq_n(float v)
124{
125 return vdupq_n_f32(v);
126}
127
128inline void internal_vst1q(float *p, const float32x4_t &v)
129{
130 vst1q_f32(p, v);
131}
132
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100133float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100134{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100135 return vmulq_f32(x, y);
136}
137
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100138inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100139{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100140 return vmlaq_f32(x, y, z);
141}
142
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000143constexpr int small_tensor_size_optim = 8;
144inline bool run_optim_small_tensor_info(const ITensorInfo *t)
145{
146 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
147}
148
Pablo Telloc09314a2017-09-21 13:59:14 +0100149inline bool run_optim_small_tensor(const ITensor *t)
150{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000151 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100152}
153
154// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
155// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
Michele Di Giorgio33f41fa2021-03-09 14:09:08 +0000156// store intermidiate results in memory. Temporary results are stored in SIMD registers directly and then written to the output buffer.
Pablo Telloc09314a2017-09-21 13:59:14 +0100157template <unsigned int stridex>
158class convolver_w1x1_i8x8_f32
159{
160public:
Manuel Bottini327225d2021-04-13 13:09:30 +0100161 static void convolve(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Pablo Telloc09314a2017-09-21 13:59:14 +0100162 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100163 ARM_COMPUTE_ERROR_ON(src->info()->dimension(Window::DimX) > small_tensor_size_optim);
164 ARM_COMPUTE_ERROR_ON(src->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100165
Manuel Bottini327225d2021-04-13 13:09:30 +0100166 const int input_stride_x = src->info()->strides_in_bytes().x();
167 const int input_stride_y = src->info()->strides_in_bytes().y();
168 const int input_stride_z = src->info()->strides_in_bytes().z();
169 const int output_stride_y = dst->info()->strides_in_bytes().y();
170 const int output_stride_z = dst->info()->strides_in_bytes().z();
Pablo Telloc09314a2017-09-21 13:59:14 +0100171 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
172 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
Manuel Bottini327225d2021-04-13 13:09:30 +0100173 const int output_h = dst->info()->dimension(1);
Pablo Telloc09314a2017-09-21 13:59:14 +0100174 const int range_z = window.z().end() - window.z().start();
175 const int kernel_depth = weights->info()->dimension(Window::DimZ);
176 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000177 const unsigned int conv_pad_left = conv_info.pad_left();
178 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Telloc09314a2017-09-21 13:59:14 +0100179
180 // setup output window for the iterator
181 Window window_out = window;
Manuel Bottini327225d2021-04-13 13:09:30 +0100182 window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
183 window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
Pablo Telloc09314a2017-09-21 13:59:14 +0100184 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
185
186 // setup input window for the iterator
187 Window window_in = window;
188 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
189 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
190 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
191 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
192
193 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Manuel Bottini327225d2021-04-13 13:09:30 +0100194 Iterator out(dst, window_out);
195 Iterator in(src, window_in);
Pablo Telloc09314a2017-09-21 13:59:14 +0100196 Iterator k(weights, window_k);
197
198 const uint8_t *k_ptr = k.ptr();
199
200 execute_window_loop(window_out, [&](const Coordinates & id)
201 {
Michalis Spyroubcfd09a2019-05-01 13:03:59 +0100202 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
203 uint8_t *out_ptr = out.ptr();
204 int ih = 0;
205 int oh = 0;
206 std::array<float32x4_t, 8> accum0 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
207 std::array<float32x4_t, 8> accum1 = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100208 for(int oz = 0; oz < range_z; ++oz)
209 {
210 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
211 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
212 auto p_out_base = out_ptr + oz * output_stride_z;
213 for(int p = 0; p < kernel_depth; ++p)
214 {
215 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
216 const auto vk0 = internal_vdupq_n(*k_val);
217 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
218 {
219 const int offset_xy = ih * input_stride_y;
220 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
221 auto v_in0 = internal_vld1q<stridex>(in_val);
222 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
223 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
224 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
225 }
226 }
227 for(oh = 0; oh < output_h; ++oh)
228 {
229 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
230 vst1q_f32(p_out, accum0[oh]);
231 vst1q_f32(p_out + 4, accum1[oh]);
232 }
233 }
234 },
235 in, out);
236 }
237};
238
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100239template <typename T1, typename T2, unsigned int stridex>
240class convolver_1x1
241{
242public:
243 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100244 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100245 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100246 const int input_stride_x = src->info()->strides_in_bytes().x();
247 const int input_stride_y = src->info()->strides_in_bytes().y();
248 const int input_stride_z = src->info()->strides_in_bytes().z();
249 const int output_stride_y = dst->info()->strides_in_bytes().y();
250 const int output_stride_z = dst->info()->strides_in_bytes().z();
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100251 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
252 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
Manuel Bottini327225d2021-04-13 13:09:30 +0100253 const int output_w = dst->info()->dimension(0);
254 const int output_h = dst->info()->dimension(1);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100255 const int range_z = window.z().end() - window.z().start();
256 const int kernel_depth = weights->info()->dimension(Window::DimZ);
257 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
258 const unsigned int conv_pad_left = conv_info.pad_left();
259 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100260
261 // setup output window for the iterator
262 Window window_out = window;
Manuel Bottini327225d2021-04-13 13:09:30 +0100263 window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
264 window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100265 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
266
267 // setup input window for the iterator
268 Window window_in = window;
269 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
270 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
271 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
272 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
273
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100274 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Manuel Bottini327225d2021-04-13 13:09:30 +0100275 Iterator out(dst, window_out);
276 Iterator in(src, window_in);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100277 Iterator k(weights, window_k);
278
279 const uint8_t *k_ptr = k.ptr();
280
281 execute_window_loop(window_out, [&](const Coordinates & id)
282 {
283 /*
284 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
285 */
Georgios Pinitas15997872018-02-19 13:58:22 +0000286 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100287 uint8_t *out_ptr = out.ptr();
288 int ih = 0;
289 int oh = 0;
290 for(int oz = 0; oz < range_z; ++oz)
291 {
292 auto p_out_base = out_ptr + oz * output_stride_z;
293 // Step 1
294 {
295 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
296 const auto vk = internal_vdupq_n(*k_val);
297 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
298 {
299 const int offset_xy = ih * input_stride_y;
300 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
301 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
302 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
303 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100304 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100305 }
306 }
307 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100308
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100309 // Step 2
310 for(int p = 1; p < kernel_depth; ++p)
311 {
312 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
313 const auto vk = internal_vdupq_n(*k_val);
314 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
315 {
316 const int offset_xy = ih * input_stride_y;
317 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
318 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
319 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
320 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100321 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100322 }
323 }
324 }
325 }
326 },
327 in, out);
328 }
329};
330
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100331template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100332float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100333 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100334
335inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
336{
337 const float32x4x3_t m00 =
338 {
339 {
340 vld1q_dup_f32(m0),
341 vld1q_dup_f32(m1),
342 vld1q_dup_f32(m2)
343 }
344 };
345 return m00;
346}
347
348inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
349{
350 const float32x4x2_t m00 =
351 {
352 {
353 vld1q_dup_f32(m3),
354 vld1q_dup_f32(m4)
355 }
356 };
357 return m00;
358}
359
360inline float32x4x3_t load_input(const float *const in)
361{
362 const float32x4x3_t vin =
363 {
364 {
365 vld1q_f32(in),
366 vld1q_f32(in + 4),
367 vld1q_f32(in + 8)
368 }
369 };
370 return vin;
371}
372
373template <>
374inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100375 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100376{
Pablo Tello06da39d2017-08-10 15:10:40 +0100377 const float32x4x3_t vin0 = load_input(in_0);
378 const float32x4x3_t vin1 = load_input(in_1);
379 const float32x4x3_t vin2 = load_input(in_2);
380 const float32x4x3_t vin3 = load_input(in_3);
381 const float32x4x3_t vin4 = load_input(in_4);
382 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
383 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
384 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
385 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
386 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
387 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
388 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
389 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
390 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
391 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
392
393 float32x4x2_t out =
394 {
395 {
396 vmulq_f32(vin0.val[0], m00.val[0]),
397 vmulq_f32(vin0.val[1], m00.val[0])
398 }
399 };
400
401 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
402 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
403 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
404 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
405
406 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
407 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
408 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
409 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
410 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
411
412 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
413 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
414 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
415 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
416 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
417
418 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
419 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
420 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
421 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
422 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
423
424 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
425 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
426 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
427 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
428 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
429
430 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
431 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
432 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
433 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
434
435 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
436 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
437 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
438 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
439 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
440
441 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
442 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
443 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
444 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
445 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
446
447 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
448 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
449 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
450 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
451 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
452
453 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
454 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
455 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
456 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
457 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
458
459 return out;
460}
461
462template <>
463inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100464 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100465{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100466 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100467 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
468 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
469 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
470 return out;
471}
472
473template <>
474inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100475 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100476{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100477 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100478 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
479 return out;
480}
481
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100482template <typename T1, typename T2, unsigned int stridex>
483class convolver_3x3
484{
485public:
486 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100487 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100488 {
489 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Manuel Bottini327225d2021-04-13 13:09:30 +0100490 const int input_stride_x = src->info()->strides_in_bytes().x();
491 const int input_stride_y = src->info()->strides_in_bytes().y();
492 const int input_stride_z = src->info()->strides_in_bytes().z();
493 const int output_stride_y = dst->info()->strides_in_bytes().y();
494 const int output_stride_z = dst->info()->strides_in_bytes().z();
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100495 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
496 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
497 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
498 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
Manuel Bottini327225d2021-04-13 13:09:30 +0100499 const int output_w = dst->info()->dimension(0);
500 const int output_h = dst->info()->dimension(1);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100501 const int num_planes_z = window.z().end() - window.z().start();
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000502 const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100503 const int kernel_depth = weights->info()->dimension(Window::DimZ);
504 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
505 const unsigned int conv_pad_left = conv_info.pad_left();
506 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100507
508 // setup output window for the iterator
509 Window window_out = window;
Manuel Bottini327225d2021-04-13 13:09:30 +0100510 window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
511 window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100512 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
513
514 // setup input window for the iterator
515 Window window_in = window;
516 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
517 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
518 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
519 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
520
521 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
522
Manuel Bottini327225d2021-04-13 13:09:30 +0100523 Iterator out(dst, window_out);
524 Iterator in(src, window_in);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100525 Iterator k(weights, window_k);
526
527 const uint8_t *k_ptr = k.ptr();
528
529 execute_window_loop(window_out, [&](const Coordinates & id)
530 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000531 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100532 uint8_t *out_ptr = out.ptr();
533 int ih = 0;
534 int oh = 0;
535 /*
536 Each thread executing this kernel computes one or more output's volume planes.
537
538 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
539 the third thread [16,24] and the fourth thread [25,31].
540
541 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100542 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100543
544 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
545 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
546 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
547 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100548 for(int oz = 0; oz < num_planes_z; ++oz)
549 {
Pablo Tello0d176142017-07-06 16:43:14 +0100550 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100551 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
552 // Step 1
553 {
Pablo Tello0d176142017-07-06 16:43:14 +0100554 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
555 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
556 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100557 const auto vk_r0 = load_matrix_row(ptr_k_r0);
558 const auto vk_r1 = load_matrix_row(ptr_k_r1);
559 const auto vk_r2 = load_matrix_row(ptr_k_r2);
560 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
561 {
562 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
563 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
564 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
565 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
566 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
567 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
568 {
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000569 convolve_3x3<false>(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100570 }
571 }
572 }
573 // Step 2
574 for(int p = 1; p < kernel_depth; ++p)
575 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100576 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
577 const uint8_t *input_base = input_ptr + p * input_stride_z;
578 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
579 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
580 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
581 const auto vk_r0 = load_matrix_row(ptr_k_r0);
582 const auto vk_r1 = load_matrix_row(ptr_k_r1);
583 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100584 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
585 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100586 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
587 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
588 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100589 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
590 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
591 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
592 {
Georgios Pinitasa26e1662020-03-04 15:31:25 +0000593 convolve_3x3<true>(in_top, in_mid, in_low, p_out, vk_r0, vk_r1, vk_r2, stridex);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100594 }
595 }
596 }
597 }
598 },
599 in, out);
600 }
601};
602
Pablo Tello06da39d2017-08-10 15:10:40 +0100603template <typename T1, typename T2, unsigned int stridex>
604class convolver_5x5
605{
606public:
607 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100608 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Pablo Tello06da39d2017-08-10 15:10:40 +0100609 {
610 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Manuel Bottini327225d2021-04-13 13:09:30 +0100611 const int input_stride_x = src->info()->strides_in_bytes().x();
612 const int input_stride_y = src->info()->strides_in_bytes().y();
613 const int input_stride_z = src->info()->strides_in_bytes().z();
614 const int output_stride_y = dst->info()->strides_in_bytes().y();
615 const int output_stride_z = dst->info()->strides_in_bytes().z();
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100616 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
617 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
618 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
619 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
Manuel Bottini327225d2021-04-13 13:09:30 +0100620 const int output_w = dst->info()->dimension(0);
621 const int output_h = dst->info()->dimension(1);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100622 const int num_planes_z = window.z().end() - window.z().start();
Michele Di Giorgio13ec5f02020-01-02 12:11:13 +0000623 const int delta_input = get_input_num_elems_processed(num_elems_written_per_iteration, stridex);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100624 const int kernel_depth = weights->info()->dimension(Window::DimZ);
625 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
626 const unsigned int conv_pad_left = conv_info.pad_left();
627 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Tello06da39d2017-08-10 15:10:40 +0100628
629 // setup output window for the iterator
630 Window window_out = window;
Manuel Bottini327225d2021-04-13 13:09:30 +0100631 window_out.set(Window::DimX, Window::Dimension(0, dst->info()->dimension(Window::DimX), dst->info()->dimension(Window::DimX)));
632 window_out.set(Window::DimY, Window::Dimension(0, dst->info()->dimension(Window::DimY), dst->info()->dimension(Window::DimY)));
Pablo Tello06da39d2017-08-10 15:10:40 +0100633 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
634
635 // setup input window for the iterator
636 Window window_in = window;
637 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
638 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
639 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
640 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
641
642 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
643
Manuel Bottini327225d2021-04-13 13:09:30 +0100644 Iterator out(dst, window_out);
645 Iterator in(src, window_in);
Pablo Tello06da39d2017-08-10 15:10:40 +0100646 Iterator k(weights, window_k);
647
648 const uint8_t *k_ptr = k.ptr();
649
650 execute_window_loop(window_out, [&](const Coordinates & id)
651 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000652 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Pablo Tello06da39d2017-08-10 15:10:40 +0100653 uint8_t *out_ptr = out.ptr();
654 int ih = 0;
655 int oh = 0;
656 for(int oz = 0; oz < num_planes_z; ++oz)
657 {
658 const int zoffset = id.z() + oz;
659 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
660 // Step 1
661 {
662 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
663 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
664 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
665 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
666 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
667 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
668 {
669 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
670 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
671 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
672 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
673 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
674 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
675 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
676 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
677 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100678 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100679 store_results<stridex>(p_out, vres);
680 }
681 }
682 }
683 // Step 2
684 for(int p = 1; p < kernel_depth; ++p)
685 {
686 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
687 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
688 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
689 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
690 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
691
692 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
693 {
694 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
695 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
696 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
697 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
698 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
699 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
700 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
701 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
702 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100703 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100704 accumulate_results<stridex>(p_out, vres);
705 }
706 }
707 }
708 }
709 },
710 in, out);
711 }
712};
713
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100714template <typename T1, typename T2>
715inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100716 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100717{
718 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
719 switch(conv_stride_x)
720 {
721 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100722 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100723 break;
724 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100725 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100726 break;
727 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100728 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100729 break;
730 default:
731 ARM_COMPUTE_ERROR("Not implemented");
732 }
733}
734
Pablo Telloc09314a2017-09-21 13:59:14 +0100735template <>
736inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100737 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Pablo Telloc09314a2017-09-21 13:59:14 +0100738{
739 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
Manuel Bottini327225d2021-04-13 13:09:30 +0100740 if(run_optim_small_tensor(src))
Pablo Telloc09314a2017-09-21 13:59:14 +0100741 {
742 switch(conv_stride_x)
743 {
744 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100745 convolver_w1x1_i8x8_f32<1>::convolve(window, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100746 break;
747 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100748 convolver_w1x1_i8x8_f32<2>::convolve(window, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100749 break;
750 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100751 convolver_w1x1_i8x8_f32<3>::convolve(window, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100752 break;
753 default:
754 ARM_COMPUTE_ERROR("Not implemented");
755 }
756 }
757 else
758 {
759 switch(conv_stride_x)
760 {
761 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100762 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100763 break;
764 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100765 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100766 break;
767 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100768 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Telloc09314a2017-09-21 13:59:14 +0100769 break;
770 default:
771 ARM_COMPUTE_ERROR("Not implemented");
772 }
773 }
774}
775
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100776template <typename T1, typename T2>
777inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100778 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100779{
780 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
781 switch(conv_stride_x)
782 {
783 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100784 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100785 break;
786 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100787 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100788 break;
789 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100790 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100791 break;
792 default:
793 ARM_COMPUTE_ERROR("Not implemented");
794 }
795}
Pablo Tello06da39d2017-08-10 15:10:40 +0100796
797template <typename T1, typename T2>
798inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
Manuel Bottini327225d2021-04-13 13:09:30 +0100799 const ITensor *src, const ITensor *weights, ITensor *dst, const PadStrideInfo &conv_info)
Pablo Tello06da39d2017-08-10 15:10:40 +0100800{
801 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
802 switch(conv_stride_x)
803 {
804 case 1:
Manuel Bottini327225d2021-04-13 13:09:30 +0100805 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Tello06da39d2017-08-10 15:10:40 +0100806 break;
807 case 2:
Manuel Bottini327225d2021-04-13 13:09:30 +0100808 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Tello06da39d2017-08-10 15:10:40 +0100809 break;
810 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100811 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, src, weights, dst, conv_info);
Pablo Tello06da39d2017-08-10 15:10:40 +0100812 break;
813 default:
814 ARM_COMPUTE_ERROR("Not implemented");
815 }
816}
817
Manuel Bottini327225d2021-04-13 13:09:30 +0100818Status validate_arguments(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000819{
Manuel Bottini327225d2021-04-13 13:09:30 +0100820 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, weights, dst);
821 ARM_COMPUTE_RETURN_ERROR_ON(src->data_layout() == DataLayout::UNKNOWN);
822 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src);
823 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::F16, DataType::F32);
824 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000825
Manuel Bottini327225d2021-04-13 13:09:30 +0100826 const DataLayout data_layout = src->data_layout();
Giorgio Arenac0f54432018-03-16 14:02:34 +0000827 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
828 const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
829 const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
830
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000831 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Manuel Bottini327225d2021-04-13 13:09:30 +0100832 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != src->dimension(channel_idx));
Giorgio Arenac0f54432018-03-16 14:02:34 +0000833 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000834 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Manuel Bottini327225d2021-04-13 13:09:30 +0100835 ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && src->data_type() != DataType::F32);
836 ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (src->data_type() == DataType::F16));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000837
838 // Checks performed when output is configured
Manuel Bottini327225d2021-04-13 13:09:30 +0100839 if(dst->total_size() != 0)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000840 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100841 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *weights, conv_info);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000842
Manuel Bottini327225d2021-04-13 13:09:30 +0100843 DataType data_type = src->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000844
Manuel Bottini327225d2021-04-13 13:09:30 +0100845 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), output_shape);
846 ARM_COMPUTE_RETURN_ERROR_ON(dst->data_type() != data_type);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000847 }
848
849 return Status{};
850}
851
Manuel Bottini327225d2021-04-13 13:09:30 +0100852std::pair<Status, Window> validate_and_configure_window(ITensorInfo *src, ITensorInfo *weights, ITensorInfo *dst, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +0000853 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000854{
Manuel Bottini327225d2021-04-13 13:09:30 +0100855 ARM_COMPUTE_ERROR_ON(src->data_layout() == DataLayout::UNKNOWN);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000856
Manuel Bottini327225d2021-04-13 13:09:30 +0100857 const DataLayout data_layout = src->data_layout();
Giorgio Arenac0f54432018-03-16 14:02:34 +0000858 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
859
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000860 // Calculate right and bottom border
Giorgio Arenac0f54432018-03-16 14:02:34 +0000861 unsigned int kernel_size = weights->dimension(width_idx);
Georgios Pinitas1d6d2112018-02-05 17:40:12 +0000862 const int conv_stride_x = std::get<0>(conv_info.stride());
Georgios Pinitas1a03d762018-02-21 14:47:09 +0000863 const int conv_stride_y = std::get<1>(conv_info.stride());
Manuel Bottini327225d2021-04-13 13:09:30 +0100864 const int input_width = src->dimension(width_idx);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000865
Giorgio Arenac0f54432018-03-16 14:02:34 +0000866 Window win{};
867 bool window_changed = false;
868
869 if(data_layout == DataLayout::NCHW)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000870 {
Giorgio Arenac0f54432018-03-16 14:02:34 +0000871 switch(kernel_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000872 {
Giorgio Arenac0f54432018-03-16 14:02:34 +0000873 case 1:
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000874 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100875 switch(src->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +0000876 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000877#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +0000878 case DataType::F16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000879 num_elems_written_per_iteration = 8;
Giorgio Arenac0f54432018-03-16 14:02:34 +0000880 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100881#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +0000882 case DataType::F32:
Manuel Bottini327225d2021-04-13 13:09:30 +0100883 if(run_optim_small_tensor_info(src))
Giorgio Arenac0f54432018-03-16 14:02:34 +0000884 {
885 num_elems_written_per_iteration = 8;
886 }
887 else
888 {
889 num_elems_written_per_iteration = 4;
890 }
891 break;
892 default:
893 ARM_COMPUTE_ERROR("Data type not supported.");
894 break;
895 }
896 num_weight_elems_read_per_row = kernel_size;
897 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
898 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000899 }
Giorgio Arenac0f54432018-03-16 14:02:34 +0000900 case 3:
Manuel Bottini327225d2021-04-13 13:09:30 +0100901 switch(src->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +0000902 {
903 case DataType::F32:
904 num_weight_elems_read_per_row = 4 + kernel_size - 1;
905 num_elems_read_per_iteration = 12;
906 num_elems_written_per_iteration = 16 >> conv_stride_x;
907 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000908#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +0000909 case DataType::F16:
Giorgio Arenac0f54432018-03-16 14:02:34 +0000910 num_weight_elems_read_per_row = 8 + kernel_size - 1;
911 num_elems_read_per_iteration = 24;
912 num_elems_written_per_iteration = 32 >> conv_stride_x;
913 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100914#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +0000915 default:
916 ARM_COMPUTE_ERROR("Data type not supported.");
917 break;
918 }
Gian Marco Iodice41acb762018-08-23 10:25:06 +0100919 break;
920 case 5:
921 {
Manuel Bottini327225d2021-04-13 13:09:30 +0100922 switch(src->data_type())
Gian Marco Iodice41acb762018-08-23 10:25:06 +0100923 {
924 case DataType::F32:
925 num_weight_elems_read_per_row = 4 + kernel_size - 1;
926 num_elems_read_per_iteration = 12;
927 num_elems_written_per_iteration = 16 >> conv_stride_x;
928 break;
929 default:
930 ARM_COMPUTE_ERROR("Data type not supported.");
931 break;
932 }
Giorgio Arenac0f54432018-03-16 14:02:34 +0000933 }
934 break;
935 default:
936 {
937 ARM_COMPUTE_ERROR("Not implemented");
938 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000939 }
940 }
Giorgio Arenac0f54432018-03-16 14:02:34 +0000941
942 // Calculate right pad
943 int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
Manuel Bottini327225d2021-04-13 13:09:30 +0100944 int end_x = ceil_to_multiple(static_cast<int>(dst->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
Giorgio Arenac0f54432018-03-16 14:02:34 +0000945 int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
946
947 // Calculate border
948 const unsigned int conv_pad_left = conv_info.pad_left();
949 const unsigned int conv_pad_top = conv_info.pad_top();
950 const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
951 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
952
953 border_size.left = conv_pad_left;
954 border_size.top = conv_pad_top;
955 border_size.right = conv_pad_right;
956 border_size.bottom = conv_pad_bottom;
957
958 // Configure window
Manuel Bottini327225d2021-04-13 13:09:30 +0100959 win = calculate_max_window(*dst, Steps(num_elems_written_per_iteration));
Giorgio Arenac0f54432018-03-16 14:02:34 +0000960
Manuel Bottini327225d2021-04-13 13:09:30 +0100961 AccessWindowRectangle input_access(src, -conv_pad_left, -conv_pad_top,
Giorgio Arenac0f54432018-03-16 14:02:34 +0000962 num_elems_read_per_iteration, kernel_size,
963 conv_stride_x, conv_stride_y);
964 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
Manuel Bottini327225d2021-04-13 13:09:30 +0100965 AccessWindowHorizontal output_access(dst, 0, num_elems_written_per_iteration);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000966 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
Manuel Bottini327225d2021-04-13 13:09:30 +0100967 output_access.set_valid_region(win, ValidRegion(Coordinates(), dst->tensor_shape()));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000968 }
Giorgio Arenac0f54432018-03-16 14:02:34 +0000969 else
970 {
Manuel Bottini87350f42020-09-15 13:03:34 +0100971 // Configure window NHWC without any padding
Manuel Bottini327225d2021-04-13 13:09:30 +0100972 win = calculate_max_window(*dst, Steps());
Giorgio Arenac0f54432018-03-16 14:02:34 +0000973 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000974
975 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
976 return std::make_pair(err, win);
977}
Manuel Bottini87350f42020-09-15 13:03:34 +0100978
Michalis Spyroub55f8e82021-07-22 11:23:11 +0100979bool have_zero_x_internal_padding(ITensorInfo *src, const ITensorInfo *weights)
Manuel Bottini87350f42020-09-15 13:03:34 +0100980{
Manuel Bottini327225d2021-04-13 13:09:30 +0100981 return (src->padding().left == 0 && weights->padding().left == 0 && src->padding().right == 0 && weights->padding().right == 0);
Manuel Bottini87350f42020-09-15 13:03:34 +0100982}
983
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100984} // namespace
985
Manuel Bottini87350f42020-09-15 13:03:34 +0100986template <typename T>
Manuel Bottinib4bb6a02021-05-24 16:01:32 +0100987void CpuDirectConv2dKernel::convolve_nhwc_optimized(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst)
Manuel Bottini87350f42020-09-15 13:03:34 +0100988{
989 // This function assumes that input and weights have not padding in channel
990
991 // Declare useful types
992 using vtype = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
993 using vector_type = typename vtype::type;
994 using tag_type = typename vtype::tag_type;
995
996 // Scalar quantities
Manuel Bottini327225d2021-04-13 13:09:30 +0100997 const int element_size = src->info()->element_size();
998 const int input_stride_w = src->info()->strides_in_bytes().y() / element_size;
999 const int input_stride_h = src->info()->strides_in_bytes().z() / element_size;
1000 const int input_stride_n = src->info()->strides_in_bytes()[3] / element_size;
1001 const int input_dim_w = src->info()->dimension(1);
1002 const int input_dim_h = src->info()->dimension(2);
Manuel Bottini87350f42020-09-15 13:03:34 +01001003
Manuel Bottini327225d2021-04-13 13:09:30 +01001004 const int output_stride_c = dst->info()->strides_in_bytes().x();
Manuel Bottini87350f42020-09-15 13:03:34 +01001005
Manuel Bottini327225d2021-04-13 13:09:30 +01001006 const unsigned int kernel_stride_w = weights->info()->strides_in_bytes().y() / element_size;
1007 const unsigned int kernel_stride_h = weights->info()->strides_in_bytes().z() / element_size;
1008 const int kernel_dim_w = weights->info()->dimension(1);
1009 const int kernel_dim_h = weights->info()->dimension(2);
Manuel Bottini87350f42020-09-15 13:03:34 +01001010
1011 const int conv_pad_top = _conv_info.pad_top();
1012 const int conv_pad_left = _conv_info.pad_left();
1013 const int conv_stride_w = std::get<0>(_conv_info.stride());
1014 const int conv_stride_h = std::get<1>(_conv_info.stride());
1015
1016 // Setup input window for the output iterator
1017 Window window_out = window;
1018 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
1019
1020 // Setup input window for the weights iterator
Manuel Bottini327225d2021-04-13 13:09:30 +01001021 Window window_w = calculate_max_window(*weights->info(), Steps());
Manuel Bottini87350f42020-09-15 13:03:34 +01001022 window_w.set(Window::DimX, Window::Dimension(0, 1, 1));
1023 window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
1024 window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
1025
Manuel Bottini327225d2021-04-13 13:09:30 +01001026 Iterator out(dst, window_out);
1027 Iterator wei(weights, window_w);
Manuel Bottini87350f42020-09-15 13:03:34 +01001028
1029 constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
1030 /*
1031 * This implementation parallelize the full WC plane of input and weights by
1032 * treating them as series of elements. So for example, a 3x3 weights and
1033 * floating point vector operations of 4 elements per time, the first 3
1034 * channel elements of the first row would be taken and additionally the first
1035 * element of the second row. The 9 elements in each single WC weight plane
1036 * would require 2 4-element vector operations and a last single element operation.
1037 *
1038 * This works since when we create the input vector to multiply with the weights,
1039 * the exact required elements are loaded in the same order. Therefore the
1040 * multiplication works on the correct input/weight elements.
1041 */
1042 execute_window_loop(window_out, [&](const Coordinates & id)
1043 {
1044 /*
1045 * In here we create theoretical indexes which then we validate for both
1046 * inputs and weights.
1047 * As a reminder, this loop take each output point in NHW, C is treated
1048 * in the weights loop.
1049 */
1050 // We are computing the theoretical starting input starting points
1051 const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
1052 const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
1053 const int in_w_end_t = in_w_start_t + kernel_dim_w;
1054 const int in_h_end_t = in_h_start_t + kernel_dim_h;
1055
1056 // We are computing the valid initial and ending input points by checking the borders
1057 const int in_w_start = std::max(in_w_start_t, 0);
1058 const int in_h_start = std::max(in_h_start_t, 0);
1059 const int in_w_end = std::min(in_w_end_t, input_dim_w);
1060 const int in_h_end = std::min(in_h_end_t, input_dim_h);
1061
1062 // We use the input points to select the valid weight points to use
1063 const int index_wc_start = (in_w_start - in_w_start_t) * kernel_stride_w;
1064 const int index_h_start = in_h_start - in_h_start_t;
1065 const int index_wc_end = (kernel_dim_w - (in_w_end_t - in_w_end)) * kernel_stride_w;
1066 const int index_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
1067
1068 execute_window_loop(window_w, [&](const Coordinates & id_w)
1069 {
1070 /*
1071 * This is the loop in the weights, and it goes along N (the batches)
1072 * As a reminder, the batches of the weights are translated into the
1073 * channels of the output
1074 */
Manuel Bottini327225d2021-04-13 13:09:30 +01001075 const T *in_ptr_row = reinterpret_cast<const T *>(src->buffer() + src->info()->offset_first_element_in_bytes())
Manuel Bottini87350f42020-09-15 13:03:34 +01001076 + id[3] * input_stride_n + in_w_start * input_stride_w + in_h_start * input_stride_h;
1077 const T *weights_ptr_row = reinterpret_cast<const T *>(wei.ptr()) + index_h_start * kernel_stride_h;
1078 uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c;
1079
1080 T out_temp = static_cast<T>(0);
1081 for(int index_h = index_h_start; index_h < index_h_end; ++index_h, in_ptr_row += input_stride_h, weights_ptr_row += kernel_stride_h)
1082 {
1083 const T *in_ptr_mover = in_ptr_row;
1084 int index_wc = index_wc_start;
1085 vector_type out_temp_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
1086 for(; index_wc <= index_wc_end - num_elems_read_per_iteration; index_wc += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration)
1087 {
1088 const auto src_vec = wrapper::vloadq(in_ptr_mover);
1089 const auto w_vec = wrapper::vloadq(weights_ptr_row + index_wc);
1090 out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec);
1091 }
1092 out_temp += vreduce(out_temp_vec);
1093 for(; index_wc < index_wc_end; ++index_wc, ++in_ptr_mover)
1094 {
1095 const auto src_val = *(in_ptr_mover);
1096 const auto w_val = *(weights_ptr_row + index_wc);
1097 out_temp += src_val * w_val;
1098 }
1099 }
1100 *(reinterpret_cast<T *>(out_ptr)) = out_temp;
1101 },
1102 wei);
1103 },
1104 out);
1105}
1106
1107template <typename T>
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001108void CpuDirectConv2dKernel::convolve_nhwc(const Window &window, const ITensor *src, const ITensor *weights, ITensor *dst)
Manuel Bottini87350f42020-09-15 13:03:34 +01001109{
1110 // Declare useful types
1111 using vtype = wrapper::traits::neon_bitvector<T, wrapper::traits::BitWidth::W128>;
1112 using vector_type = typename vtype::type;
1113 using tag_type = typename vtype::tag_type;
1114
1115 // Scalar quantities
Manuel Bottini327225d2021-04-13 13:09:30 +01001116 const int element_size = src->info()->element_size();
1117 const int input_stride_w = src->info()->strides_in_bytes().y() / element_size;
1118 const int input_stride_h = src->info()->strides_in_bytes().z() / element_size;
1119 const int input_stride_n = src->info()->strides_in_bytes()[3] / element_size;
1120 const int input_dim_w = src->info()->dimension(1);
1121 const int input_dim_h = src->info()->dimension(2);
Manuel Bottini87350f42020-09-15 13:03:34 +01001122
Manuel Bottini327225d2021-04-13 13:09:30 +01001123 const int output_stride_c = dst->info()->strides_in_bytes().x();
Manuel Bottini87350f42020-09-15 13:03:34 +01001124
Manuel Bottini327225d2021-04-13 13:09:30 +01001125 const unsigned int kernel_stride_w = weights->info()->strides_in_bytes().y() / element_size;
1126 const unsigned int kernel_stride_h = weights->info()->strides_in_bytes().z() / element_size;
1127 const int kernel_dim_w = weights->info()->dimension(1);
1128 const int kernel_dim_h = weights->info()->dimension(2);
Manuel Bottini87350f42020-09-15 13:03:34 +01001129
1130 const int conv_pad_top = _conv_info.pad_top();
1131 const int conv_pad_left = _conv_info.pad_left();
1132 const int conv_stride_w = std::get<0>(_conv_info.stride());
1133 const int conv_stride_h = std::get<1>(_conv_info.stride());
1134
1135 // Setup input window for the output iterator
1136 Window window_out = window;
1137 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
1138
1139 // Setup input window for the weights iterator
Manuel Bottini327225d2021-04-13 13:09:30 +01001140 Window window_w = calculate_max_window(*weights->info(), Steps());
Manuel Bottini87350f42020-09-15 13:03:34 +01001141 window_w.set(Window::DimX, Window::Dimension(0, 1, 1));
1142 window_w.set(Window::DimY, Window::Dimension(0, 1, 1));
1143 window_w.set(Window::DimZ, Window::Dimension(0, 1, 1));
1144
Manuel Bottini327225d2021-04-13 13:09:30 +01001145 Iterator out(dst, window_out);
1146 Iterator wei(weights, window_w);
Manuel Bottini87350f42020-09-15 13:03:34 +01001147
1148 constexpr int num_elems_read_per_iteration = 16 / sizeof(T);
1149
1150 execute_window_loop(window_out, [&](const Coordinates & id)
1151 {
1152 // We are computing the theoretical starting input starting points
1153 const int in_w_start_t = static_cast<int>(id.y()) * conv_stride_w - conv_pad_left;
1154 const int in_h_start_t = static_cast<int>(id.z()) * conv_stride_h - conv_pad_top;
1155 const int in_w_end_t = in_w_start_t + kernel_dim_w;
1156 const int in_h_end_t = in_h_start_t + kernel_dim_h;
1157
1158 // We are computing the valid initial and ending input points by checking the borders
1159 const int in_w_start = std::max(in_w_start_t, 0);
1160 const int in_h_start = std::max(in_h_start_t, 0);
1161 const int in_w_end = std::min(in_w_end_t, input_dim_w);
1162 const int in_h_end = std::min(in_h_end_t, input_dim_h);
1163
1164 // We use the input points to select the valid weight points to use
1165 const int wei_w_start = in_w_start - in_w_start_t;
1166 const int wei_h_start = in_h_start - in_h_start_t;
1167 const int wei_w_end = kernel_dim_w - (in_w_end_t - in_w_end);
1168 const int wei_h_end = kernel_dim_h - (in_h_end_t - in_h_end);
1169
Manuel Bottini327225d2021-04-13 13:09:30 +01001170 const int index_c_end = weights->info()->dimension(0);
1171 const T *const in_ptr_start = reinterpret_cast<const T *>(src->buffer() + src->info()->offset_first_element_in_bytes()) + id[3] * input_stride_n;
Manuel Bottini87350f42020-09-15 13:03:34 +01001172
1173 execute_window_loop(window_w, [&](const Coordinates & id_w)
1174 {
1175 const T *const weights_ptr_start = reinterpret_cast<const T *>(wei.ptr());
1176 uint8_t *out_ptr = out.ptr() + id_w[3] * output_stride_c;
1177
1178 T out_temp = static_cast<T>(0);
1179 for(int index_wei_h = wei_h_start, index_in_h = in_h_start; index_wei_h < wei_h_end; ++index_wei_h, ++index_in_h)
1180 {
1181 const T *const in_ptr_row = in_ptr_start + index_in_h * input_stride_h;
1182 const T *const weights_ptr_row = weights_ptr_start + index_wei_h * kernel_stride_h;
1183 for(int index_wei_w = wei_w_start, index_in_w = in_w_start; index_wei_w < wei_w_end; ++index_wei_w, ++index_in_w)
1184 {
1185 const T *in_ptr_mover = in_ptr_row + index_in_w * input_stride_w;
1186 const T *weights_ptr_mover = weights_ptr_row + index_wei_w * kernel_stride_w;
1187 int index_c = 0;
1188 vector_type out_temp_vec = wrapper::vdup_n(static_cast<T>(0), tag_type());
1189 for(; index_c <= index_c_end - num_elems_read_per_iteration; index_c += num_elems_read_per_iteration, in_ptr_mover += num_elems_read_per_iteration, weights_ptr_mover += num_elems_read_per_iteration)
1190 {
1191 const auto src_vec = wrapper::vloadq(in_ptr_mover);
1192 const auto w_vec = wrapper::vloadq(weights_ptr_mover);
1193 out_temp_vec = wrapper::vmla(out_temp_vec, w_vec, src_vec);
1194 }
1195 out_temp += vreduce(out_temp_vec);
1196 for(; index_c < index_c_end; ++index_c, ++in_ptr_mover, ++weights_ptr_mover)
1197 {
1198 const auto src_val = *(in_ptr_mover);
1199 const auto w_val = *(weights_ptr_mover);
1200 out_temp += src_val * w_val;
1201 }
1202 }
1203 }
1204 *(reinterpret_cast<T *>(out_ptr)) = out_temp;
1205 },
1206 wei);
1207 },
1208 out);
1209}
1210
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001211BorderSize CpuDirectConv2dKernel::border_size() const
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001212{
1213 return _border_size;
1214}
1215
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001216void CpuDirectConv2dKernel::configure(ITensorInfo *src, ITensorInfo *weights, ITensorInfo *dst, const PadStrideInfo &conv_info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001217{
Manuel Bottini327225d2021-04-13 13:09:30 +01001218 ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001219
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001220 _conv_info = conv_info;
Manuel Bottini327225d2021-04-13 13:09:30 +01001221 _data_layout = src->data_layout();
1222 _kernel_size = weights->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001223
1224 const unsigned int conv_pad_left = conv_info.pad_left();
1225 const unsigned int conv_pad_top = conv_info.pad_top();
1226 const unsigned int conv_pad_right = conv_info.pad_right();
1227 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
Manuel Bottinica62c6f2021-03-23 11:50:34 +00001228 if(_data_layout == DataLayout::NCHW)
Manuel Bottini87350f42020-09-15 13:03:34 +01001229 {
1230 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
1231 }
1232 else
1233 {
1234 _border_size = BorderSize(0);
1235 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001236
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001237 // Get convolved dimensions
Manuel Bottini327225d2021-04-13 13:09:30 +01001238 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*src, *weights, conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001239
Manuel Bottini327225d2021-04-13 13:09:30 +01001240 DataType data_type = src->data_type();
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001241
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001242 // Output auto inizialitation if not yet initialized
Manuel Bottini327225d2021-04-13 13:09:30 +01001243 auto_init_if_empty(*dst, output_shape, 1, data_type);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001244
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001245 // Perform validation step
Manuel Bottini327225d2021-04-13 13:09:30 +01001246 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src, weights, dst, conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001247
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001248 // Configure kernel window
Manuel Bottini327225d2021-04-13 13:09:30 +01001249 auto win_config = validate_and_configure_window(src, weights, dst, conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001250 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001251 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
Manuel Bottini327225d2021-04-13 13:09:30 +01001252 ICpuKernel::configure(win_config.second);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001253}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001254
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001255Status CpuDirectConv2dKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const PadStrideInfo &conv_info)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001256{
1257 unsigned int num_weight_elems_read_per_row = 0;
1258 unsigned int num_elems_read_per_iteration = 0;
1259 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas15997872018-02-19 13:58:22 +00001260 BorderSize border_size = {};
Manuel Bottini327225d2021-04-13 13:09:30 +01001261 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src, weights, dst, conv_info));
1262 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(src->clone().get(),
Georgios Pinitas0223a782017-12-12 11:44:44 +00001263 weights->clone().get(),
Manuel Bottini327225d2021-04-13 13:09:30 +01001264 dst->clone().get(),
Georgios Pinitas0223a782017-12-12 11:44:44 +00001265 conv_info,
1266 num_weight_elems_read_per_row,
1267 num_elems_read_per_iteration,
1268 num_elems_written_per_iteration,
1269 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001270 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001271
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001272 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001273}
1274
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001275void CpuDirectConv2dKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001276{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001277 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001278 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
Manuel Bottini327225d2021-04-13 13:09:30 +01001279 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001280
Manuel Bottini327225d2021-04-13 13:09:30 +01001281 auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0);
1282 auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1);
1283 auto dst = tensors.get_tensor(TensorType::ACL_DST);
1284 const int kernel_size = weights->info()->dimension(get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001285
Manuel Bottinica62c6f2021-03-23 11:50:34 +00001286 if(_data_layout == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001287 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001288 switch(kernel_size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001289 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001290 case 1:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001291 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001292 switch(src->info()->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +00001293 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001294 case DataType::F32:
Manuel Bottini327225d2021-04-13 13:09:30 +01001295 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001296 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001297#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001298 case DataType::F16:
Manuel Bottini327225d2021-04-13 13:09:30 +01001299 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001300 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001301#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001302 default:
1303 ARM_COMPUTE_ERROR("Data type not supported");
1304 break;
1305 }
1306 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001307 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001308 case 3:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001309 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001310 switch(src->info()->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +00001311 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001312 case DataType::F32:
Manuel Bottini327225d2021-04-13 13:09:30 +01001313 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001314 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001315#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001316 case DataType::F16:
Manuel Bottini327225d2021-04-13 13:09:30 +01001317 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001318 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001319#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001320 default:
1321 ARM_COMPUTE_ERROR("Data type not supported");
1322 break;
1323 }
1324 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001325 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001326 case 5:
Pablo Tello06da39d2017-08-10 15:10:40 +01001327 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001328 switch(src->info()->data_type())
Giorgio Arenac0f54432018-03-16 14:02:34 +00001329 {
1330 case DataType::F32:
Manuel Bottini327225d2021-04-13 13:09:30 +01001331 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, src, weights, dst, _conv_info);
Giorgio Arenac0f54432018-03-16 14:02:34 +00001332 break;
1333 default:
1334 ARM_COMPUTE_ERROR("Data type not supported");
1335 break;
1336 }
1337 break;
Pablo Tello06da39d2017-08-10 15:10:40 +01001338 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001339 default:
1340 {
1341 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
1342 break;
1343 }
1344 }
1345 }
1346 else
1347 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001348 switch(src->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001349 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001350 case DataType::F32:
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001351 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001352 if(have_zero_x_internal_padding(src->info(), weights->info()))
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001353 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001354 convolve_nhwc_optimized<float>(window, src, weights, dst);
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001355 }
1356 else
1357 {
Manuel Bottini327225d2021-04-13 13:09:30 +01001358 convolve_nhwc<float>(window, src, weights, dst);
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001359 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001360 break;
Gian Marco Iodice95f93612019-06-13 15:58:32 +01001361 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001362 default:
1363 ARM_COMPUTE_ERROR("Data type not supported");
1364 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001365 }
1366 }
1367}
Manuel Bottinib4bb6a02021-05-24 16:01:32 +01001368const char *CpuDirectConv2dKernel::name() const
Manuel Bottini327225d2021-04-13 13:09:30 +01001369{
1370 return "CpuDirectConvolutionLayerKernel";
1371}
1372} // namespace kernels
1373} // namespace cpu
Sheri Zhangac6499a2021-02-10 15:32:38 +00001374} // namespace arm_compute