blob: d23a2e584742e89ead6bc2984d715326db9d7ba5 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010033#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/Validate.h"
35
36#include <algorithm>
37#include <arm_neon.h>
38
39using namespace arm_compute;
40
41namespace
42{
Pablo Tellof87cc7f2017-07-26 10:28:40 +010043template <unsigned int stridex>
44qint16x8_t internal_vld1q(const qint16_t *in);
45
46template <>
47qint16x8_t internal_vld1q<1>(const qint16_t *in)
48{
49 return vld1q_qs16(in);
50}
51
52template <>
53qint16x8_t internal_vld1q<2>(const qint16_t *in)
54{
55 const int16x8x2_t tmp = vld2q_s16(in);
56 return tmp.val[0];
57}
58
59template <>
60qint16x8_t internal_vld1q<3>(const qint16_t *in)
61{
62 const int16x8x3_t tmp = vld3q_s16(in);
63 return tmp.val[0];
64}
65
66inline qint16x8_t internal_vdupq_n(qint16_t v)
67{
68 return vdupq_n_qs16(v);
69}
70
Pablo Tello0d176142017-07-06 16:43:14 +010071#ifdef ARM_COMPUTE_ENABLE_FP16
72template <unsigned int stridex>
73float16x8_t internal_vld1q(const float16_t *in);
74
75template <>
76float16x8_t internal_vld1q<1>(const float16_t *in)
77{
78 return vld1q_f16(in);
79}
80
81template <>
82float16x8_t internal_vld1q<2>(const float16_t *in)
83{
84 const float16x8x2_t tmp = vld2q_f16(in);
85 return tmp.val[0];
86}
87
88template <>
89float16x8_t internal_vld1q<3>(const float16_t *in)
90{
91 const float16x8x3_t tmp = vld3q_f16(in);
92 return tmp.val[0];
93}
94
95inline float16x8_t internal_vdupq_n(float16_t v)
96{
97 return vdupq_n_f16(v);
98}
99
100inline void internal_vst1q(float16_t *p, const float16x8_t &v)
101{
102 vst1q_f16(p, v);
103}
104
105float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
106{
107 ARM_COMPUTE_UNUSED(fixed_point_position);
108 return vmulq_f16(x, y);
109}
110
111inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
112{
113 ARM_COMPUTE_UNUSED(fixed_point_position);
114 return vaddq_f16(x, vmulq_f16(y, z));
115}
116#endif /* ARM_COMPUTE_ENABLE_FP16 */
117
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100118template <unsigned int stridex>
119float32x4_t internal_vld1q(const float *in);
120
121template <>
122float32x4_t internal_vld1q<1>(const float *in)
123{
124 return vld1q_f32(in);
125}
126
127template <>
128float32x4_t internal_vld1q<2>(const float *in)
129{
130 const float32x4x2_t tmp = vld2q_f32(in);
131 return tmp.val[0];
132}
133
134template <>
135float32x4_t internal_vld1q<3>(const float *in)
136{
137 const float32x4x3_t tmp = vld3q_f32(in);
138 return tmp.val[0];
139}
140
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100141inline float32x4_t internal_vdupq_n(float v)
142{
143 return vdupq_n_f32(v);
144}
145
146inline void internal_vst1q(float *p, const float32x4_t &v)
147{
148 vst1q_f32(p, v);
149}
150
151float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
152{
153 ARM_COMPUTE_UNUSED(fixed_point_position);
154 return vmulq_f32(x, y);
155}
156
157inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
158{
159 ARM_COMPUTE_UNUSED(fixed_point_position);
160 return vmlaq_f32(x, y, z);
161}
162
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100163template <unsigned int stridex>
164qint8x8_t internal_vld1q(const qint8_t *in);
165
166template <>
167qint8x8_t internal_vld1q<1>(const qint8_t *in)
168{
169 return vld1_qs8(in);
170}
171
172template <>
173qint8x8_t internal_vld1q<2>(const qint8_t *in)
174{
175 const qint8x8x2_t tmp = vld2_s8(in);
176 return tmp.val[0];
177}
178
179template <>
180qint8x8_t internal_vld1q<3>(const qint8_t *in)
181{
182 const qint8x8x3_t tmp = vld3_s8(in);
183 return tmp.val[0];
184}
185
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100186inline qint8x8_t internal_vdupq_n(qint8_t v)
187{
188 return vdup_n_qs8(v);
189}
190
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100191inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100192{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100193 return vmull_qs8(x, y, fixed_point_position);
194}
195
196inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
197{
198 return vqmlal_qs8(x, y, z, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100199}
200
201inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
202{
203 vst1q_qs16(p, v);
204}
205
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100206inline void internal_vst1q(int *p, const qint32x4x2_t &v)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100207{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100208 vst1q_s32(p, v.val[0]);
209 vst1q_s32(p + 4, v.val[1]);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100210}
211
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100212template <unsigned int stridex>
213qint32x4x2_t internal_vld1q(const qint32_t *in);
214
215template <>
216qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100217{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100218 const qint32x4x2_t r =
219 {
220 {
221 vld1q_s32(in),
222 vld1q_s32(in + 4)
223 }
224 };
225 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100226}
227
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100228inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100229{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100230 const qint32x4x2_t r =
231 {
232 {
233 vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
234 vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
235 }
236 };
237 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100238}
239
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100240inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100241{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100242 const qint32x4x2_t r =
243 {
244 {
245 vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
246 vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
247 }
248 };
249 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100250}
251
252template <typename T1, typename T2, unsigned int stridex>
253class convolver_1x1
254{
255public:
256 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
257 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
258 {
259 const int input_stride_y = input->info()->strides_in_bytes().y();
260 const int input_stride_z = input->info()->strides_in_bytes().z();
261 const int output_stride_y = output->info()->strides_in_bytes().y();
262 const int output_stride_z = output->info()->strides_in_bytes().z();
263 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
264 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
265 const int output_w = output->info()->dimension(0);
266 const int output_h = output->info()->dimension(1);
267 const int range_z = window.z().end() - window.z().start();
268 const int kernel_depth = weights->info()->dimension(Window::DimZ);
269 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
270 const int fixed_point_position = input->info()->fixed_point_position();
271
272 // setup output window for the iterator
273 Window window_out = window;
274 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
275 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
276 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
277
278 // setup input window for the iterator
279 Window window_in = window;
280 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
281 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
282 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
283 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
284
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100285 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100286 Iterator out(output, window_out);
287 Iterator in(input, window_in);
288 Iterator k(weights, window_k);
289
290 const uint8_t *k_ptr = k.ptr();
291
292 execute_window_loop(window_out, [&](const Coordinates & id)
293 {
294 /*
295 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
296 */
297 const uint8_t *input_ptr = in.ptr();
298 uint8_t *out_ptr = out.ptr();
299 int ih = 0;
300 int oh = 0;
301 for(int oz = 0; oz < range_z; ++oz)
302 {
303 auto p_out_base = out_ptr + oz * output_stride_z;
304 // Step 1
305 {
306 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
307 const auto vk = internal_vdupq_n(*k_val);
308 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
309 {
310 const int offset_xy = ih * input_stride_y;
311 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
312 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
313 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
314 {
315 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
316 }
317 }
318 }
319 // Step 2
320 for(int p = 1; p < kernel_depth; ++p)
321 {
322 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
323 const auto vk = internal_vdupq_n(*k_val);
324 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
325 {
326 const int offset_xy = ih * input_stride_y;
327 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
328 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
329 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
330 {
331 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
332 }
333 }
334 }
335 }
336 },
337 in, out);
338 }
339};
340
Pablo Tello0d176142017-07-06 16:43:14 +0100341#ifdef ARM_COMPUTE_ENABLE_FP16
342inline float16x8x3_t load_matrix_row(const float16_t *ptr)
343{
344 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
345 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
346 const float16x8x3_t r =
347 {
348 {
349 vld1q_dup_f16(ptr),
350 vld1q_dup_f16(1 + ptr),
351 vld1q_dup_f16(2 + ptr)
352 }
353 };
354 return r;
355}
356
357template <unsigned int stridex>
358float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
359 int fixed_point_position);
360
361template <>
362float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
363 int fixed_point_position)
364{
365 ARM_COMPUTE_UNUSED(fixed_point_position);
366
367 const float16x8x3_t vtop =
368 {
369 {
370 vld1q_f16(in_top),
371 vld1q_f16(in_top + 8),
372 vld1q_f16(in_top + 16)
373 }
374 };
375 const float16x8x3_t vmid =
376 {
377 {
378 vld1q_f16(in_mid),
379 vld1q_f16(in_mid + 8),
380 vld1q_f16(in_mid + 16)
381 }
382 };
383 const float16x8x3_t vlow =
384 {
385 {
386 vld1q_f16(in_low),
387 vld1q_f16(in_low + 8),
388 vld1q_f16(in_low + 16)
389 }
390 };
391 float16x8x2_t out =
392 {
393 {
394 vmulq_f16(vtop.val[0], m0.val[0]),
395 vmulq_f16(vtop.val[1], m0.val[0])
396 }
397 };
398 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
399 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
400 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
401 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
402 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
403 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
404 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
405 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
406 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
407 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
408 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
409 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
410 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
411 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
412 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
413 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
414 return out;
415}
416
417template <>
418inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
419 int fixed_point_position)
420{
421 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
422 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
423 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 2);
424 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 3);
425 return out;
426}
427
428template <>
429inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
430 int fixed_point_position)
431{
432 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
433 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
434 return out;
435}
436
437template <unsigned int stridex>
438void store_results(float16_t *buffer, const float16x8x2_t &values);
439
440template <>
441void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
442{
443 vst1q_f16(buffer, values.val[0]);
444 vst1q_f16(buffer + 8, values.val[1]);
445}
446
447template <>
448void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
449{
450 vst1q_f16(buffer, values.val[0]);
451}
452
453template <>
454void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
455{
456 vst1_f16(buffer, vget_low_f16(values.val[0]));
457}
458
459template <unsigned int stridex>
460void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
461
462template <>
463void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
464{
465 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
466 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
467}
468
469template <>
470void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
471{
472 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
473}
474
475template <>
476void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
477{
478 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
479}
480
481#endif /* ARM_COMPUTE_ENABLE_FP16 */
482
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100483inline float32x4x3_t load_matrix_row(const float *ptr)
484{
485 const float32x4x3_t r =
486 {
487 {
488 vld1q_dup_f32(ptr),
489 vld1q_dup_f32(1 + ptr),
490 vld1q_dup_f32(2 + ptr)
491 }
492 };
493 return r;
494}
495inline qint8x8x3_t load_matrix_row(const qint8_t *ptr)
496{
497 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
498 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
499 const qint8x8x3_t r =
500 {
501 {
502 vld1_dup_qs8(ptr),
503 vld1_dup_qs8(1 + ptr),
504 vld1_dup_qs8(2 + ptr)
505 }
506 };
507 return r;
508}
509
510template <unsigned int stridex>
Pablo Tello06da39d2017-08-10 15:10:40 +0100511float32x4x2_t convolve_5x5(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
512 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position);
513
514inline float32x4x3_t load_matrix_hi(const float *const m0, const float *const m1, const float *const m2)
515{
516 const float32x4x3_t m00 =
517 {
518 {
519 vld1q_dup_f32(m0),
520 vld1q_dup_f32(m1),
521 vld1q_dup_f32(m2)
522 }
523 };
524 return m00;
525}
526
527inline float32x4x2_t load_matrix_lo(const float *const m3, const float *const m4)
528{
529 const float32x4x2_t m00 =
530 {
531 {
532 vld1q_dup_f32(m3),
533 vld1q_dup_f32(m4)
534 }
535 };
536 return m00;
537}
538
539inline float32x4x3_t load_input(const float *const in)
540{
541 const float32x4x3_t vin =
542 {
543 {
544 vld1q_f32(in),
545 vld1q_f32(in + 4),
546 vld1q_f32(in + 8)
547 }
548 };
549 return vin;
550}
551
552template <>
553inline float32x4x2_t convolve_5x5<1>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
554 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
555{
556 ARM_COMPUTE_UNUSED(fixed_point_position);
557 const float32x4x3_t vin0 = load_input(in_0);
558 const float32x4x3_t vin1 = load_input(in_1);
559 const float32x4x3_t vin2 = load_input(in_2);
560 const float32x4x3_t vin3 = load_input(in_3);
561 const float32x4x3_t vin4 = load_input(in_4);
562 const float32x4x3_t m00 = load_matrix_hi(m0, 1 + m0, 2 + m0);
563 const float32x4x2_t m01 = load_matrix_lo(3 + m0, 4 + m0);
564 const float32x4x3_t m10 = load_matrix_hi(m1, 1 + m1, 2 + m1);
565 const float32x4x2_t m11 = load_matrix_lo(3 + m1, 4 + m1);
566 const float32x4x3_t m20 = load_matrix_hi(m2, 1 + m2, 2 + m2);
567 const float32x4x2_t m21 = load_matrix_lo(3 + m2, 4 + m2);
568 const float32x4x3_t m30 = load_matrix_hi(m3, 1 + m3, 2 + m3);
569 const float32x4x2_t m31 = load_matrix_lo(3 + m3, 4 + m3);
570 const float32x4x3_t m40 = load_matrix_hi(m4, 1 + m4, 2 + m4);
571 const float32x4x2_t m41 = load_matrix_lo(3 + m4, 4 + m4);
572
573 float32x4x2_t out =
574 {
575 {
576 vmulq_f32(vin0.val[0], m00.val[0]),
577 vmulq_f32(vin0.val[1], m00.val[0])
578 }
579 };
580
581 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 1), m00.val[1]);
582 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 2), m00.val[2]);
583 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin0.val[0], vin0.val[1], 3), m01.val[0]);
584 out.val[0] = vmlaq_f32(out.val[0], vin0.val[1], m01.val[1]);
585
586 out.val[0] = vmlaq_f32(out.val[0], vin1.val[0], m10.val[0]);
587 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 1), m10.val[1]);
588 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 2), m10.val[2]);
589 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin1.val[0], vin1.val[1], 3), m11.val[0]);
590 out.val[0] = vmlaq_f32(out.val[0], vin1.val[1], m11.val[1]);
591
592 out.val[0] = vmlaq_f32(out.val[0], vin2.val[0], m20.val[0]);
593 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 1), m20.val[1]);
594 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 2), m20.val[2]);
595 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin2.val[0], vin2.val[1], 3), m21.val[0]);
596 out.val[0] = vmlaq_f32(out.val[0], vin2.val[1], m21.val[1]);
597
598 out.val[0] = vmlaq_f32(out.val[0], vin3.val[0], m30.val[0]);
599 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 1), m30.val[1]);
600 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 2), m30.val[2]);
601 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin3.val[0], vin3.val[1], 3), m31.val[0]);
602 out.val[0] = vmlaq_f32(out.val[0], vin3.val[1], m31.val[1]);
603
604 out.val[0] = vmlaq_f32(out.val[0], vin4.val[0], m40.val[0]);
605 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 1), m40.val[1]);
606 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 2), m40.val[2]);
607 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vin4.val[0], vin4.val[1], 3), m41.val[0]);
608 out.val[0] = vmlaq_f32(out.val[0], vin4.val[1], m41.val[1]);
609
610 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 1), m00.val[1]);
611 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 2), m00.val[2]);
612 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin0.val[1], vin0.val[2], 3), m01.val[0]);
613 out.val[1] = vmlaq_f32(out.val[1], vin0.val[2], m01.val[1]);
614
615 out.val[1] = vmlaq_f32(out.val[1], vin1.val[1], m10.val[0]);
616 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 1), m10.val[1]);
617 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 2), m10.val[2]);
618 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin1.val[1], vin1.val[2], 3), m11.val[0]);
619 out.val[1] = vmlaq_f32(out.val[1], vin1.val[2], m11.val[1]);
620
621 out.val[1] = vmlaq_f32(out.val[1], vin2.val[1], m20.val[0]);
622 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 1), m20.val[1]);
623 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 2), m20.val[2]);
624 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin2.val[1], vin2.val[2], 3), m21.val[0]);
625 out.val[1] = vmlaq_f32(out.val[1], vin2.val[2], m21.val[1]);
626
627 out.val[1] = vmlaq_f32(out.val[1], vin3.val[1], m30.val[0]);
628 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 1), m30.val[1]);
629 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 2), m30.val[2]);
630 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin3.val[1], vin3.val[2], 3), m31.val[0]);
631 out.val[1] = vmlaq_f32(out.val[1], vin3.val[2], m31.val[1]);
632
633 out.val[1] = vmlaq_f32(out.val[1], vin4.val[1], m40.val[0]);
634 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 1), m40.val[1]);
635 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 2), m40.val[2]);
636 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vin4.val[1], vin4.val[2], 3), m41.val[0]);
637 out.val[1] = vmlaq_f32(out.val[1], vin4.val[2], m41.val[1]);
638
639 return out;
640}
641
642template <>
643inline float32x4x2_t convolve_5x5<2>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
644 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
645{
646 ARM_COMPUTE_UNUSED(fixed_point_position);
647 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
648 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
649 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
650 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
651 return out;
652}
653
654template <>
655inline float32x4x2_t convolve_5x5<3>(const float *in_0, const float *in_1, const float *in_2, const float *in_3, const float *in_4,
656 const float *m0, const float *m1, const float *m2, const float *m3, const float *m4, int fixed_point_position)
657{
658 float32x4x2_t out = convolve_5x5<1>(in_0, in_1, in_2, in_3, in_4, m0, m1, m2, m3, m4, fixed_point_position);
659 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
660 return out;
661}
662
663template <unsigned int stridex>
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100664float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position);
665
666template <>
667inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
668{
669 ARM_COMPUTE_UNUSED(fixed_point_position);
670
671 const float32x4x3_t vtop =
672 {
673 {
674 vld1q_f32(in_top),
675 vld1q_f32(in_top + 4),
676 vld1q_f32(in_top + 8)
677 }
678 };
679 const float32x4x3_t vmid =
680 {
681 {
682 vld1q_f32(in_mid),
683 vld1q_f32(in_mid + 4),
684 vld1q_f32(in_mid + 8)
685 }
686 };
687 const float32x4x3_t vlow =
688 {
689 {
690 vld1q_f32(in_low),
691 vld1q_f32(in_low + 4),
692 vld1q_f32(in_low + 8)
693 }
694 };
695 float32x4x2_t out =
696 {
697 {
698 vmulq_f32(vtop.val[0], m0.val[0]),
699 vmulq_f32(vtop.val[1], m0.val[0])
700 }
701 };
702 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
703 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100704
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100705 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
706 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
707 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100708
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100709 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
710 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
711 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100712
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100713 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
714 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100715
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100716 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
717 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
718 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
Pablo Tello06da39d2017-08-10 15:10:40 +0100719
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100720 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
721 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
722 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
723 return out;
724}
725
726template <>
727inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
728{
729 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
730 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
731 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
732 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
733 return out;
734}
735
736template <>
737inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
738{
739 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
740 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
741 return out;
742}
743
744template <unsigned int stridex>
745qint16x8x2_t convolve_3x3(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position);
746
747template <>
748inline qint16x8x2_t convolve_3x3<1>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
749{
750 ARM_COMPUTE_UNUSED(fixed_point_position);
751
752 const qint8x8x3_t vtop =
753 {
754 {
755 vld1_qs8(in_top),
756 vld1_qs8(in_top + 8),
757 vld1_qs8(in_top + 16)
758 }
759 };
760 const qint8x8x3_t vmid =
761 {
762 {
763 vld1_qs8(in_mid),
764 vld1_qs8(in_mid + 8),
765 vld1_qs8(in_mid + 16)
766 }
767 };
768 const qint8x8x3_t vlow =
769 {
770 {
771 vld1_qs8(in_low),
772 vld1_qs8(in_low + 8),
773 vld1_qs8(in_low + 16)
774 }
775 };
776 qint16x8x2_t out =
777 {
778 {
779 vmull_qs8(vtop.val[0], m0.val[0], fixed_point_position),
780 vmull_qs8(vtop.val[1], m0.val[0], fixed_point_position)
781 }
782 };
783 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 1), m0.val[1], fixed_point_position);
784 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 2), m0.val[2], fixed_point_position);
785 out.val[0] = vqmlal_qs8(out.val[0], vmid.val[0], m1.val[0], fixed_point_position);
786 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 1), m1.val[1], fixed_point_position);
787 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 2), m1.val[2], fixed_point_position);
788 out.val[0] = vqmlal_qs8(out.val[0], vlow.val[0], m2.val[0], fixed_point_position);
789 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 1), m2.val[1], fixed_point_position);
790 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 2), m2.val[2], fixed_point_position);
791 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 1), m0.val[1], fixed_point_position);
792 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 2), m0.val[2], fixed_point_position);
793 out.val[1] = vqmlal_qs8(out.val[1], vmid.val[1], m1.val[0], fixed_point_position);
794 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 1), m1.val[1], fixed_point_position);
795 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 2), m1.val[2], fixed_point_position);
796 out.val[1] = vqmlal_qs8(out.val[1], vlow.val[1], m2.val[0], fixed_point_position);
797 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 1), m2.val[1], fixed_point_position);
798 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 2), m2.val[2], fixed_point_position);
799 return out;
800}
801
802template <>
803inline qint16x8x2_t convolve_3x3<2>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
804{
805 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
806 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 2), out.val[0], 1);
807 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 4), out.val[0], 2);
808 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 3);
809 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 0), out.val[0], 4);
810 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 2), out.val[0], 5);
811 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 4), out.val[0], 6);
812 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 6), out.val[0], 7);
813 return out;
814}
815
816template <>
817inline qint16x8x2_t convolve_3x3<3>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
818{
819 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
820 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 3), out.val[0], 1);
821 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 2);
822 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 1), out.val[0], 3);
823 return out;
824}
825
826template <unsigned int stridex>
827void store_results(float *buffer, const float32x4x2_t &values);
828
829template <>
830void store_results<1>(float *buffer, const float32x4x2_t &values)
831{
832 vst1q_f32(buffer, values.val[0]);
833 vst1q_f32(buffer + 4, values.val[1]);
834}
835
836template <>
837void store_results<2>(float *buffer, const float32x4x2_t &values)
838{
839 vst1q_f32(buffer, values.val[0]);
840}
841
842template <>
843void store_results<3>(float *buffer, const float32x4x2_t &values)
844{
845 vst1_f32(buffer, vget_low_f32(values.val[0]));
846}
847
848template <unsigned int stridex>
849void store_results(qint16_t *buffer, const qint16x8x2_t &values);
850
851template <>
852void store_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
853{
854 vst1q_qs16(buffer, values.val[0]);
855 vst1q_qs16(buffer + 8, values.val[1]);
856}
857
858template <>
859void store_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
860{
861 vst1q_qs16(buffer, values.val[0]);
862}
863
864template <>
865void store_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
866{
867 vst1_qs16(buffer, vget_low_s16(values.val[0]));
868}
869
870template <unsigned int stridex>
871void accumulate_results(float *buffer, const float32x4x2_t &values);
872
873template <>
874void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
875{
876 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
877 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
878}
879
880template <>
881void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
882{
883 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
884}
885
886template <>
887void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
888{
889 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
890}
891
892template <unsigned int stridex>
893void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
894
895template <>
896void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
897{
898 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
899 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
900}
901
902template <>
903void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
904{
905 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
906}
907
908template <>
909void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
910{
911 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
912}
913
914template <unsigned int stridex>
915int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
916
917template <>
918int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
919{
920 return num_elems_written_per_iteration;
921}
922
923template <>
924int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
925{
926 return num_elems_written_per_iteration << 1;
927}
928
929template <>
930int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
931{
932 return num_elems_written_per_iteration * 3;
933}
934
935template <typename T1, typename T2, unsigned int stridex>
936class convolver_3x3
937{
938public:
939 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
940 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
941 {
942 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
943 const int input_stride_x = input->info()->strides_in_bytes().x();
944 const int input_stride_y = input->info()->strides_in_bytes().y();
945 const int input_stride_z = input->info()->strides_in_bytes().z();
946 const int output_stride_y = output->info()->strides_in_bytes().y();
947 const int output_stride_z = output->info()->strides_in_bytes().z();
948 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
949 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
950 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
951 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
952 const int output_w = output->info()->dimension(0);
953 const int output_h = output->info()->dimension(1);
954 const int num_planes_z = window.z().end() - window.z().start();
955 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
956 const int kernel_depth = weights->info()->dimension(Window::DimZ);
957 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
958 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
959 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
960 const int fixed_point_position = input->info()->fixed_point_position();
961
962 // setup output window for the iterator
963 Window window_out = window;
964 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
965 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
966 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
967
968 // setup input window for the iterator
969 Window window_in = window;
970 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
971 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
972 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
973 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
974
975 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
976
977 Iterator out(output, window_out);
978 Iterator in(input, window_in);
979 Iterator k(weights, window_k);
980
981 const uint8_t *k_ptr = k.ptr();
982
983 execute_window_loop(window_out, [&](const Coordinates & id)
984 {
985 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
986 uint8_t *out_ptr = out.ptr();
987 int ih = 0;
988 int oh = 0;
989 /*
990 Each thread executing this kernel computes one or more output's volume planes.
991
992 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
993 the third thread [16,24] and the fourth thread [25,31].
994
995 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
996 is that we setup the neon registers containing the kernerl's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
997
998 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
999 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
1000 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
1001 */
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001002 for(int oz = 0; oz < num_planes_z; ++oz)
1003 {
Pablo Tello0d176142017-07-06 16:43:14 +01001004 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001005 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
1006 // Step 1
1007 {
Pablo Tello0d176142017-07-06 16:43:14 +01001008 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
1009 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
1010 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001011 const auto vk_r0 = load_matrix_row(ptr_k_r0);
1012 const auto vk_r1 = load_matrix_row(ptr_k_r1);
1013 const auto vk_r2 = load_matrix_row(ptr_k_r2);
1014 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
1015 {
1016 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
1017 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
1018 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
1019 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1020 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1021 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
1022 {
1023 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
1024 store_results<stridex>(p_out, vres);
1025 }
1026 }
1027 }
1028 // Step 2
1029 for(int p = 1; p < kernel_depth; ++p)
1030 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001031 const uint8_t *ptr_k_base = k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w;
1032 const uint8_t *input_base = input_ptr + p * input_stride_z;
1033 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(ptr_k_base);
1034 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y);
1035 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(ptr_k_base + kernel_stride_y * 2);
1036 const auto vk_r0 = load_matrix_row(ptr_k_r0);
1037 const auto vk_r1 = load_matrix_row(ptr_k_r1);
1038 const auto vk_r2 = load_matrix_row(ptr_k_r2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001039 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
1040 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001041 auto in_top = reinterpret_cast<const T1 *>(input_base + (ih + 0) * input_stride_y);
1042 auto in_mid = reinterpret_cast<const T1 *>(input_base + (ih + 1) * input_stride_y);
1043 auto in_low = reinterpret_cast<const T1 *>(input_base + (ih + 2) * input_stride_y);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001044 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1045 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1046 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
1047 {
1048 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
1049 accumulate_results<stridex>(p_out, vres);
1050 }
1051 }
1052 }
1053 }
1054 },
1055 in, out);
1056 }
1057};
1058
Pablo Tello06da39d2017-08-10 15:10:40 +01001059template <typename T1, typename T2, unsigned int stridex>
1060class convolver_5x5
1061{
1062public:
1063 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1064 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1065 {
1066 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
1067 const int input_stride_x = input->info()->strides_in_bytes().x();
1068 const int input_stride_y = input->info()->strides_in_bytes().y();
1069 const int input_stride_z = input->info()->strides_in_bytes().z();
1070 const int output_stride_y = output->info()->strides_in_bytes().y();
1071 const int output_stride_z = output->info()->strides_in_bytes().z();
1072 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
1073 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
1074 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
1075 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
1076 const int output_w = output->info()->dimension(0);
1077 const int output_h = output->info()->dimension(1);
1078 const int num_planes_z = window.z().end() - window.z().start();
1079 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
1080 const int kernel_depth = weights->info()->dimension(Window::DimZ);
1081 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
1082 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
1083 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
1084 const int fixed_point_position = input->info()->fixed_point_position();
1085
1086 // setup output window for the iterator
1087 Window window_out = window;
1088 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
1089 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
1090 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
1091
1092 // setup input window for the iterator
1093 Window window_in = window;
1094 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
1095 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
1096 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
1097 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
1098
1099 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
1100
1101 Iterator out(output, window_out);
1102 Iterator in(input, window_in);
1103 Iterator k(weights, window_k);
1104
1105 const uint8_t *k_ptr = k.ptr();
1106
1107 execute_window_loop(window_out, [&](const Coordinates & id)
1108 {
1109 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
1110 uint8_t *out_ptr = out.ptr();
1111 int ih = 0;
1112 int oh = 0;
1113 for(int oz = 0; oz < num_planes_z; ++oz)
1114 {
1115 const int zoffset = id.z() + oz;
1116 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
1117 // Step 1
1118 {
1119 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
1120 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
1121 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
1122 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
1123 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
1124 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
1125 {
1126 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
1127 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
1128 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
1129 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 3) * input_stride_y);
1130 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 4) * input_stride_y);
1131 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1132 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1133 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
1134 {
1135 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
1136 store_results<stridex>(p_out, vres);
1137 }
1138 }
1139 }
1140 // Step 2
1141 for(int p = 1; p < kernel_depth; ++p)
1142 {
1143 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
1144 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
1145 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
1146 const auto ptr_k_r3 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 3 * kernel_stride_y + 0 * kernel_stride_x);
1147 const auto ptr_k_r4 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 4 * kernel_stride_y + 0 * kernel_stride_x);
1148
1149 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
1150 {
1151 auto in_0 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
1152 auto in_1 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
1153 auto in_2 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
1154 auto in_3 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 3) * input_stride_y);
1155 auto in_4 = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 4) * input_stride_y);
1156 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
1157 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
1158 in_0 += delta_input, in_1 += delta_input, in_2 += delta_input, in_3 += delta_input, in_4 += delta_input, p_out += num_elems_written_per_iteration)
1159 {
1160 auto vres = convolve_5x5<stridex>(in_0, in_1, in_2, in_3, in_4, ptr_k_r0, ptr_k_r1, ptr_k_r2, ptr_k_r3, ptr_k_r4, fixed_point_position);
1161 accumulate_results<stridex>(p_out, vres);
1162 }
1163 }
1164 }
1165 }
1166 },
1167 in, out);
1168 }
1169};
1170
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001171template <typename T1, typename T2>
1172inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1173 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1174{
1175 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1176 switch(conv_stride_x)
1177 {
1178 case 1:
1179 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1180 break;
1181 case 2:
1182 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1183 break;
1184 case 3:
1185 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1186 break;
1187 default:
1188 ARM_COMPUTE_ERROR("Not implemented");
1189 }
1190}
1191
1192template <typename T1, typename T2>
1193inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1194 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1195{
1196 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1197 switch(conv_stride_x)
1198 {
1199 case 1:
1200 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1201 break;
1202 case 2:
1203 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1204 break;
1205 case 3:
1206 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1207 break;
1208 default:
1209 ARM_COMPUTE_ERROR("Not implemented");
1210 }
1211}
Pablo Tello06da39d2017-08-10 15:10:40 +01001212
1213template <typename T1, typename T2>
1214inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
1215 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1216{
1217 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1218 switch(conv_stride_x)
1219 {
1220 case 1:
1221 convolver_5x5<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1222 break;
1223 case 2:
1224 convolver_5x5<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1225 break;
1226 case 3:
1227 convolver_5x5<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
1228 break;
1229 default:
1230 ARM_COMPUTE_ERROR("Not implemented");
1231 }
1232}
1233
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001234} // namespace
1235
1236NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
Georgios Pinitas898a8062017-09-12 19:19:12 +01001237 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_weight_elems_read_per_row(0), _num_elems_read_per_iteration(0),
1238 _num_elems_written_per_iteration(0)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001239{
1240}
1241
1242BorderSize NEDirectConvolutionLayerKernel::border_size() const
1243{
1244 return _border_size;
1245}
1246
1247void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
1248{
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001249 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::QS16, DataType::F32);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001250 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001251 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
1252 "Pad > 0 not supported for 1x1 weights");
1253 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
1254 "Pad > 1 not supported for 3x3 weights");
Pablo Tello06da39d2017-08-10 15:10:40 +01001255 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 5 && (std::get<0>(conv_info.pad()) > 2 || std::get<1>(conv_info.pad()) > 2),
1256 "Pad > 2 not supported for 5x5 weights");
1257
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001258 ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001259 ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
1260 ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1));
1261 ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001262
1263 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
1264 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
1265 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
1266
1267 _input = input;
1268 _weights = weights;
1269 _output = output;
1270 _conv_info = conv_info;
1271 _kernel_size = weights->info()->dimension(0);
1272 _border_size = BorderSize(conv_pad_y, conv_pad_x);
1273
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +01001274 const unsigned int kernel_size = weights->info()->dimension(0);
1275
1276 // Get convolved dimensions
1277 unsigned int output_width = 0;
1278 unsigned int output_height = 0;
1279 std::tie(output_width, output_height) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_size, kernel_size, conv_info);
1280
1281 TensorShape output_shape = input->info()->tensor_shape();
1282 output_shape.set(0, output_width);
1283 output_shape.set(1, output_height);
1284 output_shape.set(2, weights->info()->dimension(3));
1285
1286 DataType data_type = input->info()->data_type();
1287
1288 if(is_data_type_fixed_point(data_type))
1289 {
1290 // Promote data type in case of fixed point
1291 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
1292 }
1293
1294 // Output auto inizialitation if not yet initialized
1295 auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
1296
1297 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
1298 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, output->info()->data_type());
1299
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001300 switch(_kernel_size)
1301 {
1302 case 1:
1303 {
Pablo Tello0d176142017-07-06 16:43:14 +01001304 switch(input->info()->data_type())
1305 {
1306#ifdef ARM_COMPUTE_ENABLE_FP16
1307 case DataType::F16:
1308#endif /* ARM_COMPUTE_ENABLE_FP16 */
1309 case DataType::QS8:
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001310 case DataType::QS16:
Pablo Tello0d176142017-07-06 16:43:14 +01001311 _num_elems_written_per_iteration = 8;
1312 break;
1313 case DataType::F32:
1314 _num_elems_written_per_iteration = 4;
1315 break;
1316 default:
1317 ARM_COMPUTE_ERROR("Data type not supported.");
1318 break;
1319 }
Georgios Pinitas898a8062017-09-12 19:19:12 +01001320 _num_weight_elems_read_per_row = kernel_size;
1321 _num_elems_read_per_iteration = conv_stride_x * _num_elems_written_per_iteration;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001322 break;
1323 }
1324 case 3:
Pablo Tello06da39d2017-08-10 15:10:40 +01001325 case 5:
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001326 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001327 switch(input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001328 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001329 case DataType::F32:
Georgios Pinitas898a8062017-09-12 19:19:12 +01001330 _num_weight_elems_read_per_row = 4 + _kernel_size - 1;
Pablo Tello06da39d2017-08-10 15:10:40 +01001331 _num_elems_read_per_iteration = 12;
1332 _num_elems_written_per_iteration = 16 >> conv_stride_x;
1333 break;
1334#ifdef ARM_COMPUTE_ENABLE_FP16
1335 case DataType::F16:
1336#endif /* ARM_COMPUTE_ENABLE_FP16 */
1337 case DataType::QS8:
1338 case DataType::QS16:
Georgios Pinitas898a8062017-09-12 19:19:12 +01001339 _num_weight_elems_read_per_row = 8 + _kernel_size - 1;
Pablo Tello06da39d2017-08-10 15:10:40 +01001340 _num_elems_read_per_iteration = 24;
1341 _num_elems_written_per_iteration = 32 >> conv_stride_x;
1342 break;
1343 default:
1344 ARM_COMPUTE_ERROR("Data type not supported.");
1345 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001346 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001347 }
Georgios Pinitas898a8062017-09-12 19:19:12 +01001348 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001349 default:
1350 {
1351 ARM_COMPUTE_ERROR("Not implemented");
1352 break;
1353 }
1354 }
1355
Georgios Pinitas898a8062017-09-12 19:19:12 +01001356 // Calculate right and bottom border
1357 const unsigned int conv_stride_y = std::get<1>(_conv_info.stride());
1358 const int input_width = input->info()->dimension(0);
1359 const int input_height = input->info()->dimension(1);
1360 const int upper_bound_w = ceil_to_multiple(((output->info()->dimension(0) - 1) * conv_stride_x + _kernel_size), _num_elems_read_per_iteration) - conv_pad_x - input_width;
1361 const int upper_bound_h = ((output->info()->dimension(1) - 1) * conv_stride_y - conv_pad_y + _kernel_size) - input_height;
1362 _border_size.right = std::max(upper_bound_w, static_cast<int>(_kernel_size));
1363 _border_size.bottom = std::max(upper_bound_h, static_cast<int>(_kernel_size));
1364
1365 Window win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
1366 AccessWindowStatic input_access(input->info(), -conv_pad_x, -conv_pad_y, input_width + _border_size.right, input_height + _border_size.bottom);
1367 AccessWindowStatic weights_access(weights->info(), 0, 0, _num_weight_elems_read_per_row, _kernel_size);
1368 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
1369 update_window_and_padding(win, input_access, weights_access, output_access);
1370 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
1371
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001372 INEKernel::configure(win);
1373}
1374
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001375void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo &info)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001376{
Moritz Pflanzerc186b572017-09-07 09:48:04 +01001377 ARM_COMPUTE_UNUSED(info);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001378 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1379 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1380 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1381
1382 const int kernel_size = _weights->info()->dimension(0);
1383
1384 switch(kernel_size)
1385 {
1386 case 1:
1387 {
Pablo Tello0d176142017-07-06 16:43:14 +01001388 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001389 {
Pablo Tello0d176142017-07-06 16:43:14 +01001390 case DataType::QS8:
1391 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1392 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001393 case DataType::QS16:
1394 convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1395 break;
Pablo Tello0d176142017-07-06 16:43:14 +01001396 case DataType::F32:
1397 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1398 break;
1399#ifdef ARM_COMPUTE_ENABLE_FP16
1400 case DataType::F16:
1401 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1402 break;
1403#endif /* ARM_COMPUTE_ENABLE_FP16 */
1404 default:
1405 ARM_COMPUTE_ERROR("Data type not supported");
1406 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001407 }
1408 break;
1409 }
1410 case 3:
1411 {
Pablo Tello0d176142017-07-06 16:43:14 +01001412 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001413 {
Pablo Tello0d176142017-07-06 16:43:14 +01001414 case DataType::QS8:
1415 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1416 break;
1417 case DataType::F32:
1418 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1419 break;
1420#ifdef ARM_COMPUTE_ENABLE_FP16
1421 case DataType::F16:
1422 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1423 break;
1424#endif /* ARM_COMPUTE_ENABLE_FP16 */
1425 default:
1426 ARM_COMPUTE_ERROR("Data type not supported");
1427 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001428 }
1429 break;
1430 }
Pablo Tello06da39d2017-08-10 15:10:40 +01001431 case 5:
1432 {
1433 switch(_input->info()->data_type())
1434 {
1435 case DataType::F32:
1436 convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1437 break;
1438 default:
1439 ARM_COMPUTE_ERROR("Data type not supported");
1440 break;
1441 }
1442 break;
1443 }
1444
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001445 default:
1446 {
Pablo Tello06da39d2017-08-10 15:10:40 +01001447 ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001448 break;
1449 }
1450 }
1451}