blob: 43292d1b22d344b28a405a81147deac0a5271d0e [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/Types.h"
33#include "arm_compute/core/Validate.h"
34
35#include <algorithm>
36#include <arm_neon.h>
37
38using namespace arm_compute;
39
40namespace
41{
Pablo Tellof87cc7f2017-07-26 10:28:40 +010042template <unsigned int stridex>
43qint16x8_t internal_vld1q(const qint16_t *in);
44
45template <>
46qint16x8_t internal_vld1q<1>(const qint16_t *in)
47{
48 return vld1q_qs16(in);
49}
50
51template <>
52qint16x8_t internal_vld1q<2>(const qint16_t *in)
53{
54 const int16x8x2_t tmp = vld2q_s16(in);
55 return tmp.val[0];
56}
57
58template <>
59qint16x8_t internal_vld1q<3>(const qint16_t *in)
60{
61 const int16x8x3_t tmp = vld3q_s16(in);
62 return tmp.val[0];
63}
64
65inline qint16x8_t internal_vdupq_n(qint16_t v)
66{
67 return vdupq_n_qs16(v);
68}
69
Pablo Tello0d176142017-07-06 16:43:14 +010070#ifdef ARM_COMPUTE_ENABLE_FP16
71template <unsigned int stridex>
72float16x8_t internal_vld1q(const float16_t *in);
73
74template <>
75float16x8_t internal_vld1q<1>(const float16_t *in)
76{
77 return vld1q_f16(in);
78}
79
80template <>
81float16x8_t internal_vld1q<2>(const float16_t *in)
82{
83 const float16x8x2_t tmp = vld2q_f16(in);
84 return tmp.val[0];
85}
86
87template <>
88float16x8_t internal_vld1q<3>(const float16_t *in)
89{
90 const float16x8x3_t tmp = vld3q_f16(in);
91 return tmp.val[0];
92}
93
94inline float16x8_t internal_vdupq_n(float16_t v)
95{
96 return vdupq_n_f16(v);
97}
98
99inline void internal_vst1q(float16_t *p, const float16x8_t &v)
100{
101 vst1q_f16(p, v);
102}
103
104float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
105{
106 ARM_COMPUTE_UNUSED(fixed_point_position);
107 return vmulq_f16(x, y);
108}
109
110inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
111{
112 ARM_COMPUTE_UNUSED(fixed_point_position);
113 return vaddq_f16(x, vmulq_f16(y, z));
114}
115#endif /* ARM_COMPUTE_ENABLE_FP16 */
116
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100117template <unsigned int stridex>
118float32x4_t internal_vld1q(const float *in);
119
120template <>
121float32x4_t internal_vld1q<1>(const float *in)
122{
123 return vld1q_f32(in);
124}
125
126template <>
127float32x4_t internal_vld1q<2>(const float *in)
128{
129 const float32x4x2_t tmp = vld2q_f32(in);
130 return tmp.val[0];
131}
132
133template <>
134float32x4_t internal_vld1q<3>(const float *in)
135{
136 const float32x4x3_t tmp = vld3q_f32(in);
137 return tmp.val[0];
138}
139
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100140inline float32x4_t internal_vdupq_n(float v)
141{
142 return vdupq_n_f32(v);
143}
144
145inline void internal_vst1q(float *p, const float32x4_t &v)
146{
147 vst1q_f32(p, v);
148}
149
150float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
151{
152 ARM_COMPUTE_UNUSED(fixed_point_position);
153 return vmulq_f32(x, y);
154}
155
156inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
157{
158 ARM_COMPUTE_UNUSED(fixed_point_position);
159 return vmlaq_f32(x, y, z);
160}
161
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100162template <unsigned int stridex>
163qint8x8_t internal_vld1q(const qint8_t *in);
164
165template <>
166qint8x8_t internal_vld1q<1>(const qint8_t *in)
167{
168 return vld1_qs8(in);
169}
170
171template <>
172qint8x8_t internal_vld1q<2>(const qint8_t *in)
173{
174 const qint8x8x2_t tmp = vld2_s8(in);
175 return tmp.val[0];
176}
177
178template <>
179qint8x8_t internal_vld1q<3>(const qint8_t *in)
180{
181 const qint8x8x3_t tmp = vld3_s8(in);
182 return tmp.val[0];
183}
184
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100185inline qint8x8_t internal_vdupq_n(qint8_t v)
186{
187 return vdup_n_qs8(v);
188}
189
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100190inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100191{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100192 return vmull_qs8(x, y, fixed_point_position);
193}
194
195inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
196{
197 return vqmlal_qs8(x, y, z, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100198}
199
200inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
201{
202 vst1q_qs16(p, v);
203}
204
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100205inline void internal_vst1q(int *p, const qint32x4x2_t &v)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100206{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100207 vst1q_s32(p, v.val[0]);
208 vst1q_s32(p + 4, v.val[1]);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100209}
210
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100211template <unsigned int stridex>
212qint32x4x2_t internal_vld1q(const qint32_t *in);
213
214template <>
215qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100216{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100217 const qint32x4x2_t r =
218 {
219 {
220 vld1q_s32(in),
221 vld1q_s32(in + 4)
222 }
223 };
224 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100225}
226
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100227inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100228{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100229 const qint32x4x2_t r =
230 {
231 {
232 vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
233 vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
234 }
235 };
236 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100237}
238
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100239inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100240{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100241 const qint32x4x2_t r =
242 {
243 {
244 vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
245 vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
246 }
247 };
248 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100249}
250
251template <typename T1, typename T2, unsigned int stridex>
252class convolver_1x1
253{
254public:
255 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
256 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
257 {
258 const int input_stride_y = input->info()->strides_in_bytes().y();
259 const int input_stride_z = input->info()->strides_in_bytes().z();
260 const int output_stride_y = output->info()->strides_in_bytes().y();
261 const int output_stride_z = output->info()->strides_in_bytes().z();
262 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
263 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
264 const int output_w = output->info()->dimension(0);
265 const int output_h = output->info()->dimension(1);
266 const int range_z = window.z().end() - window.z().start();
267 const int kernel_depth = weights->info()->dimension(Window::DimZ);
268 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
269 const int fixed_point_position = input->info()->fixed_point_position();
270
271 // setup output window for the iterator
272 Window window_out = window;
273 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
274 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
275 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
276
277 // setup input window for the iterator
278 Window window_in = window;
279 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
280 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
281 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
282 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
283
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100284 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100285 Iterator out(output, window_out);
286 Iterator in(input, window_in);
287 Iterator k(weights, window_k);
288
289 const uint8_t *k_ptr = k.ptr();
290
291 execute_window_loop(window_out, [&](const Coordinates & id)
292 {
293 /*
294 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
295 */
296 const uint8_t *input_ptr = in.ptr();
297 uint8_t *out_ptr = out.ptr();
298 int ih = 0;
299 int oh = 0;
300 for(int oz = 0; oz < range_z; ++oz)
301 {
302 auto p_out_base = out_ptr + oz * output_stride_z;
303 // Step 1
304 {
305 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
306 const auto vk = internal_vdupq_n(*k_val);
307 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
308 {
309 const int offset_xy = ih * input_stride_y;
310 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
311 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
312 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
313 {
314 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
315 }
316 }
317 }
318 // Step 2
319 for(int p = 1; p < kernel_depth; ++p)
320 {
321 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
322 const auto vk = internal_vdupq_n(*k_val);
323 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
324 {
325 const int offset_xy = ih * input_stride_y;
326 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
327 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
328 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
329 {
330 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
331 }
332 }
333 }
334 }
335 },
336 in, out);
337 }
338};
339
Pablo Tello0d176142017-07-06 16:43:14 +0100340#ifdef ARM_COMPUTE_ENABLE_FP16
341inline float16x8x3_t load_matrix_row(const float16_t *ptr)
342{
343 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
344 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
345 const float16x8x3_t r =
346 {
347 {
348 vld1q_dup_f16(ptr),
349 vld1q_dup_f16(1 + ptr),
350 vld1q_dup_f16(2 + ptr)
351 }
352 };
353 return r;
354}
355
356template <unsigned int stridex>
357float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
358 int fixed_point_position);
359
360template <>
361float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
362 int fixed_point_position)
363{
364 ARM_COMPUTE_UNUSED(fixed_point_position);
365
366 const float16x8x3_t vtop =
367 {
368 {
369 vld1q_f16(in_top),
370 vld1q_f16(in_top + 8),
371 vld1q_f16(in_top + 16)
372 }
373 };
374 const float16x8x3_t vmid =
375 {
376 {
377 vld1q_f16(in_mid),
378 vld1q_f16(in_mid + 8),
379 vld1q_f16(in_mid + 16)
380 }
381 };
382 const float16x8x3_t vlow =
383 {
384 {
385 vld1q_f16(in_low),
386 vld1q_f16(in_low + 8),
387 vld1q_f16(in_low + 16)
388 }
389 };
390 float16x8x2_t out =
391 {
392 {
393 vmulq_f16(vtop.val[0], m0.val[0]),
394 vmulq_f16(vtop.val[1], m0.val[0])
395 }
396 };
397 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
398 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
399 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
400 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
401 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
402 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
403 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
404 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
405 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
406 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
407 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
408 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
409 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
410 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
411 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
412 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
413 return out;
414}
415
416template <>
417inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
418 int fixed_point_position)
419{
420 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
421 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
422 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 2);
423 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 3);
424 return out;
425}
426
427template <>
428inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
429 int fixed_point_position)
430{
431 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
432 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
433 return out;
434}
435
436template <unsigned int stridex>
437void store_results(float16_t *buffer, const float16x8x2_t &values);
438
439template <>
440void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
441{
442 vst1q_f16(buffer, values.val[0]);
443 vst1q_f16(buffer + 8, values.val[1]);
444}
445
446template <>
447void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
448{
449 vst1q_f16(buffer, values.val[0]);
450}
451
452template <>
453void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
454{
455 vst1_f16(buffer, vget_low_f16(values.val[0]));
456}
457
458template <unsigned int stridex>
459void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
460
461template <>
462void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
463{
464 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
465 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
466}
467
468template <>
469void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
470{
471 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
472}
473
474template <>
475void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
476{
477 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
478}
479
480#endif /* ARM_COMPUTE_ENABLE_FP16 */
481
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100482inline float32x4x3_t load_matrix_row(const float *ptr)
483{
484 const float32x4x3_t r =
485 {
486 {
487 vld1q_dup_f32(ptr),
488 vld1q_dup_f32(1 + ptr),
489 vld1q_dup_f32(2 + ptr)
490 }
491 };
492 return r;
493}
494inline qint8x8x3_t load_matrix_row(const qint8_t *ptr)
495{
496 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
497 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
498 const qint8x8x3_t r =
499 {
500 {
501 vld1_dup_qs8(ptr),
502 vld1_dup_qs8(1 + ptr),
503 vld1_dup_qs8(2 + ptr)
504 }
505 };
506 return r;
507}
508
509template <unsigned int stridex>
510float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position);
511
512template <>
513inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
514{
515 ARM_COMPUTE_UNUSED(fixed_point_position);
516
517 const float32x4x3_t vtop =
518 {
519 {
520 vld1q_f32(in_top),
521 vld1q_f32(in_top + 4),
522 vld1q_f32(in_top + 8)
523 }
524 };
525 const float32x4x3_t vmid =
526 {
527 {
528 vld1q_f32(in_mid),
529 vld1q_f32(in_mid + 4),
530 vld1q_f32(in_mid + 8)
531 }
532 };
533 const float32x4x3_t vlow =
534 {
535 {
536 vld1q_f32(in_low),
537 vld1q_f32(in_low + 4),
538 vld1q_f32(in_low + 8)
539 }
540 };
541 float32x4x2_t out =
542 {
543 {
544 vmulq_f32(vtop.val[0], m0.val[0]),
545 vmulq_f32(vtop.val[1], m0.val[0])
546 }
547 };
548 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
549 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
550 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
551 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
552 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
553 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
554 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
555 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
556 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
557 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
558 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
559 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
560 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
561 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
562 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
563 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
564 return out;
565}
566
567template <>
568inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
569{
570 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
571 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
572 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
573 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
574 return out;
575}
576
577template <>
578inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
579{
580 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
581 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
582 return out;
583}
584
585template <unsigned int stridex>
586qint16x8x2_t convolve_3x3(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position);
587
588template <>
589inline qint16x8x2_t convolve_3x3<1>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
590{
591 ARM_COMPUTE_UNUSED(fixed_point_position);
592
593 const qint8x8x3_t vtop =
594 {
595 {
596 vld1_qs8(in_top),
597 vld1_qs8(in_top + 8),
598 vld1_qs8(in_top + 16)
599 }
600 };
601 const qint8x8x3_t vmid =
602 {
603 {
604 vld1_qs8(in_mid),
605 vld1_qs8(in_mid + 8),
606 vld1_qs8(in_mid + 16)
607 }
608 };
609 const qint8x8x3_t vlow =
610 {
611 {
612 vld1_qs8(in_low),
613 vld1_qs8(in_low + 8),
614 vld1_qs8(in_low + 16)
615 }
616 };
617 qint16x8x2_t out =
618 {
619 {
620 vmull_qs8(vtop.val[0], m0.val[0], fixed_point_position),
621 vmull_qs8(vtop.val[1], m0.val[0], fixed_point_position)
622 }
623 };
624 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 1), m0.val[1], fixed_point_position);
625 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 2), m0.val[2], fixed_point_position);
626 out.val[0] = vqmlal_qs8(out.val[0], vmid.val[0], m1.val[0], fixed_point_position);
627 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 1), m1.val[1], fixed_point_position);
628 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 2), m1.val[2], fixed_point_position);
629 out.val[0] = vqmlal_qs8(out.val[0], vlow.val[0], m2.val[0], fixed_point_position);
630 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 1), m2.val[1], fixed_point_position);
631 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 2), m2.val[2], fixed_point_position);
632 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 1), m0.val[1], fixed_point_position);
633 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 2), m0.val[2], fixed_point_position);
634 out.val[1] = vqmlal_qs8(out.val[1], vmid.val[1], m1.val[0], fixed_point_position);
635 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 1), m1.val[1], fixed_point_position);
636 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 2), m1.val[2], fixed_point_position);
637 out.val[1] = vqmlal_qs8(out.val[1], vlow.val[1], m2.val[0], fixed_point_position);
638 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 1), m2.val[1], fixed_point_position);
639 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 2), m2.val[2], fixed_point_position);
640 return out;
641}
642
643template <>
644inline qint16x8x2_t convolve_3x3<2>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
645{
646 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
647 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 2), out.val[0], 1);
648 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 4), out.val[0], 2);
649 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 3);
650 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 0), out.val[0], 4);
651 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 2), out.val[0], 5);
652 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 4), out.val[0], 6);
653 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 6), out.val[0], 7);
654 return out;
655}
656
657template <>
658inline qint16x8x2_t convolve_3x3<3>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
659{
660 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
661 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 3), out.val[0], 1);
662 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 2);
663 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 1), out.val[0], 3);
664 return out;
665}
666
667template <unsigned int stridex>
668void store_results(float *buffer, const float32x4x2_t &values);
669
670template <>
671void store_results<1>(float *buffer, const float32x4x2_t &values)
672{
673 vst1q_f32(buffer, values.val[0]);
674 vst1q_f32(buffer + 4, values.val[1]);
675}
676
677template <>
678void store_results<2>(float *buffer, const float32x4x2_t &values)
679{
680 vst1q_f32(buffer, values.val[0]);
681}
682
683template <>
684void store_results<3>(float *buffer, const float32x4x2_t &values)
685{
686 vst1_f32(buffer, vget_low_f32(values.val[0]));
687}
688
689template <unsigned int stridex>
690void store_results(qint16_t *buffer, const qint16x8x2_t &values);
691
692template <>
693void store_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
694{
695 vst1q_qs16(buffer, values.val[0]);
696 vst1q_qs16(buffer + 8, values.val[1]);
697}
698
699template <>
700void store_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
701{
702 vst1q_qs16(buffer, values.val[0]);
703}
704
705template <>
706void store_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
707{
708 vst1_qs16(buffer, vget_low_s16(values.val[0]));
709}
710
711template <unsigned int stridex>
712void accumulate_results(float *buffer, const float32x4x2_t &values);
713
714template <>
715void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
716{
717 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
718 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
719}
720
721template <>
722void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
723{
724 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
725}
726
727template <>
728void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
729{
730 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
731}
732
733template <unsigned int stridex>
734void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
735
736template <>
737void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
738{
739 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
740 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
741}
742
743template <>
744void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
745{
746 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
747}
748
749template <>
750void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
751{
752 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
753}
754
755template <unsigned int stridex>
756int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
757
758template <>
759int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
760{
761 return num_elems_written_per_iteration;
762}
763
764template <>
765int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
766{
767 return num_elems_written_per_iteration << 1;
768}
769
770template <>
771int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
772{
773 return num_elems_written_per_iteration * 3;
774}
775
776template <typename T1, typename T2, unsigned int stridex>
777class convolver_3x3
778{
779public:
780 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
781 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
782 {
783 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
784 const int input_stride_x = input->info()->strides_in_bytes().x();
785 const int input_stride_y = input->info()->strides_in_bytes().y();
786 const int input_stride_z = input->info()->strides_in_bytes().z();
787 const int output_stride_y = output->info()->strides_in_bytes().y();
788 const int output_stride_z = output->info()->strides_in_bytes().z();
789 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
790 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
791 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
792 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
793 const int output_w = output->info()->dimension(0);
794 const int output_h = output->info()->dimension(1);
795 const int num_planes_z = window.z().end() - window.z().start();
796 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
797 const int kernel_depth = weights->info()->dimension(Window::DimZ);
798 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
799 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
800 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
801 const int fixed_point_position = input->info()->fixed_point_position();
802
803 // setup output window for the iterator
804 Window window_out = window;
805 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
806 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
807 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
808
809 // setup input window for the iterator
810 Window window_in = window;
811 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
812 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
813 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
814 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
815
816 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
817
818 Iterator out(output, window_out);
819 Iterator in(input, window_in);
820 Iterator k(weights, window_k);
821
822 const uint8_t *k_ptr = k.ptr();
823
824 execute_window_loop(window_out, [&](const Coordinates & id)
825 {
826 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
827 uint8_t *out_ptr = out.ptr();
828 int ih = 0;
829 int oh = 0;
830 /*
831 Each thread executing this kernel computes one or more output's volume planes.
832
833 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
834 the third thread [16,24] and the fourth thread [25,31].
835
836 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
837 is that we setup the neon registers containing the kernerl's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
838
839 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
840 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
841 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
842 */
843
844 for(int oz = 0; oz < num_planes_z; ++oz)
845 {
Pablo Tello0d176142017-07-06 16:43:14 +0100846 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100847 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
848 // Step 1
849 {
Pablo Tello0d176142017-07-06 16:43:14 +0100850 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
851 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
852 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100853 const auto vk_r0 = load_matrix_row(ptr_k_r0);
854 const auto vk_r1 = load_matrix_row(ptr_k_r1);
855 const auto vk_r2 = load_matrix_row(ptr_k_r2);
856 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
857 {
858 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
859 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
860 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
861 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
862 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
863 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
864 {
865 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
866 store_results<stridex>(p_out, vres);
867 }
868 }
869 }
870 // Step 2
871 for(int p = 1; p < kernel_depth; ++p)
872 {
Pablo Tello0d176142017-07-06 16:43:14 +0100873 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
874 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
875 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100876 const auto vk_r0 = load_matrix_row(ptr_k_r0);
877 const auto vk_r1 = load_matrix_row(ptr_k_r1);
878 const auto vk_r2 = load_matrix_row(ptr_k_r2);
879 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
880 {
881 auto in_top = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
882 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
883 auto in_low = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
884 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
885 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
886 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
887 {
888 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
889 accumulate_results<stridex>(p_out, vres);
890 }
891 }
892 }
893 }
894 },
895 in, out);
896 }
897};
898
899template <typename T1, typename T2>
900inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
901 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
902{
903 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
904 switch(conv_stride_x)
905 {
906 case 1:
907 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
908 break;
909 case 2:
910 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
911 break;
912 case 3:
913 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
914 break;
915 default:
916 ARM_COMPUTE_ERROR("Not implemented");
917 }
918}
919
920template <typename T1, typename T2>
921inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
922 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
923{
924 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
925 switch(conv_stride_x)
926 {
927 case 1:
928 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
929 break;
930 case 2:
931 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
932 break;
933 case 3:
934 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
935 break;
936 default:
937 ARM_COMPUTE_ERROR("Not implemented");
938 }
939}
940} // namespace
941
942NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
943 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_elems_read_per_iteration(0), _num_elems_written_per_iteration(0)
944{
945}
946
947BorderSize NEDirectConvolutionLayerKernel::border_size() const
948{
949 return _border_size;
950}
951
952void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
953{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100954 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::QS16, DataType::F32);
955 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F16, DataType::QS16, DataType::F32);
956 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS16, DataType::F16, DataType::QS32, DataType::F32);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100957 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
958 "Pad > 0 not supported for 1x1 weights");
959 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
960 "Pad > 1 not supported for 3x3 weights");
961 ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
962
963 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
964 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
965 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
966
967 _input = input;
968 _weights = weights;
969 _output = output;
970 _conv_info = conv_info;
971 _kernel_size = weights->info()->dimension(0);
972 _border_size = BorderSize(conv_pad_y, conv_pad_x);
973
974 Window win = calculate_max_window(*output->info());
975
976 switch(_kernel_size)
977 {
978 case 1:
979 {
Pablo Tello0d176142017-07-06 16:43:14 +0100980 switch(input->info()->data_type())
981 {
982#ifdef ARM_COMPUTE_ENABLE_FP16
983 case DataType::F16:
984#endif /* ARM_COMPUTE_ENABLE_FP16 */
985 case DataType::QS8:
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100986 case DataType::QS16:
Pablo Tello0d176142017-07-06 16:43:14 +0100987 _num_elems_written_per_iteration = 8;
988 break;
989 case DataType::F32:
990 _num_elems_written_per_iteration = 4;
991 break;
992 default:
993 ARM_COMPUTE_ERROR("Data type not supported.");
994 break;
995 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100996
Pablo Tello0d176142017-07-06 16:43:14 +0100997 _num_elems_read_per_iteration = conv_stride_x * _num_elems_written_per_iteration;
998 win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100999 AccessWindowHorizontal input_access(input->info(), 0, _num_elems_read_per_iteration);
1000 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
1001 update_window_and_padding(win, input_access, output_access);
1002 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
1003 break;
1004 }
1005 case 3:
1006 {
1007 if(input->info()->data_type() == DataType::F32)
1008 {
1009 _num_elems_read_per_iteration = 12;
1010 _num_elems_written_per_iteration = 16 >> conv_stride_x;
1011 }
1012 else
1013 {
1014 _num_elems_read_per_iteration = 24;
1015 _num_elems_written_per_iteration = 32 >> conv_stride_x;
1016 }
1017
1018 // Calculate right and bottom border
1019 const unsigned int conv_stride_y = std::get<1>(_conv_info.stride());
1020 const int input_width = input->info()->dimension(0);
1021 const int input_height = input->info()->dimension(1);
1022 const int upper_bound_w = ceil_to_multiple(((output->info()->dimension(0) - 1) * conv_stride_x + _kernel_size), _num_elems_read_per_iteration) - conv_pad_x - input_width;
1023 const int upper_bound_h = ((output->info()->dimension(1) - 1) * conv_stride_y - conv_pad_y + _kernel_size) - input_height;
1024 _border_size.right = std::max(upper_bound_w, static_cast<int>(_kernel_size));
1025 _border_size.bottom = std::max(upper_bound_h, static_cast<int>(_kernel_size));
1026
1027 // Create window and update padding
1028 win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
1029 AccessWindowStatic input_access(input->info(), -conv_pad_x, -conv_pad_y, input_width + _border_size.right, input_height + _border_size.bottom);
1030 AccessWindowStatic weights_access(weights->info(), 0, 0, _kernel_size, _kernel_size);
1031 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
1032 update_window_and_padding(win, input_access, weights_access, output_access);
1033 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
1034 break;
1035 }
1036 default:
1037 {
1038 ARM_COMPUTE_ERROR("Not implemented");
1039 break;
1040 }
1041 }
1042
1043 INEKernel::configure(win);
1044}
1045
1046void NEDirectConvolutionLayerKernel::run(const Window &window)
1047{
1048 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1049 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1050 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1051
1052 const int kernel_size = _weights->info()->dimension(0);
1053
1054 switch(kernel_size)
1055 {
1056 case 1:
1057 {
Pablo Tello0d176142017-07-06 16:43:14 +01001058 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001059 {
Pablo Tello0d176142017-07-06 16:43:14 +01001060 case DataType::QS8:
1061 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1062 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001063 case DataType::QS16:
1064 convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1065 break;
Pablo Tello0d176142017-07-06 16:43:14 +01001066 case DataType::F32:
1067 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1068 break;
1069#ifdef ARM_COMPUTE_ENABLE_FP16
1070 case DataType::F16:
1071 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1072 break;
1073#endif /* ARM_COMPUTE_ENABLE_FP16 */
1074 default:
1075 ARM_COMPUTE_ERROR("Data type not supported");
1076 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001077 }
1078 break;
1079 }
1080 case 3:
1081 {
Pablo Tello0d176142017-07-06 16:43:14 +01001082 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001083 {
Pablo Tello0d176142017-07-06 16:43:14 +01001084 case DataType::QS8:
1085 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1086 break;
1087 case DataType::F32:
1088 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1089 break;
1090#ifdef ARM_COMPUTE_ENABLE_FP16
1091 case DataType::F16:
1092 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1093 break;
1094#endif /* ARM_COMPUTE_ENABLE_FP16 */
1095 default:
1096 ARM_COMPUTE_ERROR("Data type not supported");
1097 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001098 }
1099 break;
1100 }
1101 default:
1102 {
1103 ARM_COMPUTE_ERROR("Only kernel sizes 1x1 and 3x3 are supported.");
1104 break;
1105 }
1106 }
1107}