blob: d6088981aa1febec5b19a0ecb15b8f304a2c5b5d [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEDirectConvolutionLayerKernel.h"
25
26#include "arm_compute/core/AccessWindowStatic.h"
27#include "arm_compute/core/Error.h"
28#include "arm_compute/core/Helpers.h"
29#include "arm_compute/core/IAccessWindow.h"
30#include "arm_compute/core/ITensor.h"
31#include "arm_compute/core/NEON/NEFixedPoint.h"
32#include "arm_compute/core/Types.h"
33#include "arm_compute/core/Validate.h"
34
35#include <algorithm>
36#include <arm_neon.h>
37
38using namespace arm_compute;
39
40namespace
41{
42template <unsigned int stridex>
43float32x4_t internal_vld1q(const float *in);
44
45template <>
46float32x4_t internal_vld1q<1>(const float *in)
47{
48 return vld1q_f32(in);
49}
50
51template <>
52float32x4_t internal_vld1q<2>(const float *in)
53{
54 const float32x4x2_t tmp = vld2q_f32(in);
55 return tmp.val[0];
56}
57
58template <>
59float32x4_t internal_vld1q<3>(const float *in)
60{
61 const float32x4x3_t tmp = vld3q_f32(in);
62 return tmp.val[0];
63}
64
65template <unsigned int stridex>
66qint8x8_t internal_vld1q(const qint8_t *in);
67
68template <>
69qint8x8_t internal_vld1q<1>(const qint8_t *in)
70{
71 return vld1_qs8(in);
72}
73
74template <>
75qint8x8_t internal_vld1q<2>(const qint8_t *in)
76{
77 const qint8x8x2_t tmp = vld2_s8(in);
78 return tmp.val[0];
79}
80
81template <>
82qint8x8_t internal_vld1q<3>(const qint8_t *in)
83{
84 const qint8x8x3_t tmp = vld3_s8(in);
85 return tmp.val[0];
86}
87
88template <unsigned int stridex>
89qint16x8_t internal_vld1q(const qint16_t *in);
90
91template <>
92qint16x8_t internal_vld1q<1>(const qint16_t *in)
93{
94 return vld1q_s16(in);
95}
96
97inline float32x4_t internal_vdupq_n(float v)
98{
99 return vdupq_n_f32(v);
100}
101
102inline qint8x8_t internal_vdupq_n(qint8_t v)
103{
104 return vdup_n_qs8(v);
105}
106
107inline void internal_vst1q(float *p, const float32x4_t &v)
108{
109 vst1q_f32(p, v);
110}
111
112inline void internal_vst1q(qint16_t *p, const qint16x8_t &v)
113{
114 vst1q_qs16(p, v);
115}
116
117float32x4_t internal_vmull(const float32x4_t &x, const float32x4_t &y, int fixed_point_position)
118{
119 ARM_COMPUTE_UNUSED(fixed_point_position);
120 return vmulq_f32(x, y);
121}
122
123qint16x8_t internal_vmull(const qint8x8_t &x, const qint8x8_t &y, int fixed_point_position)
124{
125 return vmull_qs8(x, y, fixed_point_position);
126}
127
128inline float32x4_t internal_vmlal(const float32x4_t &x, const float32x4_t &y, const float32x4_t &z, int fixed_point_position)
129{
130 ARM_COMPUTE_UNUSED(fixed_point_position);
131 return vmlaq_f32(x, y, z);
132}
133
134inline qint16x8_t internal_vmlal(const qint16x8_t &x, const qint8x8_t &y, const qint8x8_t &z, int fixed_point_position)
135{
136 return vqmlal_qs8(x, y, z, fixed_point_position);
137}
138
139template <typename T1, typename T2, unsigned int stridex>
140class convolver_1x1
141{
142public:
143 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
144 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
145 {
146 const int input_stride_y = input->info()->strides_in_bytes().y();
147 const int input_stride_z = input->info()->strides_in_bytes().z();
148 const int output_stride_y = output->info()->strides_in_bytes().y();
149 const int output_stride_z = output->info()->strides_in_bytes().z();
150 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
151 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
152 const int output_w = output->info()->dimension(0);
153 const int output_h = output->info()->dimension(1);
154 const int range_z = window.z().end() - window.z().start();
155 const int kernel_depth = weights->info()->dimension(Window::DimZ);
156 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
157 const int fixed_point_position = input->info()->fixed_point_position();
158
159 // setup output window for the iterator
160 Window window_out = window;
161 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
162 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
163 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), range_z));
164
165 // setup input window for the iterator
166 Window window_in = window;
167 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
168 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
169 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
170 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
171
172 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
173
174 Iterator out(output, window_out);
175 Iterator in(input, window_in);
176 Iterator k(weights, window_k);
177
178 const uint8_t *k_ptr = k.ptr();
179
180 execute_window_loop(window_out, [&](const Coordinates & id)
181 {
182 /*
183 For a detailed explanation on how the algorithm works refer to template <> class convolver_3x3<1>
184 */
185 const uint8_t *input_ptr = in.ptr();
186 uint8_t *out_ptr = out.ptr();
187 int ih = 0;
188 int oh = 0;
189 for(int oz = 0; oz < range_z; ++oz)
190 {
191 auto p_out_base = out_ptr + oz * output_stride_z;
192 // Step 1
193 {
194 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
195 const auto vk = internal_vdupq_n(*k_val);
196 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
197 {
198 const int offset_xy = ih * input_stride_y;
199 auto in_val = reinterpret_cast<const T1 *>(input_ptr + (0 * input_stride_z + offset_xy));
200 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
201 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
202 {
203 internal_vst1q(p_out, internal_vmull(vk, internal_vld1q<stridex>(in_val), fixed_point_position));
204 }
205 }
206 }
207 // Step 2
208 for(int p = 1; p < kernel_depth; ++p)
209 {
210 const auto k_val = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w);
211 const auto vk = internal_vdupq_n(*k_val);
212 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
213 {
214 const int offset_xy = ih * input_stride_y;
215 auto in_val = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + offset_xy);
216 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
217 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration, in_val += num_elems_read_per_iteration, p_out += num_elems_written_per_iteration)
218 {
219 internal_vst1q(p_out, internal_vmlal(internal_vld1q<1>(p_out), vk, internal_vld1q<stridex>(in_val), fixed_point_position));
220 }
221 }
222 }
223 }
224 },
225 in, out);
226 }
227};
228
229inline float32x4x3_t load_matrix_row(const float *ptr)
230{
231 const float32x4x3_t r =
232 {
233 {
234 vld1q_dup_f32(ptr),
235 vld1q_dup_f32(1 + ptr),
236 vld1q_dup_f32(2 + ptr)
237 }
238 };
239 return r;
240}
241inline qint8x8x3_t load_matrix_row(const qint8_t *ptr)
242{
243 /* ptr is a pointer to a row in a 3x3 matrix, the function returns 3 vectors holding exactly the same value in all lanes:
244 r.val[0] contains the first element, r.val[1] the second element and r.val[2] the third element (in all lanes) */
245 const qint8x8x3_t r =
246 {
247 {
248 vld1_dup_qs8(ptr),
249 vld1_dup_qs8(1 + ptr),
250 vld1_dup_qs8(2 + ptr)
251 }
252 };
253 return r;
254}
255
256template <unsigned int stridex>
257float32x4x2_t convolve_3x3(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position);
258
259template <>
260inline float32x4x2_t convolve_3x3<1>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
261{
262 ARM_COMPUTE_UNUSED(fixed_point_position);
263
264 const float32x4x3_t vtop =
265 {
266 {
267 vld1q_f32(in_top),
268 vld1q_f32(in_top + 4),
269 vld1q_f32(in_top + 8)
270 }
271 };
272 const float32x4x3_t vmid =
273 {
274 {
275 vld1q_f32(in_mid),
276 vld1q_f32(in_mid + 4),
277 vld1q_f32(in_mid + 8)
278 }
279 };
280 const float32x4x3_t vlow =
281 {
282 {
283 vld1q_f32(in_low),
284 vld1q_f32(in_low + 4),
285 vld1q_f32(in_low + 8)
286 }
287 };
288 float32x4x2_t out =
289 {
290 {
291 vmulq_f32(vtop.val[0], m0.val[0]),
292 vmulq_f32(vtop.val[1], m0.val[0])
293 }
294 };
295 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 1), m0.val[1]);
296 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vtop.val[0], vtop.val[1], 2), m0.val[2]);
297 out.val[0] = vmlaq_f32(out.val[0], vmid.val[0], m1.val[0]);
298 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 1), m1.val[1]);
299 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vmid.val[0], vmid.val[1], 2), m1.val[2]);
300 out.val[0] = vmlaq_f32(out.val[0], vlow.val[0], m2.val[0]);
301 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 1), m2.val[1]);
302 out.val[0] = vmlaq_f32(out.val[0], vextq_f32(vlow.val[0], vlow.val[1], 2), m2.val[2]);
303 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 1), m0.val[1]);
304 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vtop.val[1], vtop.val[2], 2), m0.val[2]);
305 out.val[1] = vmlaq_f32(out.val[1], vmid.val[1], m1.val[0]);
306 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 1), m1.val[1]);
307 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vmid.val[1], vmid.val[2], 2), m1.val[2]);
308 out.val[1] = vmlaq_f32(out.val[1], vlow.val[1], m2.val[0]);
309 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 1), m2.val[1]);
310 out.val[1] = vmlaq_f32(out.val[1], vextq_f32(vlow.val[1], vlow.val[2], 2), m2.val[2]);
311 return out;
312}
313
314template <>
315inline float32x4x2_t convolve_3x3<2>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
316{
317 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
318 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 2), out.val[0], 1);
319 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 0), out.val[0], 2);
320 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[1], 2), out.val[0], 3);
321 return out;
322}
323
324template <>
325inline float32x4x2_t convolve_3x3<3>(const float *in_top, const float *in_mid, const float *in_low, const float32x4x3_t &m0, const float32x4x3_t &m1, const float32x4x3_t &m2, int fixed_point_position)
326{
327 float32x4x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
328 out.val[0] = vsetq_lane_f32(vgetq_lane_f32(out.val[0], 3), out.val[0], 1);
329 return out;
330}
331
332template <unsigned int stridex>
333qint16x8x2_t convolve_3x3(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position);
334
335template <>
336inline qint16x8x2_t convolve_3x3<1>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
337{
338 ARM_COMPUTE_UNUSED(fixed_point_position);
339
340 const qint8x8x3_t vtop =
341 {
342 {
343 vld1_qs8(in_top),
344 vld1_qs8(in_top + 8),
345 vld1_qs8(in_top + 16)
346 }
347 };
348 const qint8x8x3_t vmid =
349 {
350 {
351 vld1_qs8(in_mid),
352 vld1_qs8(in_mid + 8),
353 vld1_qs8(in_mid + 16)
354 }
355 };
356 const qint8x8x3_t vlow =
357 {
358 {
359 vld1_qs8(in_low),
360 vld1_qs8(in_low + 8),
361 vld1_qs8(in_low + 16)
362 }
363 };
364 qint16x8x2_t out =
365 {
366 {
367 vmull_qs8(vtop.val[0], m0.val[0], fixed_point_position),
368 vmull_qs8(vtop.val[1], m0.val[0], fixed_point_position)
369 }
370 };
371 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 1), m0.val[1], fixed_point_position);
372 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vtop.val[0], vtop.val[1], 2), m0.val[2], fixed_point_position);
373 out.val[0] = vqmlal_qs8(out.val[0], vmid.val[0], m1.val[0], fixed_point_position);
374 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 1), m1.val[1], fixed_point_position);
375 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vmid.val[0], vmid.val[1], 2), m1.val[2], fixed_point_position);
376 out.val[0] = vqmlal_qs8(out.val[0], vlow.val[0], m2.val[0], fixed_point_position);
377 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 1), m2.val[1], fixed_point_position);
378 out.val[0] = vqmlal_qs8(out.val[0], vext_s8(vlow.val[0], vlow.val[1], 2), m2.val[2], fixed_point_position);
379 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 1), m0.val[1], fixed_point_position);
380 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vtop.val[1], vtop.val[2], 2), m0.val[2], fixed_point_position);
381 out.val[1] = vqmlal_qs8(out.val[1], vmid.val[1], m1.val[0], fixed_point_position);
382 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 1), m1.val[1], fixed_point_position);
383 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vmid.val[1], vmid.val[2], 2), m1.val[2], fixed_point_position);
384 out.val[1] = vqmlal_qs8(out.val[1], vlow.val[1], m2.val[0], fixed_point_position);
385 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 1), m2.val[1], fixed_point_position);
386 out.val[1] = vqmlal_qs8(out.val[1], vext_s8(vlow.val[1], vlow.val[2], 2), m2.val[2], fixed_point_position);
387 return out;
388}
389
390template <>
391inline qint16x8x2_t convolve_3x3<2>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
392{
393 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
394 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 2), out.val[0], 1);
395 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 4), out.val[0], 2);
396 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 3);
397 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 0), out.val[0], 4);
398 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 2), out.val[0], 5);
399 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 4), out.val[0], 6);
400 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 6), out.val[0], 7);
401 return out;
402}
403
404template <>
405inline qint16x8x2_t convolve_3x3<3>(const qint8_t *in_top, const qint8_t *in_mid, const qint8_t *in_low, const qint8x8x3_t &m0, const qint8x8x3_t &m1, const qint8x8x3_t &m2, int fixed_point_position)
406{
407 qint16x8x2_t out = convolve_3x3<1>(in_top, in_mid, in_low, m0, m1, m2, fixed_point_position);
408 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 3), out.val[0], 1);
409 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[0], 6), out.val[0], 2);
410 out.val[0] = vsetq_lane_s16(vgetq_lane_s16(out.val[1], 1), out.val[0], 3);
411 return out;
412}
413
414template <unsigned int stridex>
415void store_results(float *buffer, const float32x4x2_t &values);
416
417template <>
418void store_results<1>(float *buffer, const float32x4x2_t &values)
419{
420 vst1q_f32(buffer, values.val[0]);
421 vst1q_f32(buffer + 4, values.val[1]);
422}
423
424template <>
425void store_results<2>(float *buffer, const float32x4x2_t &values)
426{
427 vst1q_f32(buffer, values.val[0]);
428}
429
430template <>
431void store_results<3>(float *buffer, const float32x4x2_t &values)
432{
433 vst1_f32(buffer, vget_low_f32(values.val[0]));
434}
435
436template <unsigned int stridex>
437void store_results(qint16_t *buffer, const qint16x8x2_t &values);
438
439template <>
440void store_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
441{
442 vst1q_qs16(buffer, values.val[0]);
443 vst1q_qs16(buffer + 8, values.val[1]);
444}
445
446template <>
447void store_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
448{
449 vst1q_qs16(buffer, values.val[0]);
450}
451
452template <>
453void store_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
454{
455 vst1_qs16(buffer, vget_low_s16(values.val[0]));
456}
457
458template <unsigned int stridex>
459void accumulate_results(float *buffer, const float32x4x2_t &values);
460
461template <>
462void accumulate_results<1>(float *buffer, const float32x4x2_t &values)
463{
464 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
465 vst1q_f32(buffer + 4, vaddq_f32(vld1q_f32(buffer + 4), values.val[1]));
466}
467
468template <>
469void accumulate_results<2>(float *buffer, const float32x4x2_t &values)
470{
471 vst1q_f32(buffer, vaddq_f32(vld1q_f32(buffer), values.val[0]));
472}
473
474template <>
475void accumulate_results<3>(float *buffer, const float32x4x2_t &values)
476{
477 vst1_f32(buffer, vadd_f32(vld1_f32(buffer), vget_low_f32(values.val[0])));
478}
479
480template <unsigned int stridex>
481void accumulate_results(qint16_t *buffer, const qint16x8x2_t &values);
482
483template <>
484void accumulate_results<1>(qint16_t *buffer, const qint16x8x2_t &values)
485{
486 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
487 vst1q_qs16(buffer + 8, vqaddq_qs16(vld1q_qs16(buffer + 8), values.val[1]));
488}
489
490template <>
491void accumulate_results<2>(qint16_t *buffer, const qint16x8x2_t &values)
492{
493 vst1q_qs16(buffer, vqaddq_qs16(vld1q_qs16(buffer), values.val[0]));
494}
495
496template <>
497void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
498{
499 vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
500}
501
502template <unsigned int stridex>
503int get_input_num_elems_processed(unsigned int num_elems_written_per_iteration);
504
505template <>
506int get_input_num_elems_processed<1>(unsigned int num_elems_written_per_iteration)
507{
508 return num_elems_written_per_iteration;
509}
510
511template <>
512int get_input_num_elems_processed<2>(unsigned int num_elems_written_per_iteration)
513{
514 return num_elems_written_per_iteration << 1;
515}
516
517template <>
518int get_input_num_elems_processed<3>(unsigned int num_elems_written_per_iteration)
519{
520 return num_elems_written_per_iteration * 3;
521}
522
523template <typename T1, typename T2, unsigned int stridex>
524class convolver_3x3
525{
526public:
527 static void convolve(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
528 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
529 {
530 ARM_COMPUTE_UNUSED(num_elems_read_per_iteration);
531 const int input_stride_x = input->info()->strides_in_bytes().x();
532 const int input_stride_y = input->info()->strides_in_bytes().y();
533 const int input_stride_z = input->info()->strides_in_bytes().z();
534 const int output_stride_y = output->info()->strides_in_bytes().y();
535 const int output_stride_z = output->info()->strides_in_bytes().z();
536 const int kernel_stride_x = weights->info()->strides_in_bytes().x();
537 const int kernel_stride_y = weights->info()->strides_in_bytes().y();
538 const int kernel_stride_z = weights->info()->strides_in_bytes().z();
539 const int kernel_stride_w = weights->info()->strides_in_bytes()[3];
540 const int output_w = output->info()->dimension(0);
541 const int output_h = output->info()->dimension(1);
542 const int num_planes_z = window.z().end() - window.z().start();
543 const int delta_input = get_input_num_elems_processed<stridex>(num_elems_written_per_iteration);
544 const int kernel_depth = weights->info()->dimension(Window::DimZ);
545 const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
546 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
547 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
548 const int fixed_point_position = input->info()->fixed_point_position();
549
550 // setup output window for the iterator
551 Window window_out = window;
552 window_out.set(Window::DimX, Window::Dimension(0, output->info()->dimension(Window::DimX), output->info()->dimension(Window::DimX)));
553 window_out.set(Window::DimY, Window::Dimension(0, output->info()->dimension(Window::DimY), output->info()->dimension(Window::DimY)));
554 window_out.set(Window::DimZ, Window::Dimension(window.z().start(), window.z().end(), num_planes_z));
555
556 // setup input window for the iterator
557 Window window_in = window;
558 // we just want execute_window_loop to iterate over the higher dimensions (>3), so we set the first 3 dimensions to 0
559 window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
560 window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
561 window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
562
563 Window window_k = calculate_max_window(*weights->info(), Steps(1u));
564
565 Iterator out(output, window_out);
566 Iterator in(input, window_in);
567 Iterator k(weights, window_k);
568
569 const uint8_t *k_ptr = k.ptr();
570
571 execute_window_loop(window_out, [&](const Coordinates & id)
572 {
573 const uint8_t *input_ptr = in.ptr() - conv_pad_x * input_stride_x - conv_pad_y * input_stride_y;
574 uint8_t *out_ptr = out.ptr();
575 int ih = 0;
576 int oh = 0;
577 /*
578 Each thread executing this kernel computes one or more output's volume planes.
579
580 Let's say the 3rd dimension of the output volume is 32, the first thread will compute the output for Z = [0,7], the second thread will compute the output for Z = [8,15],
581 the third thread [16,24] and the fourth thread [25,31].
582
583 The algorithm outer loop iterates over Z, P, Y, X where P is the depth/3rd dimension of each kernel. This order is not arbitrary, the main benefit of this
584 is that we setup the neon registers containing the kernerl's values only once and then compute each XY using the preloaded registers as opposed as doing this for every XY value.
585
586 The algorithm does not require allocating any additional memory amd computes the results directly in-place in two stages:
587 1) Convolve plane 0 with kernel 0 and initialize the corresponding output plane with these values.
588 2) Convolve the remaining planes and accumulate the results in the output's plane which has been initialized in step 1.
589 */
590
591 for(int oz = 0; oz < num_planes_z; ++oz)
592 {
593 uint8_t *p_out_base = out_ptr + oz * output_stride_z;
594 // Step 1
595 {
596 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
597 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
598 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + 0 * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
599 const auto vk_r0 = load_matrix_row(ptr_k_r0);
600 const auto vk_r1 = load_matrix_row(ptr_k_r1);
601 const auto vk_r2 = load_matrix_row(ptr_k_r2);
602 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
603 {
604 auto in_top = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 0) * input_stride_y);
605 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 1) * input_stride_y);
606 auto in_low = reinterpret_cast<const T1 *>(input_ptr + 0 * input_stride_z + (ih + 2) * input_stride_y);
607 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
608 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
609 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
610 {
611 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
612 store_results<stridex>(p_out, vres);
613 }
614 }
615 }
616 // Step 2
617 for(int p = 1; p < kernel_depth; ++p)
618 {
619 const auto ptr_k_r0 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 0 * kernel_stride_y + 0 * kernel_stride_x);
620 const auto ptr_k_r1 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 1 * kernel_stride_y + 0 * kernel_stride_x);
621 const auto ptr_k_r2 = reinterpret_cast<const T1 *>(k_ptr + p * kernel_stride_z + (id.z() + oz) * kernel_stride_w + 2 * kernel_stride_y + 0 * kernel_stride_x);
622 const auto vk_r0 = load_matrix_row(ptr_k_r0);
623 const auto vk_r1 = load_matrix_row(ptr_k_r1);
624 const auto vk_r2 = load_matrix_row(ptr_k_r2);
625 for(ih = 0, oh = 0; oh < output_h; ++oh, ih += conv_stride_y)
626 {
627 auto in_top = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 0) * input_stride_y);
628 auto in_mid = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 1) * input_stride_y);
629 auto in_low = reinterpret_cast<const T1 *>(input_ptr + p * input_stride_z + (ih + 2) * input_stride_y);
630 auto p_out = reinterpret_cast<T2 *>(p_out_base + oh * output_stride_y);
631 for(int ow = 0; ow < output_w; ow += num_elems_written_per_iteration,
632 in_top += delta_input, in_mid += delta_input, in_low += delta_input, p_out += num_elems_written_per_iteration)
633 {
634 auto vres = convolve_3x3<stridex>(in_top, in_mid, in_low, vk_r0, vk_r1, vk_r2, fixed_point_position);
635 accumulate_results<stridex>(p_out, vres);
636 }
637 }
638 }
639 }
640 },
641 in, out);
642 }
643};
644
645template <typename T1, typename T2>
646inline void convolve_1x1(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
647 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
648{
649 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
650 switch(conv_stride_x)
651 {
652 case 1:
653 convolver_1x1<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
654 break;
655 case 2:
656 convolver_1x1<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
657 break;
658 case 3:
659 convolver_1x1<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
660 break;
661 default:
662 ARM_COMPUTE_ERROR("Not implemented");
663 }
664}
665
666template <typename T1, typename T2>
667inline void convolve_3x3(const Window &window, unsigned int num_elems_read_per_iteration, unsigned int num_elems_written_per_iteration,
668 const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
669{
670 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
671 switch(conv_stride_x)
672 {
673 case 1:
674 convolver_3x3<T1, T2, 1>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
675 break;
676 case 2:
677 convolver_3x3<T1, T2, 2>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
678 break;
679 case 3:
680 convolver_3x3<T1, T2, 3>::convolve(window, num_elems_read_per_iteration, num_elems_written_per_iteration, input, weights, output, conv_info);
681 break;
682 default:
683 ARM_COMPUTE_ERROR("Not implemented");
684 }
685}
686} // namespace
687
688NEDirectConvolutionLayerKernel::NEDirectConvolutionLayerKernel()
689 : _input(nullptr), _weights(nullptr), _output(nullptr), _conv_info(), _border_size(0), _kernel_size(0), _num_elems_read_per_iteration(0), _num_elems_written_per_iteration(0)
690{
691}
692
693BorderSize NEDirectConvolutionLayerKernel::border_size() const
694{
695 return _border_size;
696}
697
698void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
699{
700 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::F32);
701 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QS8, DataType::F32);
702 ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS16, DataType::F32);
703 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 1 && (std::get<0>(conv_info.pad()) || std::get<1>(conv_info.pad())),
704 "Pad > 0 not supported for 1x1 weights");
705 ARM_COMPUTE_ERROR_ON_MSG(weights->info()->dimension(0) == 3 && (std::get<0>(conv_info.pad()) > 1 || std::get<1>(conv_info.pad()) > 1),
706 "Pad > 1 not supported for 3x3 weights");
707 ARM_COMPUTE_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
708
709 const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
710 const unsigned int conv_pad_x = std::get<0>(conv_info.pad());
711 const unsigned int conv_pad_y = std::get<1>(conv_info.pad());
712
713 _input = input;
714 _weights = weights;
715 _output = output;
716 _conv_info = conv_info;
717 _kernel_size = weights->info()->dimension(0);
718 _border_size = BorderSize(conv_pad_y, conv_pad_x);
719
720 Window win = calculate_max_window(*output->info());
721
722 switch(_kernel_size)
723 {
724 case 1:
725 {
726 _num_elems_written_per_iteration = (input->info()->data_type() == DataType::QS8) ? 8 : 4;
727 _num_elems_read_per_iteration = conv_stride_x * _num_elems_written_per_iteration;
728
729 win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
730 AccessWindowHorizontal input_access(input->info(), 0, _num_elems_read_per_iteration);
731 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
732 update_window_and_padding(win, input_access, output_access);
733 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
734 break;
735 }
736 case 3:
737 {
738 if(input->info()->data_type() == DataType::F32)
739 {
740 _num_elems_read_per_iteration = 12;
741 _num_elems_written_per_iteration = 16 >> conv_stride_x;
742 }
743 else
744 {
745 _num_elems_read_per_iteration = 24;
746 _num_elems_written_per_iteration = 32 >> conv_stride_x;
747 }
748
749 // Calculate right and bottom border
750 const unsigned int conv_stride_y = std::get<1>(_conv_info.stride());
751 const int input_width = input->info()->dimension(0);
752 const int input_height = input->info()->dimension(1);
753 const int upper_bound_w = ceil_to_multiple(((output->info()->dimension(0) - 1) * conv_stride_x + _kernel_size), _num_elems_read_per_iteration) - conv_pad_x - input_width;
754 const int upper_bound_h = ((output->info()->dimension(1) - 1) * conv_stride_y - conv_pad_y + _kernel_size) - input_height;
755 _border_size.right = std::max(upper_bound_w, static_cast<int>(_kernel_size));
756 _border_size.bottom = std::max(upper_bound_h, static_cast<int>(_kernel_size));
757
758 // Create window and update padding
759 win = calculate_max_window(*output->info(), Steps(_num_elems_written_per_iteration));
760 AccessWindowStatic input_access(input->info(), -conv_pad_x, -conv_pad_y, input_width + _border_size.right, input_height + _border_size.bottom);
761 AccessWindowStatic weights_access(weights->info(), 0, 0, _kernel_size, _kernel_size);
762 AccessWindowHorizontal output_access(output->info(), 0, _num_elems_written_per_iteration);
763 update_window_and_padding(win, input_access, weights_access, output_access);
764 output_access.set_valid_region(win, ValidRegion(Coordinates(), output->info()->tensor_shape()));
765 break;
766 }
767 default:
768 {
769 ARM_COMPUTE_ERROR("Not implemented");
770 break;
771 }
772 }
773
774 INEKernel::configure(win);
775}
776
777void NEDirectConvolutionLayerKernel::run(const Window &window)
778{
779 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
780 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
781 ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
782
783 const int kernel_size = _weights->info()->dimension(0);
784
785 switch(kernel_size)
786 {
787 case 1:
788 {
789 if(_input->info()->data_type() == DataType::QS8)
790 {
791 convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
792 }
793 else
794 {
795 convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
796 }
797 break;
798 }
799 case 3:
800 {
801 if(_input->info()->data_type() == DataType::QS8)
802 {
803 convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
804 }
805 else
806 {
807 convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
808 }
809 break;
810 }
811 default:
812 {
813 ARM_COMPUTE_ERROR("Only kernel sizes 1x1 and 3x3 are supported.");
814 break;
815 }
816 }
817}