blob: f525d93e83d1dc0af5f8eba682a8f52b6c9cfd80 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrou621965e2018-01-08 17:11:26 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
Georgios Pinitas4074c992018-01-30 18:13:46 +000025#include "arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
27#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010028#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Error.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/ITensor.h"
33#include "arm_compute/core/NEON/NEFixedPoint.h"
34#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010035#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include "arm_compute/core/Validate.h"
Giorgio Arenac0f54432018-03-16 14:02:34 +000037#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
39#include <algorithm>
40#include <arm_neon.h>
41
42using namespace arm_compute;
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010043using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010044
45namespace
46{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000047#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010048template <unsigned int stridex>
49float16x8_t internal_vld1q(const float16_t *in);
50
51template <>
52float16x8_t internal_vld1q<1>(const float16_t *in)
53{
54 return vld1q_f16(in);
55}
56
57template <>
58float16x8_t internal_vld1q<2>(const float16_t *in)
59{
60 const float16x8x2_t tmp = vld2q_f16(in);
61 return tmp.val[0];
62}
63
64template <>
65float16x8_t internal_vld1q<3>(const float16_t *in)
66{
67 const float16x8x3_t tmp = vld3q_f16(in);
68 return tmp.val[0];
69}
70
71inline float16x8_t internal_vdupq_n(float16_t v)
72{
73 return vdupq_n_f16(v);
74}
75
76inline void internal_vst1q(float16_t *p, const float16x8_t &v)
77{
78 vst1q_f16(p, v);
79}
80
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010081float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y)
Pablo Tello0d176142017-07-06 16:43:14 +010082{
Pablo Tello0d176142017-07-06 16:43:14 +010083 return vmulq_f16(x, y);
84}
85
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010086inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z)
Pablo Tello0d176142017-07-06 16:43:14 +010087{
Pablo Tello0d176142017-07-06 16:43:14 +010088 return vaddq_f16(x, vmulq_f16(y, z));
89}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000090#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +010091
Anthony Barbier6ff3b192017-09-04 18:44:23 +010092template <unsigned int stridex>
93float32x4_t internal_vld1q(const float *in);
94
95template <>
96float32x4_t internal_vld1q<1>(const float *in)
97{
98 return vld1q_f32(in);
99}
100
101template <>
102float32x4_t internal_vld1q<2>(const float *in)
103{
104 const float32x4x2_t tmp = vld2q_f32(in);
105 return tmp.val[0];
106}
107
108template <>
109float32x4_t internal_vld1q<3>(const float *in)
110{
111 const float32x4x3_t tmp = vld3q_f32(in);
112 return tmp.val[0];
113}
114
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100115inline float32x4_t internal_vdupq_n(float v)
116{
117 return vdupq_n_f32(v);
118}
119
120inline void internal_vst1q(float *p, const float32x4_t &v)
121{
122 vst1q_f32(p, v);
123}
124
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100125float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100126{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100127 return vmulq_f32(x, y);
128}
129
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100130inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100131{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100132 return vmlaq_f32(x, y, z);
133}
134
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000135constexpr int small_tensor_size_optim = 8;
136inline bool run_optim_small_tensor_info(const ITensorInfo *t)
137{
138 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
139}
140
Pablo Telloc09314a2017-09-21 13:59:14 +0100141inline bool run_optim_small_tensor(const ITensor *t)
142{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000143 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100144}
145
146// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
147// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
148// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
149template <unsigned int stridex>
150class convolver_w1x1_i8x8_f32
151{
152public:
153 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
154 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000155 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim);
156 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100157
Georgios Pinitas15997872018-02-19 13:58:22 +0000158 const int input_stride_x = input->info()->strides_in_bytes().x();
Pablo Telloc09314a2017-09-21 13:59:14 +0100159 const int input_stride_y = input->info()->strides_in_bytes().y();
160 const int input_stride_z = input->info()->strides_in_bytes().z();
161 const int output_stride_y = output->info()->strides_in_bytes().y();
162 const int output_stride_z = output->info()->strides_in_bytes().z();
163 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
164 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
165 const int output_h = output->info()->dimension(1);
166 const int range_z = window.z().end() - window.z().start();
167 const int kernel_depth = weights->info()->dimension(Window::DimZ);
168 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000169 const unsigned int conv_pad_left = conv_info.pad_left();
170 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Telloc09314a2017-09-21 13:59:14 +0100171
172 // setup output window for the iterator
173 Window window_out = window;
174 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
175 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
176 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
177
178 // setup input window for the iterator
179 Window window_in = window;
180 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
181 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
182 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
183 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
184
185 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
186 Iterator out(output, window_out);
187 Iterator in(input, window_in);
188 Iterator k(weights, window_k);
189
190 const uint8_t *k_ptr = k.ptr();
191
192 execute_window_loop(window_out, [&](const Coordinates & id)
193 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000194 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000195 uint8_t *out_ptr = out.ptr();
196 int ih = 0;
197 int oh = 0;
198 float32x4_t accum0[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
199 float32x4_t accum1[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100200 for(int oz = 0; oz < range_z; ++oz)
201 {
202 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
203 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
204 auto p_out_base = out_ptr + oz * output_stride_z;
205 for(int p = 0; p < kernel_depth; ++p)
206 {
207 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
208 const auto vk0 = internal_vdupq_n(*k_val);
209 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
210 {
211 const int offset_xy = ih * input_stride_y;
212 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
213 auto v_in0 = internal_vld1q<stridex>(in_val);
214 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
215 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
216 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
217 }
218 }
219 for(oh = 0; oh < output_h; ++oh)
220 {
221 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
222 vst1q_f32(p_out, accum0[oh]);
223 vst1q_f32(p_out + 4, accum1[oh]);
224 }
225 }
226 },
227 in, out);
228 }
229};
230
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100231template <typename T1, typename T2, unsigned int stridex>
232class convolver_1x1
233{
234public:
235 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
236 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
237 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100238 const int input_stride_x = input->info()->strides_in_bytes().x();
239 const int input_stride_y = input->info()->strides_in_bytes().y();
240 const int input_stride_z = input->info()->strides_in_bytes().z();
241 const int output_stride_y = output->info()->strides_in_bytes().y();
242 const int output_stride_z = output->info()->strides_in_bytes().z();
243 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
244 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
245 const int output_w = output->info()->dimension(0);
246 const int output_h = output->info()->dimension(1);
247 const int range_z = window.z().end() - window.z().start();
248 const int kernel_depth = weights->info()->dimension(Window::DimZ);
249 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
250 const unsigned int conv_pad_left = conv_info.pad_left();
251 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100252
253 // setup output window for the iterator
254 Window window_out = window;
255 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
256 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
257 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
258
259 // setup input window for the iterator
260 Window window_in = window;
261 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
262 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
263 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
264 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
265
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100266 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100267 Iterator out(output, window_out);
268 Iterator in(input, window_in);
269 Iterator k(weights, window_k);
270
271 const uint8_t *k_ptr = k.ptr();
272
273 execute_window_loop(window_out, [&](const Coordinates & id)
274 {
275 /*
276 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
277 */
Georgios Pinitas15997872018-02-19 13:58:22 +0000278 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100279 uint8_t *out_ptr = out.ptr();
280 int ih = 0;
281 int oh = 0;
282 for(int oz = 0; oz < range_z; ++oz)
283 {
284 auto p_out_base = out_ptr + oz * output_stride_z;
285 // Step 1
286 {
287 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
288 const auto vk = internal_vdupq_n(*k_val);
289 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
290 {
291 const int offset_xy = ih * input_stride_y;
292 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
293 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
294 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
295 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100296 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100297 }
298 }
299 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100300
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100301 // Step 2
302 for(int p = 1; p < kernel_depth; ++p)
303 {
304 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
305 const auto vk = internal_vdupq_n(*k_val);
306 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
307 {
308 const int offset_xy = ih * input_stride_y;
309 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
310 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
311 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
312 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100313 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100314 }
315 }
316 }
317 }
318 },
319 in, out);
320 }
321};
322
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000323#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100324
325template <unsigned int stridex>
326void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
327
328template <>
329void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
330{
331 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
332 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
333}
334
335template <>
336void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
337{
338 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
339}
340
341template <>
342void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
343{
344 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
345}
346
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000347#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100348
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100349template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100350float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100351 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100352
353inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
354{
355 const float32x4x3_t m00 =
356 {
357 {
358 vld1q_dup_f32(m0),
359 vld1q_dup_f32(m1),
360 vld1q_dup_f32(m2)
361 }
362 };
363 return m00;
364}
365
366inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
367{
368 const float32x4x2_t m00 =
369 {
370 {
371 vld1q_dup_f32(m3),
372 vld1q_dup_f32(m4)
373 }
374 };
375 return m00;
376}
377
378inline float32x4x3_t load_input(const float *const in)
379{
380 const float32x4x3_t vin =
381 {
382 {
383 vld1q_f32(in),
384 vld1q_f32(in + 4),
385 vld1q_f32(in + 8)
386 }
387 };
388 return vin;
389}
390
391template <>
392inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100393 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100394{
Pablo Tello06da39d2017-08-10 15:10:40 +0100395 const float32x4x3_t vin0 = load_input(in_0);
396 const float32x4x3_t vin1 = load_input(in_1);
397 const float32x4x3_t vin2 = load_input(in_2);
398 const float32x4x3_t vin3 = load_input(in_3);
399 const float32x4x3_t vin4 = load_input(in_4);
400 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
401 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
402 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
403 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
404 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
405 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
406 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
407 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
408 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
409 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
410
411 float32x4x2_t out =
412 {
413 {
414 vmulq_f32(vin0.val[0], m00.val[0]),
415 vmulq_f32(vin0.val[1], m00.val[0])
416 }
417 };
418
419 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
420 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
421 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
422 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
423
424 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
425 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
426 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
427 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
428 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
429
430 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
431 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
432 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
433 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
434 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
435
436 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
437 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
438 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
439 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
440 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
441
442 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
443 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
444 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
445 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
446 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
447
448 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
449 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
450 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
451 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
452
453 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
454 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
455 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
456 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
457 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
458
459 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
460 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
461 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
462 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
463 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
464
465 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
466 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
467 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
468 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
469 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
470
471 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
472 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
473 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
474 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
475 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
476
477 return out;
478}
479
480template <>
481inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100482 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100483{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100484 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100485 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
486 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
487 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
488 return out;
489}
490
491template <>
492inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100493 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100494{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100495 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100496 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
497 return out;
498}
499
500template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100501void accumulate_results(float *buffer, const float32x4x2_t &values);
502
503template <>
504void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
505{
506 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
507 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
508}
509
510template <>
511void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
512{
513 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
514}
515
516template <>
517void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
518{
519 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
520}
521
Giorgio Arenac0f54432018-03-16 14:02:34 +0000522template <typename T1>
523class convolver_nhwc
524{
525public:
526 static void convolve(const Window &window, int kernel_size, unsigned int num_elems_read_per_iteration,
527 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
528 {
529 const int input_width = input->info()->dimension(0);
530 const int input_depth = input->info()->dimension(2);
531 const int input_stride_x = input->info()->strides_in_bytes().x();
532 const int input_stride_y = input->info()->strides_in_bytes().y();
533 const int input_stride_z = input->info()->strides_in_bytes().z();
534 const int output_stride_x = output->info()->strides_in_bytes().x();
535 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
536 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
537 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
538 const int conv_pad_top = conv_info.pad_top();
539 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
540 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
541 const T1 zero = 0;
542
543 // Setup input window for the input iterator
544 Window window_in = window;
545 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
546 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
547 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
548
549 // Setup input window for the output iterator
550 Window window_out = window;
551 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
552
553 // Setup input window for the weights iterator
554 Window window_k = calculate_max_window(*weights->info(), Steps());
555 window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
556 window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
557 window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
558 window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
559
560 Iterator in(input, window_in);
561 Iterator out(output, window_out);
562 Iterator k(weights, window_k);
563
564 execute_window_loop(window_k, [&](const Coordinates & id_k)
565 {
566 execute_window_loop(window_out, [&](const Coordinates & id)
567 {
568 const auto in_y = static_cast<int>(id.y() * conv_stride_x - conv_info.pad_left());
569 const auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
570
571 const uint8_t *in_ptr = in.ptr() + in_y * input_stride_y + in_z * input_stride_z;
572 uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x;
573
574 T1 out_val = 0;
575
576 auto in_addr_base0 = in_ptr;
577 auto we_addr_base0 = k.ptr();
578
579 for(int z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z)
580 {
581 const int in_z = id.z() * conv_stride_y + z - conv_pad_top;
582
583 if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom
584 {
585 auto in_addr_base1 = in_addr_base0;
586 auto we_addr_base1 = we_addr_base0;
587
588 for(int y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y)
589 {
590 auto out_values = internal_vdupq_n(zero);
591
592 int x = 0;
593 int no_leftover = input_width - num_elems_read_per_iteration;
594
595 for(; x < no_leftover; x += num_elems_read_per_iteration)
596 {
597 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
598 const auto in_values = internal_vld1q<1>(in_addr);
599
600 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
601 const auto we_values = internal_vld1q<1>(we_addr);
602
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100603 out_values = internal_vmlal(out_values, in_values, we_values);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000604 }
605
606 out_val += out_values[0];
607 out_val += out_values[1];
608 out_val += out_values[2];
609 out_val += out_values[3];
610
611 // Leftover
612 for(; x < input_width; ++x)
613 {
614 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
615 const auto in_value = *(in_addr);
616
617 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
618 const auto we_value = *(we_addr);
619
620 out_val += in_value * we_value;
621 }
622 }
623 }
624 }
625
626 *(reinterpret_cast<T1 *>(out_ptr)) = out_val;
627 },
628 in, out);
629 },
630 k);
631 }
632};
633
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634template <typename T1, typename T2, unsigned int stridex>
635class convolver_3x3
636{
637public:
638 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
639 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
640 {
641 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100642 const int input_stride_x = input->info()->strides_in_bytes().x();
643 const int input_stride_y = input->info()->strides_in_bytes().y();
644 const int input_stride_z = input->info()->strides_in_bytes().z();
645 const int output_stride_y = output->info()->strides_in_bytes().y();
646 const int output_stride_z = output->info()->strides_in_bytes().z();
647 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
648 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
649 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
650 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
651 const int output_w = output->info()->dimension(0);
652 const int output_h = output->info()->dimension(1);
653 const int num_planes_z = window.z().end() - window.z().start();
654 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
655 const int kernel_depth = weights->info()->dimension(Window::DimZ);
656 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
657 const unsigned int conv_pad_left = conv_info.pad_left();
658 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100659
660 // setup output window for the iterator
661 Window window_out = window;
662 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
663 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
664 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
665
666 // setup input window for the iterator
667 Window window_in = window;
668 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
669 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
670 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
671 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
672
673 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
674
675 Iterator out(output, window_out);
676 Iterator in(input, window_in);
677 Iterator k(weights, window_k);
678
679 const uint8_t *k_ptr = k.ptr();
680
681 execute_window_loop(window_out, [&](const Coordinates & id)
682 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000683 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100684 uint8_t *out_ptr = out.ptr();
685 int ih = 0;
686 int oh = 0;
687 /*
688 Each thread executing this kernel computes one or more output's volume planes.
689
690 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
691 the third thread [16,24] and the fourth thread [25,31].
692
693 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100694 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100695
696 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
697 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
698 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
699 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100700 for(int oz = 0; oz < num_planes_z; ++oz)
701 {
Pablo Tello0d176142017-07-06 16:43:14 +0100702 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100703 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
704 // Step 1
705 {
Pablo Tello0d176142017-07-06 16:43:14 +0100706 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
707 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
708 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100709 const auto vk_r0 = load_matrix_row(ptr_k_r0);
710 const auto vk_r1 = load_matrix_row(ptr_k_r1);
711 const auto vk_r2 = load_matrix_row(ptr_k_r2);
712 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
713 {
714 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
715 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
716 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
717 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
718 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
719 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
720 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100721 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100722 store_results<stridex>(p_out, vres);
723 }
724 }
725 }
726 // Step 2
727 for(int p = 1; p < kernel_depth; ++p)
728 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100729 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
730 const uint8_t *input_base = input_ptr + p * input_stride_z;
731 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
732 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
733 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
734 const auto vk_r0 = load_matrix_row(ptr_k_r0);
735 const auto vk_r1 = load_matrix_row(ptr_k_r1);
736 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100737 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
738 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100739 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
740 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
741 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100742 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
743 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
744 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
745 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100746 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100747 accumulate_results<stridex>(p_out, vres);
748 }
749 }
750 }
751 }
752 },
753 in, out);
754 }
755};
756
Pablo Tello06da39d2017-08-10 15:10:40 +0100757template <typename T1, typename T2, unsigned int stridex>
758class convolver_5x5
759{
760public:
761 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
762 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
763 {
764 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100765 const int input_stride_x = input->info()->strides_in_bytes().x();
766 const int input_stride_y = input->info()->strides_in_bytes().y();
767 const int input_stride_z = input->info()->strides_in_bytes().z();
768 const int output_stride_y = output->info()->strides_in_bytes().y();
769 const int output_stride_z = output->info()->strides_in_bytes().z();
770 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
771 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
772 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
773 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
774 const int output_w = output->info()->dimension(0);
775 const int output_h = output->info()->dimension(1);
776 const int num_planes_z = window.z().end() - window.z().start();
777 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
778 const int kernel_depth = weights->info()->dimension(Window::DimZ);
779 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
780 const unsigned int conv_pad_left = conv_info.pad_left();
781 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Tello06da39d2017-08-10 15:10:40 +0100782
783 // setup output window for the iterator
784 Window window_out = window;
785 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
786 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
787 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
788
789 // setup input window for the iterator
790 Window window_in = window;
791 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
792 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
793 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
794 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
795
796 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
797
798 Iterator out(output, window_out);
799 Iterator in(input, window_in);
800 Iterator k(weights, window_k);
801
802 const uint8_t *k_ptr = k.ptr();
803
804 execute_window_loop(window_out, [&](const Coordinates & id)
805 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000806 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Pablo Tello06da39d2017-08-10 15:10:40 +0100807 uint8_t *out_ptr = out.ptr();
808 int ih = 0;
809 int oh = 0;
810 for(int oz = 0; oz < num_planes_z; ++oz)
811 {
812 const int zoffset = id.z() + oz;
813 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
814 // Step 1
815 {
816 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
817 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
818 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
819 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
820 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
821 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
822 {
823 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
824 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
825 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
826 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
827 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
828 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
829 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
830 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
831 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100832 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100833 store_results<stridex>(p_out, vres);
834 }
835 }
836 }
837 // Step 2
838 for(int p = 1; p < kernel_depth; ++p)
839 {
840 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
841 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
842 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
843 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
844 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
845
846 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
847 {
848 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
849 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
850 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
851 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
852 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
853 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
854 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
855 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
856 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100857 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100858 accumulate_results<stridex>(p_out, vres);
859 }
860 }
861 }
862 }
863 },
864 in, out);
865 }
866};
867
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100868template <typename T1, typename T2>
869inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
870 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
871{
872 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
873 switch(conv_stride_x)
874 {
875 case 1:
876 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
877 break;
878 case 2:
879 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
880 break;
881 case 3:
882 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
883 break;
884 default:
885 ARM_COMPUTE_ERROR("Not implemented");
886 }
887}
888
Pablo Telloc09314a2017-09-21 13:59:14 +0100889template <>
890inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
891 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
892{
893 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
894 if(run_optim_small_tensor(input))
895 {
896 switch(conv_stride_x)
897 {
898 case 1:
899 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
900 break;
901 case 2:
902 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
903 break;
904 case 3:
905 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
906 break;
907 default:
908 ARM_COMPUTE_ERROR("Not implemented");
909 }
910 }
911 else
912 {
913 switch(conv_stride_x)
914 {
915 case 1:
916 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
917 break;
918 case 2:
919 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
920 break;
921 case 3:
922 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
923 break;
924 default:
925 ARM_COMPUTE_ERROR("Not implemented");
926 }
927 }
928}
929
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100930template <typename T1, typename T2>
931inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
932 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
933{
934 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
935 switch(conv_stride_x)
936 {
937 case 1:
938 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
939 break;
940 case 2:
941 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
942 break;
943 case 3:
944 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
945 break;
946 default:
947 ARM_COMPUTE_ERROR("Not implemented");
948 }
949}
Pablo Tello06da39d2017-08-10 15:10:40 +0100950
951template <typename T1, typename T2>
952inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
953 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
954{
955 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
956 switch(conv_stride_x)
957 {
958 case 1:
959 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
960 break;
961 case 2:
962 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
963 break;
964 case 3:
965 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
966 break;
967 default:
968 ARM_COMPUTE_ERROR("Not implemented");
969 }
970}
971
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000972Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
973{
974 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000975 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100976 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100977 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000978 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000979
Giorgio Arenac0f54432018-03-16 14:02:34 +0000980 const DataLayout data_layout = input->data_layout();
981 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
982 const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
983 const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
984
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000985 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Giorgio Arenac0f54432018-03-16 14:02:34 +0000986 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx));
987 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000988 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000989 ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
Gian Marco Iodice41acb762018-08-23 10:25:06 +0100990 ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (input->data_type() == DataType::F16));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000991
992 // Checks performed when output is configured
993 if(output->total_size() != 0)
994 {
Giorgio Arenac0f54432018-03-16 14:02:34 +0000995 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000996
997 DataType data_type = input->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000998
999 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
1000 ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
1001 }
1002
1003 return Status{};
1004}
1005
1006std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001007 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001008{
Giorgio Arenac0f54432018-03-16 14:02:34 +00001009 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
1010
1011 const DataLayout data_layout = input->data_layout();
1012 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
1013
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001014 // Calculate right and bottom border
Giorgio Arenac0f54432018-03-16 14:02:34 +00001015 unsigned int kernel_size = weights->dimension(width_idx);
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001016 const int conv_stride_x = std::get<0>(conv_info.stride());
Georgios Pinitas1a03d762018-02-21 14:47:09 +00001017 const int conv_stride_y = std::get<1>(conv_info.stride());
Giorgio Arenac0f54432018-03-16 14:02:34 +00001018 const int input_width = input->dimension(width_idx);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001019
Giorgio Arenac0f54432018-03-16 14:02:34 +00001020 Window win{};
1021 bool window_changed = false;
1022
1023 if(data_layout == DataLayout::NCHW)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001024 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001025 switch(kernel_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001026 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001027 case 1:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001028 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001029 switch(input->data_type())
1030 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001031#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001032 case DataType::F16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001033 num_elems_written_per_iteration = 8;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001034 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001035#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001036 case DataType::F32:
1037 if(run_optim_small_tensor_info(input))
1038 {
1039 num_elems_written_per_iteration = 8;
1040 }
1041 else
1042 {
1043 num_elems_written_per_iteration = 4;
1044 }
1045 break;
1046 default:
1047 ARM_COMPUTE_ERROR("Data type not supported.");
1048 break;
1049 }
1050 num_weight_elems_read_per_row = kernel_size;
1051 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
1052 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001053 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001054 case 3:
Giorgio Arenac0f54432018-03-16 14:02:34 +00001055 switch(input->data_type())
1056 {
1057 case DataType::F32:
1058 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1059 num_elems_read_per_iteration = 12;
1060 num_elems_written_per_iteration = 16 >> conv_stride_x;
1061 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001062#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001063 case DataType::F16:
Giorgio Arenac0f54432018-03-16 14:02:34 +00001064 num_weight_elems_read_per_row = 8 + kernel_size - 1;
1065 num_elems_read_per_iteration = 24;
1066 num_elems_written_per_iteration = 32 >> conv_stride_x;
1067 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001068#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001069 default:
1070 ARM_COMPUTE_ERROR("Data type not supported.");
1071 break;
1072 }
Gian Marco Iodice41acb762018-08-23 10:25:06 +01001073 break;
1074 case 5:
1075 {
1076 switch(input->data_type())
1077 {
1078 case DataType::F32:
1079 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1080 num_elems_read_per_iteration = 12;
1081 num_elems_written_per_iteration = 16 >> conv_stride_x;
1082 break;
1083 default:
1084 ARM_COMPUTE_ERROR("Data type not supported.");
1085 break;
1086 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001087 }
1088 break;
1089 default:
1090 {
1091 ARM_COMPUTE_ERROR("Not implemented");
1092 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001093 }
1094 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001095
1096 // Calculate right pad
1097 int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
1098 int end_x = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
1099 int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
1100
1101 // Calculate border
1102 const unsigned int conv_pad_left = conv_info.pad_left();
1103 const unsigned int conv_pad_top = conv_info.pad_top();
1104 const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
1105 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1106
1107 border_size.left = conv_pad_left;
1108 border_size.top = conv_pad_top;
1109 border_size.right = conv_pad_right;
1110 border_size.bottom = conv_pad_bottom;
1111
1112 // Configure window
1113 win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
1114
1115 AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
1116 num_elems_read_per_iteration, kernel_size,
1117 conv_stride_x, conv_stride_y);
1118 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
1119 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
1120 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1121 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001122 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001123 else
1124 {
1125 border_size.left = 0;
1126 border_size.top = conv_info.pad_left();
1127 border_size.right = 0;
1128 border_size.bottom = conv_info.pad_right();
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001129
Giorgio Arenac0f54432018-03-16 14:02:34 +00001130 num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type());
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001131
Giorgio Arenac0f54432018-03-16 14:02:34 +00001132 win = calculate_max_window(*output, Steps());
Michalis Spyrou621965e2018-01-08 17:11:26 +00001133
Giorgio Arenac0f54432018-03-16 14:02:34 +00001134 AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x);
1135 AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size);
1136 window_changed = update_window_and_padding(win, input_access, weights_access);
1137 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001138
1139 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1140 return std::make_pair(err, win);
1141}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001142} // namespace
1143
1144NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001145 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1146 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001147{
1148}
1149
1150BorderSize NEDirectConvolutionLayerKernel::border_size() const
1151{
1152 return _border_size;
1153}
1154
1155void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1156{
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001157 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001158
1159 _input = input;
1160 _weights = weights;
1161 _output = output;
1162 _conv_info = conv_info;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001163 _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001164
1165 const unsigned int conv_pad_left = conv_info.pad_left();
1166 const unsigned int conv_pad_top = conv_info.pad_top();
1167 const unsigned int conv_pad_right = conv_info.pad_right();
1168 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1169 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001170
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001171 // Get convolved dimensions
Giorgio Arenac0f54432018-03-16 14:02:34 +00001172 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001173
1174 DataType data_type = input->info()->data_type();
1175
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001176 // Output auto inizialitation if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001177 auto_init_if_empty(*output->info(), output_shape, 1, data_type);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001178
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001179 // Perform validation step
1180 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001181
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001182 // Configure kernel window
1183 auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001184 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001185 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1186 INEKernel::configure(win_config.second);
1187}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001188
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001189Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1190{
1191 unsigned int num_weight_elems_read_per_row = 0;
1192 unsigned int num_elems_read_per_iteration = 0;
1193 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas15997872018-02-19 13:58:22 +00001194 BorderSize border_size = {};
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001195 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info));
Georgios Pinitas0223a782017-12-12 11:44:44 +00001196 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1197 weights->clone().get(),
1198 output->clone().get(),
1199 conv_info,
1200 num_weight_elems_read_per_row,
1201 num_elems_read_per_iteration,
1202 num_elems_written_per_iteration,
1203 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001204 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001205
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001206 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001207}
1208
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001209void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001210{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001211 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001212 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1213 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1214 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1215
Giorgio Arenac0f54432018-03-16 14:02:34 +00001216 const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001217
Giorgio Arenac0f54432018-03-16 14:02:34 +00001218 if(_input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001219 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001220 switch(kernel_size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001221 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001222 case 1:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001223 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001224 switch(_input->info()->data_type())
1225 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001226 case DataType::F32:
1227 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1228 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001229#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001230 case DataType::F16:
1231 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1232 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001233#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001234 default:
1235 ARM_COMPUTE_ERROR("Data type not supported");
1236 break;
1237 }
1238 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001239 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001240 case 3:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001241 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001242 switch(_input->info()->data_type())
1243 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001244 case DataType::F32:
1245 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1246 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001247#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001248 case DataType::F16:
1249 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1250 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001251#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001252 default:
1253 ARM_COMPUTE_ERROR("Data type not supported");
1254 break;
1255 }
1256 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001257 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001258 case 5:
Pablo Tello06da39d2017-08-10 15:10:40 +01001259 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001260 switch(_input->info()->data_type())
1261 {
1262 case DataType::F32:
1263 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1264 break;
1265 default:
1266 ARM_COMPUTE_ERROR("Data type not supported");
1267 break;
1268 }
1269 break;
Pablo Tello06da39d2017-08-10 15:10:40 +01001270 }
Pablo Tello06da39d2017-08-10 15:10:40 +01001271
Giorgio Arenac0f54432018-03-16 14:02:34 +00001272 default:
1273 {
1274 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
1275 break;
1276 }
1277 }
1278 }
1279 else
1280 {
1281 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001282 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001283 case DataType::F32:
1284 convolver_nhwc<float>::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
1285 break;
1286 default:
1287 ARM_COMPUTE_ERROR("Data type not supported");
1288 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001289 }
1290 }
1291}