blob: 162c4b1ace81f9bbc18dc5a555aa2bc00a610b71 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Michalis Spyrou621965e2018-01-08 17:11:26 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
Georgios Pinitas4074c992018-01-30 18:13:46 +000025#include "arm_compute/core/NEON/kernels/detail/NEDirectConvolutionDetail.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
27#include "arm_compute/core/AccessWindowStatic.h"
Anthony Barbiereaefd002018-07-20 17:49:35 +010028#include "arm_compute/core/CPP/Validate.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010029#include "arm_compute/core/Error.h"
30#include "arm_compute/core/Helpers.h"
31#include "arm_compute/core/IAccessWindow.h"
32#include "arm_compute/core/ITensor.h"
33#include "arm_compute/core/NEON/NEFixedPoint.h"
34#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010035#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010036#include "arm_compute/core/Validate.h"
Giorgio Arenac0f54432018-03-16 14:02:34 +000037#include "arm_compute/core/utils/misc/ShapeCalculator.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010038
Michalis Spyrou201c37c2018-10-25 17:25:54 +010039#include "arm_compute/core/NEON/wrapper/wrapper.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010040#include <algorithm>
41#include <arm_neon.h>
42
43using namespace arm_compute;
Michalis Spyrou7362f0d2017-10-18 17:58:22 +010044using namespace arm_compute::detail;
Anthony Barbier6ff3b192017-09-04 18:44:23 +010045
46namespace
47{
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000048#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +010049template <unsigned int stridex>
50float16x8_t internal_vld1q(const float16_t *in);
51
52template <>
53float16x8_t internal_vld1q<1>(const float16_t *in)
54{
55 return vld1q_f16(in);
56}
57
58template <>
59float16x8_t internal_vld1q<2>(const float16_t *in)
60{
61 const float16x8x2_t tmp = vld2q_f16(in);
62 return tmp.val[0];
63}
64
65template <>
66float16x8_t internal_vld1q<3>(const float16_t *in)
67{
68 const float16x8x3_t tmp = vld3q_f16(in);
69 return tmp.val[0];
70}
71
72inline float16x8_t internal_vdupq_n(float16_t v)
73{
74 return vdupq_n_f16(v);
75}
76
77inline void internal_vst1q(float16_t *p, const float16x8_t &v)
78{
79 vst1q_f16(p, v);
80}
81
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010082float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y)
Pablo Tello0d176142017-07-06 16:43:14 +010083{
Pablo Tello0d176142017-07-06 16:43:14 +010084 return vmulq_f16(x, y);
85}
86
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +010087inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z)
Pablo Tello0d176142017-07-06 16:43:14 +010088{
Pablo Tello0d176142017-07-06 16:43:14 +010089 return vaddq_f16(x, vmulq_f16(y, z));
90}
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +000091#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +010092
Anthony Barbier6ff3b192017-09-04 18:44:23 +010093template <unsigned int stridex>
94float32x4_t internal_vld1q(const float *in);
95
96template <>
97float32x4_t internal_vld1q<1>(const float *in)
98{
99 return vld1q_f32(in);
100}
101
102template <>
103float32x4_t internal_vld1q<2>(const float *in)
104{
105 const float32x4x2_t tmp = vld2q_f32(in);
106 return tmp.val[0];
107}
108
109template <>
110float32x4_t internal_vld1q<3>(const float *in)
111{
112 const float32x4x3_t tmp = vld3q_f32(in);
113 return tmp.val[0];
114}
115
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100116inline float32x4_t internal_vdupq_n(float v)
117{
118 return vdupq_n_f32(v);
119}
120
121inline void internal_vst1q(float *p, const float32x4_t &v)
122{
123 vst1q_f32(p, v);
124}
125
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100126float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100127{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100128 return vmulq_f32(x, y);
129}
130
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100131inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z)
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100132{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100133 return vmlaq_f32(x, y, z);
134}
135
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000136constexpr int small_tensor_size_optim = 8;
137inline bool run_optim_small_tensor_info(const ITensorInfo *t)
138{
139 return t->dimension(Window::DimX) <= small_tensor_size_optim && t->dimension(Window::DimY) <= small_tensor_size_optim;
140}
141
Pablo Telloc09314a2017-09-21 13:59:14 +0100142inline bool run_optim_small_tensor(const ITensor *t)
143{
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000144 return run_optim_small_tensor_info(t->info());
Pablo Telloc09314a2017-09-21 13:59:14 +0100145}
146
147// Optimized convolver for 1x1 kernels used only where input width and height are both <= 8
148// For big Z as in Input=7x7x832, this implementation is faster than the general code becuase it doesn't need to
149// store intermidiate results in memory. Temporary results are stored in NEON registers directly and then written to the output buffer.
150template <unsigned int stridex>
151class convolver_w1x1_i8x8_f32
152{
153public:
154 static void convolve(const Window &window, const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
155 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000156 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimX) > small_tensor_size_optim);
157 ARM_COMPUTE_ERROR_ON(input->info()->dimension(Window::DimY) > small_tensor_size_optim);
Pablo Telloc09314a2017-09-21 13:59:14 +0100158
Georgios Pinitas15997872018-02-19 13:58:22 +0000159 const int input_stride_x = input->info()->strides_in_bytes().x();
Pablo Telloc09314a2017-09-21 13:59:14 +0100160 const int input_stride_y = input->info()->strides_in_bytes().y();
161 const int input_stride_z = input->info()->strides_in_bytes().z();
162 const int output_stride_y = output->info()->strides_in_bytes().y();
163 const int output_stride_z = output->info()->strides_in_bytes().z();
164 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
165 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
166 const int output_h = output->info()->dimension(1);
167 const int range_z = window.z().end() - window.z().start();
168 const int kernel_depth = weights->info()->dimension(Window::DimZ);
169 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
Georgios Pinitas15997872018-02-19 13:58:22 +0000170 const unsigned int conv_pad_left = conv_info.pad_left();
171 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Telloc09314a2017-09-21 13:59:14 +0100172
173 // setup output window for the iterator
174 Window window_out = window;
175 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
176 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
177 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
178
179 // setup input window for the iterator
180 Window window_in = window;
181 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
182 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
183 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
184 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
185
186 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
187 Iterator out(output, window_out);
188 Iterator in(input, window_in);
189 Iterator k(weights, window_k);
190
191 const uint8_t *k_ptr = k.ptr();
192
193 execute_window_loop(window_out, [&](const Coordinates & id)
194 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000195 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000196 uint8_t *out_ptr = out.ptr();
197 int ih = 0;
198 int oh = 0;
199 float32x4_t accum0[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
200 float32x4_t accum1[small_tensor_size_optim] = { vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0), vdupq_n_f32(0) };
Pablo Telloc09314a2017-09-21 13:59:14 +0100201 for(int oz = 0; oz < range_z; ++oz)
202 {
203 accum0[0] = accum0[1] = accum0[2] = accum0[3] = accum0[4] = accum0[5] = accum0[6] = accum0[7] = vdupq_n_f32(0.f);
204 accum1[0] = accum1[1] = accum1[2] = accum1[3] = accum1[4] = accum1[5] = accum1[6] = accum1[7] = vdupq_n_f32(0.f);
205 auto p_out_base = out_ptr + oz * output_stride_z;
206 for(int p = 0; p < kernel_depth; ++p)
207 {
208 const auto k_val = reinterpret_cast<const float *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
209 const auto vk0 = internal_vdupq_n(*k_val);
210 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
211 {
212 const int offset_xy = ih * input_stride_y;
213 auto in_val = reinterpret_cast<const float *>(input_ptr + p * input_stride_z + offset_xy);
214 auto v_in0 = internal_vld1q<stridex>(in_val);
215 auto v_in1 = internal_vld1q<stridex>(in_val + 4);
216 accum0[oh] = vmlaq_f32(accum0[oh], vk0, v_in0);
217 accum1[oh] = vmlaq_f32(accum1[oh], vk0, v_in1);
218 }
219 }
220 for(oh = 0; oh < output_h; ++oh)
221 {
222 auto p_out = reinterpret_cast<float *>(p_out_base + oh * output_stride_y);
223 vst1q_f32(p_out, accum0[oh]);
224 vst1q_f32(p_out + 4, accum1[oh]);
225 }
226 }
227 },
228 in, out);
229 }
230};
231
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100232template <typename T1, typename T2, unsigned int stridex>
233class convolver_1x1
234{
235public:
236 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
237 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
238 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100239 const int input_stride_x = input->info()->strides_in_bytes().x();
240 const int input_stride_y = input->info()->strides_in_bytes().y();
241 const int input_stride_z = input->info()->strides_in_bytes().z();
242 const int output_stride_y = output->info()->strides_in_bytes().y();
243 const int output_stride_z = output->info()->strides_in_bytes().z();
244 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
245 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
246 const int output_w = output->info()->dimension(0);
247 const int output_h = output->info()->dimension(1);
248 const int range_z = window.z().end() - window.z().start();
249 const int kernel_depth = weights->info()->dimension(Window::DimZ);
250 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
251 const unsigned int conv_pad_left = conv_info.pad_left();
252 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100253
254 // setup output window for the iterator
255 Window window_out = window;
256 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
257 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
258 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
259
260 // setup input window for the iterator
261 Window window_in = window;
262 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
263 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
264 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
265 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
266
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100267 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100268 Iterator out(output, window_out);
269 Iterator in(input, window_in);
270 Iterator k(weights, window_k);
271
272 const uint8_t *k_ptr = k.ptr();
273
274 execute_window_loop(window_out, [&](const Coordinates & id)
275 {
276 /*
277 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
278 */
Georgios Pinitas15997872018-02-19 13:58:22 +0000279 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100280 uint8_t *out_ptr = out.ptr();
281 int ih = 0;
282 int oh = 0;
283 for(int oz = 0; oz < range_z; ++oz)
284 {
285 auto p_out_base = out_ptr + oz * output_stride_z;
286 // Step 1
287 {
288 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
289 const auto vk = internal_vdupq_n(*k_val);
290 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
291 {
292 const int offset_xy = ih * input_stride_y;
293 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
294 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
295 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
296 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100297 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100298 }
299 }
300 }
Pablo Telloc09314a2017-09-21 13:59:14 +0100301
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100302 // Step 2
303 for(int p = 1; p < kernel_depth; ++p)
304 {
305 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
306 const auto vk = internal_vdupq_n(*k_val);
307 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
308 {
309 const int offset_xy = ih * input_stride_y;
310 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
311 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
312 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
313 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100314 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val)));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100315 }
316 }
317 }
318 }
319 },
320 in, out);
321 }
322};
323
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000324#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Pablo Tello0d176142017-07-06 16:43:14 +0100325
326template <unsigned int stridex>
327void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
328
329template <>
330void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
331{
332 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
333 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
334}
335
336template <>
337void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
338{
339 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
340}
341
342template <>
343void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
344{
345 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
346}
347
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +0000348#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Pablo Tello0d176142017-07-06 16:43:14 +0100349
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100350template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100351float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100352 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100353
354inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
355{
356 const float32x4x3_t m00 =
357 {
358 {
359 vld1q_dup_f32(m0),
360 vld1q_dup_f32(m1),
361 vld1q_dup_f32(m2)
362 }
363 };
364 return m00;
365}
366
367inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
368{
369 const float32x4x2_t m00 =
370 {
371 {
372 vld1q_dup_f32(m3),
373 vld1q_dup_f32(m4)
374 }
375 };
376 return m00;
377}
378
379inline float32x4x3_t load_input(const float *const in)
380{
381 const float32x4x3_t vin =
382 {
383 {
384 vld1q_f32(in),
385 vld1q_f32(in + 4),
386 vld1q_f32(in + 8)
387 }
388 };
389 return vin;
390}
391
392template <>
393inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100394 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100395{
Pablo Tello06da39d2017-08-10 15:10:40 +0100396 const float32x4x3_t vin0 = load_input(in_0);
397 const float32x4x3_t vin1 = load_input(in_1);
398 const float32x4x3_t vin2 = load_input(in_2);
399 const float32x4x3_t vin3 = load_input(in_3);
400 const float32x4x3_t vin4 = load_input(in_4);
401 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
402 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
403 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
404 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
405 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
406 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
407 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
408 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
409 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
410 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
411
412 float32x4x2_t out =
413 {
414 {
415 vmulq_f32(vin0.val[0], m00.val[0]),
416 vmulq_f32(vin0.val[1], m00.val[0])
417 }
418 };
419
420 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
421 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
422 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
423 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
424
425 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
426 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
427 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
428 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
429 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
430
431 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
432 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
433 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
434 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
435 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
436
437 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
438 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
439 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
440 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
441 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
442
443 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
444 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
445 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
446 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
447 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
448
449 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
450 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
451 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
452 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
453
454 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
455 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
456 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
457 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
458 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
459
460 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
461 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
462 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
463 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
464 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
465
466 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
467 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
468 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
469 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
470 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
471
472 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
473 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
474 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
475 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
476 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
477
478 return out;
479}
480
481template <>
482inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100483 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100484{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100485 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100486 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
487 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
488 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
489 return out;
490}
491
492template <>
493inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100494 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4)
Pablo Tello06da39d2017-08-10 15:10:40 +0100495{
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100496 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100497 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
498 return out;
499}
500
501template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100502void accumulate_results(float *buffer, const float32x4x2_t &values);
503
504template <>
505void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
506{
507 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
508 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
509}
510
511template <>
512void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
513{
514 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
515}
516
517template <>
518void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
519{
520 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
521}
522
Giorgio Arenac0f54432018-03-16 14:02:34 +0000523template <typename T1>
524class convolver_nhwc
525{
526public:
527 static void convolve(const Window &window, int kernel_size, unsigned int num_elems_read_per_iteration,
528 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
529 {
530 const int input_width = input->info()->dimension(0);
531 const int input_depth = input->info()->dimension(2);
532 const int input_stride_x = input->info()->strides_in_bytes().x();
533 const int input_stride_y = input->info()->strides_in_bytes().y();
534 const int input_stride_z = input->info()->strides_in_bytes().z();
535 const int output_stride_x = output->info()->strides_in_bytes().x();
536 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
537 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
538 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
539 const int conv_pad_top = conv_info.pad_top();
540 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
541 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
542 const T1 zero = 0;
543
544 // Setup input window for the input iterator
545 Window window_in = window;
546 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
547 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
548 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
549
550 // Setup input window for the output iterator
551 Window window_out = window;
552 window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
553
554 // Setup input window for the weights iterator
555 Window window_k = calculate_max_window(*weights->info(), Steps());
556 window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
557 window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
558 window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
559 window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
560
561 Iterator in(input, window_in);
562 Iterator out(output, window_out);
563 Iterator k(weights, window_k);
564
565 execute_window_loop(window_k, [&](const Coordinates & id_k)
566 {
567 execute_window_loop(window_out, [&](const Coordinates & id)
568 {
569 const auto in_y = static_cast<int>(id.y() * conv_stride_x - conv_info.pad_left());
570 const auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
571
572 const uint8_t *in_ptr = in.ptr() + in_y * input_stride_y + in_z * input_stride_z;
573 uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x;
574
575 T1 out_val = 0;
576
577 auto in_addr_base0 = in_ptr;
578 auto we_addr_base0 = k.ptr();
579
580 for(int z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z)
581 {
582 const int in_z = id.z() * conv_stride_y + z - conv_pad_top;
583
584 if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom
585 {
586 auto in_addr_base1 = in_addr_base0;
587 auto we_addr_base1 = we_addr_base0;
588
589 for(int y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y)
590 {
591 auto out_values = internal_vdupq_n(zero);
592
593 int x = 0;
594 int no_leftover = input_width - num_elems_read_per_iteration;
595
596 for(; x < no_leftover; x += num_elems_read_per_iteration)
597 {
598 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
599 const auto in_values = internal_vld1q<1>(in_addr);
600
601 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
602 const auto we_values = internal_vld1q<1>(we_addr);
603
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100604 out_values = internal_vmlal(out_values, in_values, we_values);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000605 }
606
Michalis Spyrou201c37c2018-10-25 17:25:54 +0100607 auto carry_addition = wrapper::vpadd(wrapper::vgethigh(out_values), wrapper::vgetlow(out_values));
608 carry_addition = wrapper::vpadd(carry_addition, carry_addition);
609 out_val += wrapper::vgetlane(carry_addition, 0);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000610
611 // Leftover
612 for(; x < input_width; ++x)
613 {
614 const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
615 const auto in_value = *(in_addr);
616
617 const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
618 const auto we_value = *(we_addr);
619
620 out_val += in_value * we_value;
621 }
622 }
623 }
624 }
625
626 *(reinterpret_cast<T1 *>(out_ptr)) = out_val;
627 },
628 in, out);
629 },
630 k);
631 }
632};
633
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100634template <typename T1, typename T2, unsigned int stridex>
635class convolver_3x3
636{
637public:
638 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
639 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
640 {
641 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100642 const int input_stride_x = input->info()->strides_in_bytes().x();
643 const int input_stride_y = input->info()->strides_in_bytes().y();
644 const int input_stride_z = input->info()->strides_in_bytes().z();
645 const int output_stride_y = output->info()->strides_in_bytes().y();
646 const int output_stride_z = output->info()->strides_in_bytes().z();
647 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
648 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
649 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
650 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
651 const int output_w = output->info()->dimension(0);
652 const int output_h = output->info()->dimension(1);
653 const int num_planes_z = window.z().end() - window.z().start();
654 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
655 const int kernel_depth = weights->info()->dimension(Window::DimZ);
656 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
657 const unsigned int conv_pad_left = conv_info.pad_left();
658 const unsigned int conv_pad_top = conv_info.pad_top();
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100659
660 // setup output window for the iterator
661 Window window_out = window;
662 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
663 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
664 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
665
666 // setup input window for the iterator
667 Window window_in = window;
668 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
669 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
670 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
671 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
672
673 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
674
675 Iterator out(output, window_out);
676 Iterator in(input, window_in);
677 Iterator k(weights, window_k);
678
679 const uint8_t *k_ptr = k.ptr();
680
681 execute_window_loop(window_out, [&](const Coordinates & id)
682 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000683 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100684 uint8_t *out_ptr = out.ptr();
685 int ih = 0;
686 int oh = 0;
687 /*
688 Each thread executing this kernel computes one or more output's volume planes.
689
690 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
691 the third thread [16,24] and the fourth thread [25,31].
692
693 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
Anthony Barbiere5007472017-10-27 15:01:44 +0100694 is that we setup the neon registers containing the kernel's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100695
696 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
697 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
698 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
699 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100700 for(int oz = 0; oz < num_planes_z; ++oz)
701 {
Pablo Tello0d176142017-07-06 16:43:14 +0100702 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100703 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
704 // Step 1
705 {
Pablo Tello0d176142017-07-06 16:43:14 +0100706 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
707 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
708 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100709 const auto vk_r0 = load_matrix_row(ptr_k_r0);
710 const auto vk_r1 = load_matrix_row(ptr_k_r1);
711 const auto vk_r2 = load_matrix_row(ptr_k_r2);
712 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
713 {
714 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
715 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
716 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
717 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
718 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
719 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
720 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100721 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100722 store_results<stridex>(p_out, vres);
723 }
724 }
725 }
726 // Step 2
727 for(int p = 1; p < kernel_depth; ++p)
728 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100729 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
730 const uint8_t *input_base = input_ptr + p * input_stride_z;
731 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
732 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
733 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
734 const auto vk_r0 = load_matrix_row(ptr_k_r0);
735 const auto vk_r1 = load_matrix_row(ptr_k_r1);
736 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100737 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
738 {
Pablo Tello06da39d2017-08-10 15:10:40 +0100739 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
740 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
741 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100742 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
743 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
744 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
745 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100746 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100747 accumulate_results<stridex>(p_out, vres);
748 }
749 }
750 }
751 }
752 },
753 in, out);
754 }
755};
756
Pablo Tello06da39d2017-08-10 15:10:40 +0100757template <typename T1, typename T2, unsigned int stridex>
758class convolver_5x5
759{
760public:
761 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
762 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
763 {
764 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100765 const int input_stride_x = input->info()->strides_in_bytes().x();
766 const int input_stride_y = input->info()->strides_in_bytes().y();
767 const int input_stride_z = input->info()->strides_in_bytes().z();
768 const int output_stride_y = output->info()->strides_in_bytes().y();
769 const int output_stride_z = output->info()->strides_in_bytes().z();
770 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
771 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
772 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
773 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
774 const int output_w = output->info()->dimension(0);
775 const int output_h = output->info()->dimension(1);
776 const int num_planes_z = window.z().end() - window.z().start();
777 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
778 const int kernel_depth = weights->info()->dimension(Window::DimZ);
779 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
780 const unsigned int conv_pad_left = conv_info.pad_left();
781 const unsigned int conv_pad_top = conv_info.pad_top();
Pablo Tello06da39d2017-08-10 15:10:40 +0100782
783 // setup output window for the iterator
784 Window window_out = window;
785 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
786 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
787 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
788
789 // setup input window for the iterator
790 Window window_in = window;
791 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
792 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
793 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
794 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
795
796 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
797
798 Iterator out(output, window_out);
799 Iterator in(input, window_in);
800 Iterator k(weights, window_k);
801
802 const uint8_t *k_ptr = k.ptr();
803
804 execute_window_loop(window_out, [&](const Coordinates & id)
805 {
Georgios Pinitas15997872018-02-19 13:58:22 +0000806 const uint8_t *input_ptr = in.ptr() - conv_pad_left * input_stride_x - conv_pad_top * input_stride_y;
Pablo Tello06da39d2017-08-10 15:10:40 +0100807 uint8_t *out_ptr = out.ptr();
808 int ih = 0;
809 int oh = 0;
810 for(int oz = 0; oz < num_planes_z; ++oz)
811 {
812 const int zoffset = id.z() + oz;
813 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
814 // Step 1
815 {
816 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
817 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
818 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
819 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
820 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
821 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
822 {
823 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
824 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
825 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
826 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
827 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
828 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
829 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
830 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
831 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100832 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100833 store_results<stridex>(p_out, vres);
834 }
835 }
836 }
837 // Step 2
838 for(int p = 1; p < kernel_depth; ++p)
839 {
840 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
841 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
842 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
843 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
844 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
845
846 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
847 {
848 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
849 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
850 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
851 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
852 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
853 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
854 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
855 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
856 {
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100857 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4);
Pablo Tello06da39d2017-08-10 15:10:40 +0100858 accumulate_results<stridex>(p_out, vres);
859 }
860 }
861 }
862 }
863 },
864 in, out);
865 }
866};
867
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100868template <typename T1, typename T2>
869inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
870 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
871{
872 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
873 switch(conv_stride_x)
874 {
875 case 1:
876 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
877 break;
878 case 2:
879 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
880 break;
881 case 3:
882 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
883 break;
884 default:
885 ARM_COMPUTE_ERROR("Not implemented");
886 }
887}
888
Pablo Telloc09314a2017-09-21 13:59:14 +0100889template <>
890inline void convolve_1x1<float, float>(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
891 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
892{
893 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
894 if(run_optim_small_tensor(input))
895 {
896 switch(conv_stride_x)
897 {
898 case 1:
899 convolver_w1x1_i8x8_f32<1>::convolve(window, input, weights, output, conv_info);
900 break;
901 case 2:
902 convolver_w1x1_i8x8_f32<2>::convolve(window, input, weights, output, conv_info);
903 break;
904 case 3:
905 convolver_w1x1_i8x8_f32<3>::convolve(window, input, weights, output, conv_info);
906 break;
907 default:
908 ARM_COMPUTE_ERROR("Not implemented");
909 }
910 }
911 else
912 {
913 switch(conv_stride_x)
914 {
915 case 1:
916 convolver_1x1<float, float, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
917 break;
918 case 2:
919 convolver_1x1<float, float, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
920 break;
921 case 3:
922 convolver_1x1<float, float, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
923 break;
924 default:
925 ARM_COMPUTE_ERROR("Not implemented");
926 }
927 }
928}
929
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100930template <typename T1, typename T2>
931inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
932 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
933{
934 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
935 switch(conv_stride_x)
936 {
937 case 1:
938 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
939 break;
940 case 2:
941 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
942 break;
943 case 3:
944 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
945 break;
946 default:
947 ARM_COMPUTE_ERROR("Not implemented");
948 }
949}
Pablo Tello06da39d2017-08-10 15:10:40 +0100950
951template <typename T1, typename T2>
952inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
953 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
954{
955 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
956 switch(conv_stride_x)
957 {
958 case 1:
959 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
960 break;
961 case 2:
962 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
963 break;
964 case 3:
965 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
966 break;
967 default:
968 ARM_COMPUTE_ERROR("Not implemented");
969 }
970}
971
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000972Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
973{
974 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000975 ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
Anthony Barbiereaefd002018-07-20 17:49:35 +0100976 ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100977 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000978 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000979
Giorgio Arenac0f54432018-03-16 14:02:34 +0000980 const DataLayout data_layout = input->data_layout();
981 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
982 const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
983 const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
984
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000985 ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Giorgio Arenac0f54432018-03-16 14:02:34 +0000986 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx));
987 ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000988 ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
Giorgio Arenac0f54432018-03-16 14:02:34 +0000989 ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
Gian Marco Iodice41acb762018-08-23 10:25:06 +0100990 ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (input->data_type() == DataType::F16));
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000991
992 // Checks performed when output is configured
993 if(output->total_size() != 0)
994 {
Giorgio Arenac0f54432018-03-16 14:02:34 +0000995 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000996
997 DataType data_type = input->data_type();
Michalis Spyrouafa5d812017-11-30 14:25:57 +0000998
999 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape);
1000 ARM_COMPUTE_RETURN_ERROR_ON(output->data_type() != data_type);
1001 }
1002
1003 return Status{};
1004}
1005
1006std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001007 unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001008{
Giorgio Arenac0f54432018-03-16 14:02:34 +00001009 ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
1010
1011 const DataLayout data_layout = input->data_layout();
1012 const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
1013
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001014 // Calculate right and bottom border
Giorgio Arenac0f54432018-03-16 14:02:34 +00001015 unsigned int kernel_size = weights->dimension(width_idx);
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001016 const int conv_stride_x = std::get<0>(conv_info.stride());
Georgios Pinitas1a03d762018-02-21 14:47:09 +00001017 const int conv_stride_y = std::get<1>(conv_info.stride());
Giorgio Arenac0f54432018-03-16 14:02:34 +00001018 const int input_width = input->dimension(width_idx);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001019
Giorgio Arenac0f54432018-03-16 14:02:34 +00001020 Window win{};
1021 bool window_changed = false;
1022
1023 if(data_layout == DataLayout::NCHW)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001024 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001025 switch(kernel_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001026 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001027 case 1:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001028 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001029 switch(input->data_type())
1030 {
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001031#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001032 case DataType::F16:
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001033 num_elems_written_per_iteration = 8;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001034 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001035#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001036 case DataType::F32:
1037 if(run_optim_small_tensor_info(input))
1038 {
1039 num_elems_written_per_iteration = 8;
1040 }
1041 else
1042 {
1043 num_elems_written_per_iteration = 4;
1044 }
1045 break;
1046 default:
1047 ARM_COMPUTE_ERROR("Data type not supported.");
1048 break;
1049 }
1050 num_weight_elems_read_per_row = kernel_size;
1051 num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
1052 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001053 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001054 case 3:
Giorgio Arenac0f54432018-03-16 14:02:34 +00001055 switch(input->data_type())
1056 {
1057 case DataType::F32:
1058 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1059 num_elems_read_per_iteration = 12;
1060 num_elems_written_per_iteration = 16 >> conv_stride_x;
1061 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001062#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001063 case DataType::F16:
Giorgio Arenac0f54432018-03-16 14:02:34 +00001064 num_weight_elems_read_per_row = 8 + kernel_size - 1;
1065 num_elems_read_per_iteration = 24;
1066 num_elems_written_per_iteration = 32 >> conv_stride_x;
1067 break;
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001068#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001069 default:
1070 ARM_COMPUTE_ERROR("Data type not supported.");
1071 break;
1072 }
Gian Marco Iodice41acb762018-08-23 10:25:06 +01001073 break;
1074 case 5:
1075 {
1076 switch(input->data_type())
1077 {
1078 case DataType::F32:
1079 num_weight_elems_read_per_row = 4 + kernel_size - 1;
1080 num_elems_read_per_iteration = 12;
1081 num_elems_written_per_iteration = 16 >> conv_stride_x;
1082 break;
1083 default:
1084 ARM_COMPUTE_ERROR("Data type not supported.");
1085 break;
1086 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001087 }
1088 break;
1089 default:
1090 {
1091 ARM_COMPUTE_ERROR("Not implemented");
1092 break;
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001093 }
1094 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001095
1096 // Calculate right pad
1097 int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
1098 int end_x = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
1099 int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
1100
1101 // Calculate border
1102 const unsigned int conv_pad_left = conv_info.pad_left();
1103 const unsigned int conv_pad_top = conv_info.pad_top();
1104 const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
1105 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1106
1107 border_size.left = conv_pad_left;
1108 border_size.top = conv_pad_top;
1109 border_size.right = conv_pad_right;
1110 border_size.bottom = conv_pad_bottom;
1111
1112 // Configure window
1113 win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
1114
1115 AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
1116 num_elems_read_per_iteration, kernel_size,
1117 conv_stride_x, conv_stride_y);
1118 AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
1119 AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
1120 window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
1121 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001122 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001123 else
1124 {
1125 border_size.left = 0;
1126 border_size.top = conv_info.pad_left();
1127 border_size.right = 0;
1128 border_size.bottom = conv_info.pad_right();
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001129
Giorgio Arenac0f54432018-03-16 14:02:34 +00001130 num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type());
Georgios Pinitas1d6d2112018-02-05 17:40:12 +00001131
Giorgio Arenac0f54432018-03-16 14:02:34 +00001132 win = calculate_max_window(*output, Steps());
Michalis Spyrou621965e2018-01-08 17:11:26 +00001133
Giorgio Arenac0f54432018-03-16 14:02:34 +00001134 AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x);
1135 AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size);
1136 window_changed = update_window_and_padding(win, input_access, weights_access);
1137 }
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001138
1139 Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
1140 return std::make_pair(err, win);
1141}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001142} // namespace
1143
1144NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001145 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1146 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001147{
1148}
1149
1150BorderSize NEDirectConvolutionLayerKernel::border_size() const
1151{
1152 return _border_size;
1153}
1154
1155void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1156{
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001157 ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001158
1159 _input = input;
1160 _weights = weights;
1161 _output = output;
1162 _conv_info = conv_info;
Giorgio Arenac0f54432018-03-16 14:02:34 +00001163 _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Michalis Spyrou621965e2018-01-08 17:11:26 +00001164
1165 const unsigned int conv_pad_left = conv_info.pad_left();
1166 const unsigned int conv_pad_top = conv_info.pad_top();
1167 const unsigned int conv_pad_right = conv_info.pad_right();
1168 const unsigned int conv_pad_bottom = conv_info.pad_bottom();
1169 _border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001170
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001171 // Get convolved dimensions
Giorgio Arenac0f54432018-03-16 14:02:34 +00001172 TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001173
1174 DataType data_type = input->info()->data_type();
1175
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001176 // Output auto inizialitation if not yet initialized
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01001177 auto_init_if_empty(*output->info(), output_shape, 1, data_type);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001178
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001179 // Perform validation step
1180 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), weights->info(), output->info(), conv_info));
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001181
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001182 // Configure kernel window
1183 auto win_config = validate_and_configure_window(input->info(), weights->info(), output->info(), conv_info, _num_weight_elems_read_per_row,
Georgios Pinitas0223a782017-12-12 11:44:44 +00001184 _num_elems_read_per_iteration, _num_elems_written_per_iteration, _border_size);
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001185 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1186 INEKernel::configure(win_config.second);
1187}
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001188
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001189Status NEDirectConvolutionLayerKernel::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
1190{
1191 unsigned int num_weight_elems_read_per_row = 0;
1192 unsigned int num_elems_read_per_iteration = 0;
1193 unsigned int num_elems_written_per_iteration = 0;
Georgios Pinitas15997872018-02-19 13:58:22 +00001194 BorderSize border_size = {};
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001195 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, weights, output, conv_info));
Georgios Pinitas0223a782017-12-12 11:44:44 +00001196 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1197 weights->clone().get(),
1198 output->clone().get(),
1199 conv_info,
1200 num_weight_elems_read_per_row,
1201 num_elems_read_per_iteration,
1202 num_elems_written_per_iteration,
1203 border_size)
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001204 .first);
Georgios Pinitas898a8062017-09-12 19:19:12 +01001205
Michalis Spyrouafa5d812017-11-30 14:25:57 +00001206 return Status{};
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001207}
1208
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001209void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001210{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001211 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001212 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1213 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1214 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1215
Giorgio Arenac0f54432018-03-16 14:02:34 +00001216 const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001217
Giorgio Arenac0f54432018-03-16 14:02:34 +00001218 if(_input->info()->data_layout() == DataLayout::NCHW)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001219 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001220 switch(kernel_size)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001221 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001222 case 1:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001223 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001224 switch(_input->info()->data_type())
1225 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001226 case DataType::F32:
1227 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1228 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001229#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001230 case DataType::F16:
1231 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1232 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001233#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001234 default:
1235 ARM_COMPUTE_ERROR("Data type not supported");
1236 break;
1237 }
1238 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001239 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001240 case 3:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001241 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001242 switch(_input->info()->data_type())
1243 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001244 case DataType::F32:
1245 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1246 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001247#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
Giorgio Arenac0f54432018-03-16 14:02:34 +00001248 case DataType::F16:
1249 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1250 break;
Ioan-Cristian Szabo5edbd1c2017-11-13 13:34:08 +00001251#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
Giorgio Arenac0f54432018-03-16 14:02:34 +00001252 default:
1253 ARM_COMPUTE_ERROR("Data type not supported");
1254 break;
1255 }
1256 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001257 }
Giorgio Arenac0f54432018-03-16 14:02:34 +00001258 case 5:
Pablo Tello06da39d2017-08-10 15:10:40 +01001259 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001260 switch(_input->info()->data_type())
1261 {
1262 case DataType::F32:
1263 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1264 break;
1265 default:
1266 ARM_COMPUTE_ERROR("Data type not supported");
1267 break;
1268 }
1269 break;
Pablo Tello06da39d2017-08-10 15:10:40 +01001270 }
Pablo Tello06da39d2017-08-10 15:10:40 +01001271
Giorgio Arenac0f54432018-03-16 14:02:34 +00001272 default:
1273 {
1274 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
1275 break;
1276 }
1277 }
1278 }
1279 else
1280 {
1281 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001282 {
Giorgio Arenac0f54432018-03-16 14:02:34 +00001283 case DataType::F32:
1284 convolver_nhwc<float>::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
1285 break;
1286 default:
1287 ARM_COMPUTE_ERROR("Data type not supported");
1288 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001289 }
1290 }
1291}