blob: 3a102edd10fd9f88987ad3309b45006159508e0d [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/Types.h"
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +010033#include "arm_compute/core/Utils.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010034#include "arm_compute/core/Validate.h"
35
36#include <algorithm>
37#include <arm_neon.h>
38
39using namespace arm_compute;
40
41namespace
42{
Pablo Tellof87cc7f2017-07-26 10:28:40 +010043template <unsigned int stridex>
44qint16x8_t internal_vld1q(const qint16_t *in);
45
46template <>
47qint16x8_t internal_vld1q<1>(const qint16_t *in)
48{
49 return vld1q_qs16(in);
50}
51
52template <>
53qint16x8_t internal_vld1q<2>(const qint16_t *in)
54{
55 const int16x8x2_t tmp = vld2q_s16(in);
56 return tmp.val[0];
57}
58
59template <>
60qint16x8_t internal_vld1q<3>(const qint16_t *in)
61{
62 const int16x8x3_t tmp = vld3q_s16(in);
63 return tmp.val[0];
64}
65
66inline qint16x8_t internal_vdupq_n(qint16_t v)
67{
68 return vdupq_n_qs16(v);
69}
70
Pablo Tello0d176142017-07-06 16:43:14 +010071#ifdef ARM_COMPUTE_ENABLE_FP16
72template <unsigned int stridex>
73float16x8_t internal_vld1q(const float16_t *in);
74
75template <>
76float16x8_t internal_vld1q<1>(const float16_t *in)
77{
78 return vld1q_f16(in);
79}
80
81template <>
82float16x8_t internal_vld1q<2>(const float16_t *in)
83{
84 const float16x8x2_t tmp = vld2q_f16(in);
85 return tmp.val[0];
86}
87
88template <>
89float16x8_t internal_vld1q<3>(const float16_t *in)
90{
91 const float16x8x3_t tmp = vld3q_f16(in);
92 return tmp.val[0];
93}
94
95inline float16x8_t internal_vdupq_n(float16_t v)
96{
97 return vdupq_n_f16(v);
98}
99
100inline void internal_vst1q(float16_t *p, const float16x8_t &v)
101{
102 vst1q_f16(p, v);
103}
104
105float16x8_t internal_vmull(const float16x8_t &x, const float16x8_t &y, int fixed_point_position)
106{
107 ARM_COMPUTE_UNUSED(fixed_point_position);
108 return vmulq_f16(x, y);
109}
110
111inline float16x8_t internal_vmlal(const float16x8_t &x, const float16x8_t &y, const float16x8_t &z, int fixed_point_position)
112{
113 ARM_COMPUTE_UNUSED(fixed_point_position);
114 return vaddq_f16(x, vmulq_f16(y, z));
115}
116#endif /* ARM_COMPUTE_ENABLE_FP16 */
117
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100118template <unsigned int stridex>
119float32x4_t internal_vld1q(const float *in);
120
121template <>
122float32x4_t internal_vld1q<1>(const float *in)
123{
124 return vld1q_f32(in);
125}
126
127template <>
128float32x4_t internal_vld1q<2>(const float *in)
129{
130 const float32x4x2_t tmp = vld2q_f32(in);
131 return tmp.val[0];
132}
133
134template <>
135float32x4_t internal_vld1q<3>(const float *in)
136{
137 const float32x4x3_t tmp = vld3q_f32(in);
138 return tmp.val[0];
139}
140
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100141inline float32x4_t internal_vdupq_n(float v)
142{
143 return vdupq_n_f32(v);
144}
145
146inline void internal_vst1q(float *p, const float32x4_t &v)
147{
148 vst1q_f32(p, v);
149}
150
151float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
152{
153 ARM_COMPUTE_UNUSED(fixed_point_position);
154 return vmulq_f32(x, y);
155}
156
157inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
158{
159 ARM_COMPUTE_UNUSED(fixed_point_position);
160 return vmlaq_f32(x, y, z);
161}
162
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100163template <unsigned int stridex>
164qint8x8_t internal_vld1q(const qint8_t *in);
165
166template <>
167qint8x8_t internal_vld1q<1>(const qint8_t *in)
168{
169 return vld1_qs8(in);
170}
171
172template <>
173qint8x8_t internal_vld1q<2>(const qint8_t *in)
174{
175 const qint8x8x2_t tmp = vld2_s8(in);
176 return tmp.val[0];
177}
178
179template <>
180qint8x8_t internal_vld1q<3>(const qint8_t *in)
181{
182 const qint8x8x3_t tmp = vld3_s8(in);
183 return tmp.val[0];
184}
185
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100186inline qint8x8_t internal_vdupq_n(qint8_t v)
187{
188 return vdup_n_qs8(v);
189}
190
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100191inline qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100192{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100193 return vmull_qs8(x, y, fixed_point_position);
194}
195
196inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
197{
198 return vqmlal_qs8(x, y, z, fixed_point_position);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100199}
200
201inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
202{
203 vst1q_qs16(p, v);
204}
205
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100206inline void internal_vst1q(int *p, const qint32x4x2_t &v)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100207{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100208 vst1q_s32(p, v.val[0]);
209 vst1q_s32(p + 4, v.val[1]);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100210}
211
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100212template <unsigned int stridex>
213qint32x4x2_t internal_vld1q(const qint32_t *in);
214
215template <>
216qint32x4x2_t internal_vld1q<1>(const qint32_t *in)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100217{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100218 const qint32x4x2_t r =
219 {
220 {
221 vld1q_s32(in),
222 vld1q_s32(in + 4)
223 }
224 };
225 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100226}
227
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100228inline qint32x4x2_t internal_vmull(const qint16x8_t &x, const qint16x8_t &y, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100229{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100230 const qint32x4x2_t r =
231 {
232 {
233 vmull_qs16(vget_low_s16(x), vget_low_s16(y), fixed_point_position),
234 vmull_qs16(vget_high_s16(x), vget_high_s16(y), fixed_point_position),
235 }
236 };
237 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100238}
239
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100240inline qint32x4x2_t internal_vmlal(const qint32x4x2_t &x, const qint16x8_t &y, const qint16x8_t &z, int fixed_point_position)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100241{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100242 const qint32x4x2_t r =
243 {
244 {
245 vqmlal_qs16(x.val[0], vget_low_s16(y), vget_low_s16(z), fixed_point_position),
246 vqmlal_qs16(x.val[1], vget_high_s16(y), vget_high_s16(z), fixed_point_position)
247 }
248 };
249 return r;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100250}
251
252template <typename T1, typename T2, unsigned int stridex>
253class convolver_1x1
254{
255public:
256 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
257 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
258 {
259 const int input_stride_y = input->info()->strides_in_bytes().y();
260 const int input_stride_z = input->info()->strides_in_bytes().z();
261 const int output_stride_y = output->info()->strides_in_bytes().y();
262 const int output_stride_z = output->info()->strides_in_bytes().z();
263 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
264 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
265 const int output_w = output->info()->dimension(0);
266 const int output_h = output->info()->dimension(1);
267 const int range_z = window.z().end() - window.z().start();
268 const int kernel_depth = weights->info()->dimension(Window::DimZ);
269 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
270 const int fixed_point_position = input->info()->fixed_point_position();
271
272 // setup output window for the iterator
273 Window window_out = window;
274 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
275 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
276 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
277
278 // setup input window for the iterator
279 Window window_in = window;
280 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
281 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
282 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
283 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
284
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100285 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100286 Iterator out(output, window_out);
287 Iterator in(input, window_in);
288 Iterator k(weights, window_k);
289
290 const uint8_t *k_ptr = k.ptr();
291
292 execute_window_loop(window_out, [&](const Coordinates & id)
293 {
294 /*
295 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
296 */
297 const uint8_t *input_ptr = in.ptr();
298 uint8_t *out_ptr = out.ptr();
299 int ih = 0;
300 int oh = 0;
301 for(int oz = 0; oz < range_z; ++oz)
302 {
303 auto p_out_base = out_ptr + oz * output_stride_z;
304 // Step 1
305 {
306 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
307 const auto vk = internal_vdupq_n(*k_val);
308 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
309 {
310 const int offset_xy = ih * input_stride_y;
311 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
312 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
313 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
314 {
315 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
316 }
317 }
318 }
319 // Step 2
320 for(int p = 1; p < kernel_depth; ++p)
321 {
322 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
323 const auto vk = internal_vdupq_n(*k_val);
324 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
325 {
326 const int offset_xy = ih * input_stride_y;
327 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
328 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
329 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
330 {
331 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
332 }
333 }
334 }
335 }
336 },
337 in, out);
338 }
339};
340
Pablo Tello0d176142017-07-06 16:43:14 +0100341#ifdef ARM_COMPUTE_ENABLE_FP16
342inline float16x8x3_t load_matrix_row(const float16_t *ptr)
343{
344 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
345 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
346 const float16x8x3_t r =
347 {
348 {
349 vld1q_dup_f16(ptr),
350 vld1q_dup_f16(1 + ptr),
351 vld1q_dup_f16(2 + ptr)
352 }
353 };
354 return r;
355}
356
357template <unsigned int stridex>
358float16x8x2_t convolve_3x3(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
359 int fixed_point_position);
360
361template <>
362float16x8x2_t convolve_3x3<1>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
363 int fixed_point_position)
364{
365 ARM_COMPUTE_UNUSED(fixed_point_position);
366
367 const float16x8x3_t vtop =
368 {
369 {
370 vld1q_f16(in_top),
371 vld1q_f16(in_top + 8),
372 vld1q_f16(in_top + 16)
373 }
374 };
375 const float16x8x3_t vmid =
376 {
377 {
378 vld1q_f16(in_mid),
379 vld1q_f16(in_mid + 8),
380 vld1q_f16(in_mid + 16)
381 }
382 };
383 const float16x8x3_t vlow =
384 {
385 {
386 vld1q_f16(in_low),
387 vld1q_f16(in_low + 8),
388 vld1q_f16(in_low + 16)
389 }
390 };
391 float16x8x2_t out =
392 {
393 {
394 vmulq_f16(vtop.val[0], m0.val[0]),
395 vmulq_f16(vtop.val[1], m0.val[0])
396 }
397 };
398 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 1), m0.val[1]));
399 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vtop.val[0], vtop.val[1], 2), m0.val[2]));
400 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vmid.val[0], m1.val[0]));
401 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 1), m1.val[1]));
402 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vmid.val[0], vmid.val[1], 2), m1.val[2]));
403 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vlow.val[0], m2.val[0]));
404 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 1), m2.val[1]));
405 out.val[0] = vaddq_f16(out.val[0], vmulq_f16(vextq_f16(vlow.val[0], vlow.val[1], 2), m2.val[2]));
406 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 1), m0.val[1]));
407 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vtop.val[1], vtop.val[2], 2), m0.val[2]));
408 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vmid.val[1], m1.val[0]));
409 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 1), m1.val[1]));
410 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vmid.val[1], vmid.val[2], 2), m1.val[2]));
411 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vlow.val[1], m2.val[0]));
412 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 1), m2.val[1]));
413 out.val[1] = vaddq_f16(out.val[1], vmulq_f16(vextq_f16(vlow.val[1], vlow.val[2], 2), m2.val[2]));
414 return out;
415}
416
417template <>
418inline float16x8x2_t convolve_3x3<2>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
419 int fixed_point_position)
420{
421 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
422 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 2), out.val[0], 1);
423 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 0), out.val[0], 2);
424 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[1], 2), out.val[0], 3);
425 return out;
426}
427
428template <>
429inline float16x8x2_t convolve_3x3<3>(const float16_t *in_top, const float16_t *in_mid, const float16_t *in_low, const float16x8x3_t &m0, const float16x8x3_t &m1, const float16x8x3_t &m2,
430 int fixed_point_position)
431{
432 float16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
433 out.val[0] = vsetq_lane_f16(vgetq_lane_f16(out.val[0], 3), out.val[0], 1);
434 return out;
435}
436
437template <unsigned int stridex>
438void store_results(float16_t *buffer, const float16x8x2_t &values);
439
440template <>
441void store_results<1>(float16_t *buffer, const float16x8x2_t &values)
442{
443 vst1q_f16(buffer, values.val[0]);
444 vst1q_f16(buffer + 8, values.val[1]);
445}
446
447template <>
448void store_results<2>(float16_t *buffer, const float16x8x2_t &values)
449{
450 vst1q_f16(buffer, values.val[0]);
451}
452
453template <>
454void store_results<3>(float16_t *buffer, const float16x8x2_t &values)
455{
456 vst1_f16(buffer, vget_low_f16(values.val[0]));
457}
458
459template <unsigned int stridex>
460void accumulate_results(float16_t *buffer, const float16x8x2_t &values);
461
462template <>
463void accumulate_results<1>(float16_t *buffer, const float16x8x2_t &values)
464{
465 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
466 vst1q_f16(buffer + 8, vaddq_f16(vld1q_f16(buffer + 8), values.val[1]));
467}
468
469template <>
470void accumulate_results<2>(float16_t *buffer, const float16x8x2_t &values)
471{
472 vst1q_f16(buffer, vaddq_f16(vld1q_f16(buffer), values.val[0]));
473}
474
475template <>
476void accumulate_results<3>(float16_t *buffer, const float16x8x2_t &values)
477{
478 vst1_f16(buffer, vadd_f16(vld1_f16(buffer), vget_low_f16(values.val[0])));
479}
480
481#endif /* ARM_COMPUTE_ENABLE_FP16 */
482
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100483inline float32x4x3_t load_matrix_row(const float *ptr)
484{
485 const float32x4x3_t r =
486 {
487 {
488 vld1q_dup_f32(ptr),
489 vld1q_dup_f32(1 + ptr),
490 vld1q_dup_f32(2 + ptr)
491 }
492 };
493 return r;
494}
495inline qint8x8x3_t load_matrix_row(const qint8_t *ptr)
496{
497 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
498 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
499 const qint8x8x3_t r =
500 {
501 {
502 vld1_dup_qs8(ptr),
503 vld1_dup_qs8(1 + ptr),
504 vld1_dup_qs8(2 + ptr)
505 }
506 };
507 return r;
508}
509
510template <unsigned int stridex>
511float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position);
512
513template <>
514inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
515{
516 ARM_COMPUTE_UNUSED(fixed_point_position);
517
518 const float32x4x3_t vtop =
519 {
520 {
521 vld1q_f32(in_top),
522 vld1q_f32(in_top + 4),
523 vld1q_f32(in_top + 8)
524 }
525 };
526 const float32x4x3_t vmid =
527 {
528 {
529 vld1q_f32(in_mid),
530 vld1q_f32(in_mid + 4),
531 vld1q_f32(in_mid + 8)
532 }
533 };
534 const float32x4x3_t vlow =
535 {
536 {
537 vld1q_f32(in_low),
538 vld1q_f32(in_low + 4),
539 vld1q_f32(in_low + 8)
540 }
541 };
542 float32x4x2_t out =
543 {
544 {
545 vmulq_f32(vtop.val[0], m0.val[0]),
546 vmulq_f32(vtop.val[1], m0.val[0])
547 }
548 };
549 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
550 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
551 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
552 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
553 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
554 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
555 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
556 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
557 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
558 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
559 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
560 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
561 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
562 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
563 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
564 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
565 return out;
566}
567
568template <>
569inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
570{
571 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
572 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
573 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
574 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
575 return out;
576}
577
578template <>
579inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
580{
581 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
582 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
583 return out;
584}
585
586template <unsigned int stridex>
587qint16x8x2_t convolve_3x3(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position);
588
589template <>
590inline qint16x8x2_t convolve_3x3<1>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
591{
592 ARM_COMPUTE_UNUSED(fixed_point_position);
593
594 const qint8x8x3_t vtop =
595 {
596 {
597 vld1_qs8(in_top),
598 vld1_qs8(in_top + 8),
599 vld1_qs8(in_top + 16)
600 }
601 };
602 const qint8x8x3_t vmid =
603 {
604 {
605 vld1_qs8(in_mid),
606 vld1_qs8(in_mid + 8),
607 vld1_qs8(in_mid + 16)
608 }
609 };
610 const qint8x8x3_t vlow =
611 {
612 {
613 vld1_qs8(in_low),
614 vld1_qs8(in_low + 8),
615 vld1_qs8(in_low + 16)
616 }
617 };
618 qint16x8x2_t out =
619 {
620 {
621 vmull_qs8(vtop.val[0], m0.val[0], fixed_point_position),
622 vmull_qs8(vtop.val[1], m0.val[0], fixed_point_position)
623 }
624 };
625 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 1), m0.val[1], fixed_point_position);
626 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 2), m0.val[2], fixed_point_position);
627 out.val[0] = vqmlal_qs8(out.val[0], vmid.val[0], m1.val[0], fixed_point_position);
628 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 1), m1.val[1], fixed_point_position);
629 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 2), m1.val[2], fixed_point_position);
630 out.val[0] = vqmlal_qs8(out.val[0], vlow.val[0], m2.val[0], fixed_point_position);
631 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 1), m2.val[1], fixed_point_position);
632 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 2), m2.val[2], fixed_point_position);
633 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 1), m0.val[1], fixed_point_position);
634 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 2), m0.val[2], fixed_point_position);
635 out.val[1] = vqmlal_qs8(out.val[1], vmid.val[1], m1.val[0], fixed_point_position);
636 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 1), m1.val[1], fixed_point_position);
637 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 2), m1.val[2], fixed_point_position);
638 out.val[1] = vqmlal_qs8(out.val[1], vlow.val[1], m2.val[0], fixed_point_position);
639 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 1), m2.val[1], fixed_point_position);
640 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 2), m2.val[2], fixed_point_position);
641 return out;
642}
643
644template <>
645inline qint16x8x2_t convolve_3x3<2>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
646{
647 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
648 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 2), out.val[0], 1);
649 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 4), out.val[0], 2);
650 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 3);
651 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 0), out.val[0], 4);
652 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 2), out.val[0], 5);
653 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 4), out.val[0], 6);
654 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 6), out.val[0], 7);
655 return out;
656}
657
658template <>
659inline qint16x8x2_t convolve_3x3<3>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
660{
661 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
662 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 3), out.val[0], 1);
663 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 2);
664 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 1), out.val[0], 3);
665 return out;
666}
667
668template <unsigned int stridex>
669void store_results(float *buffer, const float32x4x2_t &values);
670
671template <>
672void store_results<1>(float *buffer, const float32x4x2_t &values)
673{
674 vst1q_f32(buffer, values.val[0]);
675 vst1q_f32(buffer + 4, values.val[1]);
676}
677
678template <>
679void store_results<2>(float *buffer, const float32x4x2_t &values)
680{
681 vst1q_f32(buffer, values.val[0]);
682}
683
684template <>
685void store_results<3>(float *buffer, const float32x4x2_t &values)
686{
687 vst1_f32(buffer, vget_low_f32(values.val[0]));
688}
689
690template <unsigned int stridex>
691void store_results(qint16_t *buffer, const qint16x8x2_t &values);
692
693template <>
694void store_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
695{
696 vst1q_qs16(buffer, values.val[0]);
697 vst1q_qs16(buffer + 8, values.val[1]);
698}
699
700template <>
701void store_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
702{
703 vst1q_qs16(buffer, values.val[0]);
704}
705
706template <>
707void store_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
708{
709 vst1_qs16(buffer, vget_low_s16(values.val[0]));
710}
711
712template <unsigned int stridex>
713void accumulate_results(float *buffer, const float32x4x2_t &values);
714
715template <>
716void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
717{
718 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
719 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
720}
721
722template <>
723void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
724{
725 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
726}
727
728template <>
729void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
730{
731 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
732}
733
734template <unsigned int stridex>
735void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
736
737template <>
738void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
739{
740 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
741 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
742}
743
744template <>
745void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
746{
747 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
748}
749
750template <>
751void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
752{
753 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
754}
755
756template <unsigned int stridex>
757int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
758
759template <>
760int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
761{
762 return num_elems_written_per_iteration;
763}
764
765template <>
766int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
767{
768 return num_elems_written_per_iteration << 1;
769}
770
771template <>
772int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
773{
774 return num_elems_written_per_iteration * 3;
775}
776
777template <typename T1, typename T2, unsigned int stridex>
778class convolver_3x3
779{
780public:
781 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
782 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
783 {
784 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
785 const int input_stride_x = input->info()->strides_in_bytes().x();
786 const int input_stride_y = input->info()->strides_in_bytes().y();
787 const int input_stride_z = input->info()->strides_in_bytes().z();
788 const int output_stride_y = output->info()->strides_in_bytes().y();
789 const int output_stride_z = output->info()->strides_in_bytes().z();
790 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
791 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
792 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
793 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
794 const int output_w = output->info()->dimension(0);
795 const int output_h = output->info()->dimension(1);
796 const int num_planes_z = window.z().end() - window.z().start();
797 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
798 const int kernel_depth = weights->info()->dimension(Window::DimZ);
799 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
800 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
801 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
802 const int fixed_point_position = input->info()->fixed_point_position();
803
804 // setup output window for the iterator
805 Window window_out = window;
806 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
807 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
808 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
809
810 // setup input window for the iterator
811 Window window_in = window;
812 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
813 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
814 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
815 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
816
817 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
818
819 Iterator out(output, window_out);
820 Iterator in(input, window_in);
821 Iterator k(weights, window_k);
822
823 const uint8_t *k_ptr = k.ptr();
824
825 execute_window_loop(window_out, [&](const Coordinates & id)
826 {
827 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
828 uint8_t *out_ptr = out.ptr();
829 int ih = 0;
830 int oh = 0;
831 /*
832 Each thread executing this kernel computes one or more output's volume planes.
833
834 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
835 the third thread [16,24] and the fourth thread [25,31].
836
837 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
838 is that we setup the neon registers containing the kernerl's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
839
840 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
841 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
842 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
843 */
844
845 for(int oz = 0; oz < num_planes_z; ++oz)
846 {
Pablo Tello0d176142017-07-06 16:43:14 +0100847 const int zoffset = id.z() + oz;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100848 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
849 // Step 1
850 {
Pablo Tello0d176142017-07-06 16:43:14 +0100851 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
852 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
853 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100854 const auto vk_r0 = load_matrix_row(ptr_k_r0);
855 const auto vk_r1 = load_matrix_row(ptr_k_r1);
856 const auto vk_r2 = load_matrix_row(ptr_k_r2);
857 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
858 {
859 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
860 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
861 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
862 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
863 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
864 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
865 {
866 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
867 store_results<stridex>(p_out, vres);
868 }
869 }
870 }
871 // Step 2
872 for(int p = 1; p < kernel_depth; ++p)
873 {
Pablo Tello0d176142017-07-06 16:43:14 +0100874 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
875 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
876 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + zoffset * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100877 const auto vk_r0 = load_matrix_row(ptr_k_r0);
878 const auto vk_r1 = load_matrix_row(ptr_k_r1);
879 const auto vk_r2 = load_matrix_row(ptr_k_r2);
880 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
881 {
882 auto in_top = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
883 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
884 auto in_low = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
885 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
886 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
887 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
888 {
889 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
890 accumulate_results<stridex>(p_out, vres);
891 }
892 }
893 }
894 }
895 },
896 in, out);
897 }
898};
899
900template <typename T1, typename T2>
901inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
902 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
903{
904 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
905 switch(conv_stride_x)
906 {
907 case 1:
908 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
909 break;
910 case 2:
911 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
912 break;
913 case 3:
914 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
915 break;
916 default:
917 ARM_COMPUTE_ERROR("Not implemented");
918 }
919}
920
921template <typename T1, typename T2>
922inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
923 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
924{
925 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
926 switch(conv_stride_x)
927 {
928 case 1:
929 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
930 break;
931 case 2:
932 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
933 break;
934 case 3:
935 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
936 break;
937 default:
938 ARM_COMPUTE_ERROR("Not implemented");
939 }
940}
941} // namespace
942
943NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
944 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_elems_read_per_iteration(0), _num_elems_written_per_iteration(0)
945{
946}
947
948BorderSize NEDirectConvolutionLayerKernel::border_size() const
949{
950 return _border_size;
951}
952
953void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
954{
Pablo Tellof87cc7f2017-07-26 10:28:40 +0100955 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F16, DataType::QS16, DataType::F32);
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +0100956 ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100957 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
958 "Pad > 0 not supported for 1x1 weights");
959 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
960 "Pad > 1 not supported for 3x3 weights");
961 ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +0100962 ARM_COMPUTE_ERROR_ON(weights->info()->dimension(2) != input->info()->dimension(2));
963 ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != weights->info()->dimension(1));
964 ARM_COMPUTE_ERROR_ON(weights->info()->num_dimensions() > 4);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100965
966 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
967 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
968 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
969
970 _input = input;
971 _weights = weights;
972 _output = output;
973 _conv_info = conv_info;
974 _kernel_size = weights->info()->dimension(0);
975 _border_size = BorderSize(conv_pad_y, conv_pad_x);
976
Gian Marco Iodice5cb4d6a2017-08-08 10:53:00 +0100977 const unsigned int kernel_size = weights->info()->dimension(0);
978
979 // Get convolved dimensions
980 unsigned int output_width = 0;
981 unsigned int output_height = 0;
982 std::tie(output_width, output_height) = scaled_dimensions(input->info()->dimension(0), input->info()->dimension(1), kernel_size, kernel_size, conv_info);
983
984 TensorShape output_shape = input->info()->tensor_shape();
985 output_shape.set(0, output_width);
986 output_shape.set(1, output_height);
987 output_shape.set(2, weights->info()->dimension(3));
988
989 DataType data_type = input->info()->data_type();
990
991 if(is_data_type_fixed_point(data_type))
992 {
993 // Promote data type in case of fixed point
994 data_type = ((data_type == DataType::QS8) ? DataType::QS16 : DataType::QS32);
995 }
996
997 // Output auto inizialitation if not yet initialized
998 auto_init_if_empty(*output->info(), output_shape, 1, data_type, input->info()->fixed_point_position());
999
1000 ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(output->info()->tensor_shape(), output_shape);
1001 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, output->info()->data_type());
1002
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001003 Window win = calculate_max_window(*output->info());
1004
1005 switch(_kernel_size)
1006 {
1007 case 1:
1008 {
Pablo Tello0d176142017-07-06 16:43:14 +01001009 switch(input->info()->data_type())
1010 {
1011#ifdef ARM_COMPUTE_ENABLE_FP16
1012 case DataType::F16:
1013#endif /* ARM_COMPUTE_ENABLE_FP16 */
1014 case DataType::QS8:
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001015 case DataType::QS16:
Pablo Tello0d176142017-07-06 16:43:14 +01001016 _num_elems_written_per_iteration = 8;
1017 break;
1018 case DataType::F32:
1019 _num_elems_written_per_iteration = 4;
1020 break;
1021 default:
1022 ARM_COMPUTE_ERROR("Data type not supported.");
1023 break;
1024 }
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001025
Pablo Tello0d176142017-07-06 16:43:14 +01001026 _num_elems_read_per_iteration = conv_stride_x * _num_elems_written_per_iteration;
1027 win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001028 AccessWindowHorizontal input_access(input->info(), 0, _num_elems_read_per_iteration);
1029 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
1030 update_window_and_padding(win, input_access, output_access);
1031 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
1032 break;
1033 }
1034 case 3:
1035 {
1036 if(input->info()->data_type() == DataType::F32)
1037 {
1038 _num_elems_read_per_iteration = 12;
1039 _num_elems_written_per_iteration = 16 >> conv_stride_x;
1040 }
1041 else
1042 {
1043 _num_elems_read_per_iteration = 24;
1044 _num_elems_written_per_iteration = 32 >> conv_stride_x;
1045 }
1046
1047 // Calculate right and bottom border
1048 const unsigned int conv_stride_y = std::get<1>(_conv_info.stride());
1049 const int input_width = input->info()->dimension(0);
1050 const int input_height = input->info()->dimension(1);
1051 const int upper_bound_w = ceil_to_multiple(((output->info()->dimension(0) - 1) * conv_stride_x + _kernel_size), _num_elems_read_per_iteration) - conv_pad_x - input_width;
1052 const int upper_bound_h = ((output->info()->dimension(1) - 1) * conv_stride_y - conv_pad_y + _kernel_size) - input_height;
1053 _border_size.right = std::max(upper_bound_w, static_cast<int>(_kernel_size));
1054 _border_size.bottom = std::max(upper_bound_h, static_cast<int>(_kernel_size));
1055
1056 // Create window and update padding
1057 win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
1058 AccessWindowStatic input_access(input->info(), -conv_pad_x, -conv_pad_y, input_width + _border_size.right, input_height + _border_size.bottom);
1059 AccessWindowStatic weights_access(weights->info(), 0, 0, _kernel_size, _kernel_size);
1060 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
1061 update_window_and_padding(win, input_access, weights_access, output_access);
1062 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
1063 break;
1064 }
1065 default:
1066 {
1067 ARM_COMPUTE_ERROR("Not implemented");
1068 break;
1069 }
1070 }
1071
1072 INEKernel::configure(win);
1073}
1074
1075void NEDirectConvolutionLayerKernel::run(const Window &window)
1076{
1077 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1078 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1079 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
1080
1081 const int kernel_size = _weights->info()->dimension(0);
1082
1083 switch(kernel_size)
1084 {
1085 case 1:
1086 {
Pablo Tello0d176142017-07-06 16:43:14 +01001087 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001088 {
Pablo Tello0d176142017-07-06 16:43:14 +01001089 case DataType::QS8:
1090 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1091 break;
Pablo Tellof87cc7f2017-07-26 10:28:40 +01001092 case DataType::QS16:
1093 convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1094 break;
Pablo Tello0d176142017-07-06 16:43:14 +01001095 case DataType::F32:
1096 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1097 break;
1098#ifdef ARM_COMPUTE_ENABLE_FP16
1099 case DataType::F16:
1100 convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1101 break;
1102#endif /* ARM_COMPUTE_ENABLE_FP16 */
1103 default:
1104 ARM_COMPUTE_ERROR("Data type not supported");
1105 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001106 }
1107 break;
1108 }
1109 case 3:
1110 {
Pablo Tello0d176142017-07-06 16:43:14 +01001111 switch(_input->info()->data_type())
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001112 {
Pablo Tello0d176142017-07-06 16:43:14 +01001113 case DataType::QS8:
1114 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1115 break;
1116 case DataType::F32:
1117 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1118 break;
1119#ifdef ARM_COMPUTE_ENABLE_FP16
1120 case DataType::F16:
1121 convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
1122 break;
1123#endif /* ARM_COMPUTE_ENABLE_FP16 */
1124 default:
1125 ARM_COMPUTE_ERROR("Data type not supported");
1126 break;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001127 }
1128 break;
1129 }
1130 default:
1131 {
1132 ARM_COMPUTE_ERROR("Only kernel sizes 1x1 and 3x3 are supported.");
1133 break;
1134 }
1135 }
1136}